chore: Fix golangci-lint configuration and patch errors (#34)

* chore: Fix golangci-lint configuration and patch errors

Due to misconfiguration of a linting rules directory, our linter has not been
working properly. This change fixes the configuration issue, and all remaining
linting errors.

* Fix race in peer logging

* Fix race and return

* Lock on bufferred amount low

* Fix mutex lock
This commit is contained in:
Kyle Carberry 2022-01-20 10:00:13 -06:00 committed by GitHub
parent 6a919aea79
commit 2654a93132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 283 additions and 255 deletions

View File

@ -100,10 +100,6 @@ linters-settings:
# - whyNoLint
# - wrapperFunc
# - yodaStyleExpr
settings:
ruleguard:
failOn: all
rules: "${configDir}/lib/go/lintrules/*.go"
goimports:
local-prefixes: coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder
@ -113,24 +109,6 @@ linters-settings:
importas:
no-unaliased: true
alias:
- pkg: k8s.io/api/(\w+)/(v[\w\d]+)
alias: ${1}${2}
- pkg: k8s.io/apimachinery/pkg/apis/meta/(v[\w\d]+)
alias: meta${1}
- pkg: k8s.io/client-go/kubernetes/typed/(\w+)/(v[\w\d]+)
alias: ${1}${2}client
- pkg: k8s.io/metrics/pkg/apis/metrics/(v[\w\d]+)
alias: metrics${1}
- pkg: github.com/docker/docker/api/types
alias: dockertypes
- pkg: github.com/docker/docker/client
alias: dockerclient
misspell:
locale: US
@ -195,6 +173,20 @@ linters-settings:
- name: var-declaration
- name: var-naming
- name: waitgroup-by-value
varnamelen:
ignore-names:
- err
- rw
- r
- i
- db
# Optional list of variable declarations that should be ignored completely. (defaults to empty list)
# Entries must be in the form of "<variable name> <type>" or "<variable name> *<type>" for
# variables, or "const <name>" for constants.
ignore-decls:
- rw http.ResponseWriter
- r *http.Request
- t testing.T
issues:
# Rules listed here: https://github.com/securego/gosec#available-rules
@ -222,7 +214,6 @@ linters:
- asciicheck
- bidichk
- bodyclose
- contextcheck
- deadcode
- dogsled
- errcheck
@ -239,7 +230,6 @@ linters:
- govet
- importas
- ineffassign
# - ireturn
- makezero
- misspell
- nilnil

View File

@ -3,5 +3,5 @@ package main
import "fmt"
func main() {
fmt.Println("Hello World!")
_, _ = fmt.Println("Hello World!")
}

View File

@ -39,7 +39,7 @@ func New(options *Options) http.Handler {
httpmw.ExtractAPIKey(options.Database, nil),
httpmw.ExtractUser(options.Database),
)
r.Get("/user", users.getAuthenticatedUser)
r.Get("/user", users.authenticatedUser)
})
})
r.NotFound(site.Handler().ServeHTTP)

View File

@ -32,11 +32,11 @@ func New(t *testing.T) Server {
Database: db,
})
srv := httptest.NewServer(handler)
u, err := url.Parse(srv.URL)
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
t.Cleanup(srv.Close)
client := codersdk.New(u)
client := codersdk.New(serverURL)
_, err = client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "testuser@coder.com",
Username: "testuser",
@ -54,6 +54,6 @@ func New(t *testing.T) Server {
return Server{
Client: client,
URL: u,
URL: serverURL,
}
}

View File

@ -35,14 +35,14 @@ func Compare(hashed string, password string) (bool, error) {
if len(parts[0]) != 0 {
return false, xerrors.Errorf("hash prefix is invalid")
}
if string(parts[1]) != hashScheme {
if parts[1] != hashScheme {
return false, xerrors.Errorf("hash isn't %q scheme: %q", hashScheme, parts[1])
}
iter, err := strconv.Atoi(string(parts[2]))
iter, err := strconv.Atoi(parts[2])
if err != nil {
return false, xerrors.Errorf("parse iter from hash: %w", err)
}
salt, err := base64.RawStdEncoding.DecodeString(string(parts[3]))
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
if err != nil {
return false, xerrors.Errorf("decode salt: %w", err)
}

View File

@ -70,7 +70,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
})
return
}
user, err := users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
_, err = users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: createUser.Email,
Username: createUser.Username,
})
@ -91,7 +91,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
return
}
user, err = users.Database.InsertUser(context.Background(), database.InsertUserParams{
user, err := users.Database.InsertUser(context.Background(), database.InsertUserParams{
ID: uuid.NewString(),
Email: createUser.Email,
HashedPassword: []byte(hashedPassword),
@ -111,7 +111,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
}
// Returns the currently authenticated user.
func (users *users) getAuthenticatedUser(rw http.ResponseWriter, r *http.Request) {
func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) {
user := httpmw.User(r)
render.JSON(rw, r, User{
@ -158,11 +158,17 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
return
}
id, secret, err := generateAPIKeyIDSecret()
hashed := sha256.Sum256([]byte(secret))
keyID, keySecret, err := generateAPIKeyIDSecret()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
})
return
}
hashed := sha256.Sum256([]byte(keySecret))
_, err = users.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
ID: keyID,
UserID: user.ID,
ExpiresAt: database.Now().Add(24 * time.Hour),
CreatedAt: database.Now(),
@ -178,7 +184,7 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
}
// This format is consumed by the APIKey middleware.
sessionToken := fmt.Sprintf("%s-%s", id, secret)
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
http.SetCookie(rw, &http.Cookie{
Name: httpmw.AuthCookie,
Value: sessionToken,
@ -194,14 +200,14 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
}
// Generates a new ID and secret for an API key.
func generateAPIKeyIDSecret() (string, string, error) {
func generateAPIKeyIDSecret() (id string, secret string, err error) {
// Length of an API Key ID.
id, err := cryptorand.String(10)
id, err = cryptorand.String(10)
if err != nil {
return "", "", err
}
// Length of an API Key secret.
secret, err := cryptorand.String(22)
secret, err = cryptorand.String(22)
if err != nil {
return "", "", err
}

View File

@ -4,9 +4,10 @@ import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/stretchr/testify/require"
)
func TestUsers(t *testing.T) {

View File

@ -18,9 +18,9 @@ import (
)
// New creates a Coder client for the provided URL.
func New(url *url.URL) *Client {
func New(serverURL *url.URL) *Client {
return &Client{
url: url,
url: serverURL,
httpClient: &http.Client{},
}
}
@ -50,7 +50,7 @@ func (c *Client) SetSessionToken(token string) error {
// request performs an HTTP request with the body provided.
// The caller is responsible for closing the response body.
func (c *Client) request(ctx context.Context, method, path string, body interface{}) (*http.Response, error) {
url, err := c.url.Parse(path)
serverURL, err := c.url.Parse(path)
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
@ -65,7 +65,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
}
}
req, err := http.NewRequestWithContext(ctx, method, url.String(), &buf)
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf)
if err != nil {
return nil, xerrors.Errorf("create request: %w", err)
}
@ -81,7 +81,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
}
// readBodyAsError reads the response as an httpapi.Message, and
// wraps it in a codersdk.Error type for easy marshalling.
// wraps it in a codersdk.Error type for easy marshaling.
func readBodyAsError(res *http.Response) error {
var m httpapi.Response
err := json.NewDecoder(res.Body).Decode(&m)

View File

@ -26,7 +26,7 @@ func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserReq
// User returns a user for the ID provided.
// If the ID string is empty, the current user will be returned.
func (c *Client) User(ctx context.Context, id string) (coderd.User, error) {
func (c *Client) User(ctx context.Context, _ string) (coderd.User, error) {
res, err := c.request(ctx, http.MethodGet, "/api/v2/user", nil)
if err != nil {
return coderd.User{}, err

View File

@ -5,10 +5,11 @@ import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/stretchr/testify/require"
)
func TestUsers(t *testing.T) {

View File

@ -73,42 +73,43 @@ func Int() (int, error) {
return int(i), nil
}
// Int63n returns a non-negative random integer in [0,n) as a int64.
func Int63n(n int64) (int64, error) {
if n <= 0 {
// Int63n returns a non-negative random integer in [0,max) as a int64.
func Int63n(max int64) (int64, error) {
if max <= 0 {
panic("invalid argument to Int63n")
}
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
trueMax := int64((1 << 63) - 1 - (1<<63)%uint64(max))
i, err := Int63()
if err != nil {
return 0, err
}
for i > max {
for i > trueMax {
i, err = Int63()
if err != nil {
return 0, err
}
}
return i % n, nil
return i % max, nil
}
// Int31n returns a non-negative integer in [0,n) as a int32.
func Int31n(n int32) (int32, error) {
// Int31n returns a non-negative integer in [0,max) as a int32.
func Int31n(max int32) (int32, error) {
i, err := Uint32()
if err != nil {
return 0, err
}
return UnbiasedModulo32(i, n)
return UnbiasedModulo32(i, max)
}
// UnbiasedModulo32 uniformly modulos v by n over a sufficiently large data
// set, regenerating v if necessary. n must be > 0. All input bits in v must be
// fully random, you cannot cast a random uint8/uint16 for input into this
// function.
//nolint:varnamelen
func UnbiasedModulo32(v uint32, n int32) (int32, error) {
prod := uint64(v) * uint64(n)
low := uint32(prod)
@ -127,14 +128,14 @@ func UnbiasedModulo32(v uint32, n int32) (int32, error) {
return int32(prod >> 32), nil
}
// Intn returns a non-negative integer in [0,n) as a int.
func Intn(n int) (int, error) {
if n <= 0 {
// Intn returns a non-negative integer in [0,max) as a int.
func Intn(max int) (int, error) {
if max <= 0 {
panic("n must be a positive nonzero number")
}
if n <= 1<<31-1 {
i, err := Int31n(int32(n))
if max <= 1<<31-1 {
i, err := Int31n(int32(max))
if err != nil {
return 0, err
}
@ -142,7 +143,7 @@ func Intn(n int) (int, error) {
return int(i), nil
}
i, err := Int63n(int64(n))
i, err := Int63n(int64(max))
if err != nil {
return 0, err
}

View File

@ -5,8 +5,9 @@ import (
"encoding/binary"
"testing"
"github.com/coder/coder/cryptorand"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cryptorand"
)
func TestInt63(t *testing.T) {
@ -144,7 +145,7 @@ func TestBool(t *testing.T) {
const iterations = 10000
trueCount := 0
for i := 0; i < iterations; i += 1 {
for i := 0; i < iterations; i++ {
v, err := cryptorand.Bool()
require.NoError(t, err, "unexpected error from Bool")
if v {

View File

@ -53,7 +53,7 @@ func StringCharset(charSetStr string, size int) (string, error) {
buf.Grow(size)
for i := 0; i < size; i++ {
c, err := UnbiasedModulo32(
count, err := UnbiasedModulo32(
binary.BigEndian.Uint32(ibuf[i*4:(i+1)*4]),
int32(len(charSet)),
)
@ -61,7 +61,7 @@ func StringCharset(charSetStr string, size int) (string, error) {
return "", err
}
_, _ = buf.WriteRune(charSet[c])
_, _ = buf.WriteRune(charSet[count])
}
return buf.String(), nil

View File

@ -8,8 +8,9 @@ import (
"testing"
"unicode/utf8"
"github.com/coder/coder/cryptorand"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cryptorand"
)
func TestString(t *testing.T) {

View File

@ -22,11 +22,11 @@ type fakeQuerier struct {
}
// InTx doesn't rollback data properly for in-memory yet.
func (q *fakeQuerier) InTx(ctx context.Context, fn func(database.Store) error) error {
func (q *fakeQuerier) InTx(fn func(database.Store) error) error {
return fn(q)
}
func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) {
for _, apiKey := range q.apiKeys {
if apiKey.ID == id {
return apiKey, nil
@ -35,7 +35,7 @@ func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.AP
return database.APIKey{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) {
func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) {
for _, user := range q.users {
if user.Email == arg.Email || user.Username == arg.Username {
return user, nil
@ -44,7 +44,7 @@ func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database
return database.User{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User, error) {
func (q *fakeQuerier) GetUserByID(_ context.Context, id string) (database.User, error) {
for _, user := range q.users {
if user.ID == id {
return user, nil
@ -53,11 +53,12 @@ func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User
return database.User{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserCount(ctx context.Context) (int64, error) {
func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) {
return int64(len(q.users)), nil
}
func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
//nolint:gosimple
key := database.APIKey{
ID: arg.ID,
HashedSecret: arg.HashedSecret,
@ -79,7 +80,7 @@ func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKe
return key, nil
}
func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) {
func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) {
user := database.User{
ID: arg.ID,
Email: arg.Email,
@ -94,7 +95,7 @@ func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserPar
return user, nil
}
func (q *fakeQuerier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error {
for index, apiKey := range q.apiKeys {
if apiKey.ID != arg.ID {
continue

View File

@ -21,7 +21,7 @@ import (
type Store interface {
querier
InTx(context.Context, func(Store) error) error
InTx(func(Store) error) error
}
// DBTX represents a database connection or transaction.
@ -46,16 +46,16 @@ type sqlQuerier struct {
}
// InTx performs database operations inside a transaction.
func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error {
func (q *sqlQuerier) InTx(function func(Store) error) error {
if q.sdb == nil {
return nil
}
tx, err := q.sdb.Begin()
transaction, err := q.sdb.Begin()
if err != nil {
return xerrors.Errorf("begin transaction: %w", err)
}
defer func() {
rerr := tx.Rollback()
rerr := transaction.Rollback()
if rerr == nil || errors.Is(rerr, sql.ErrTxDone) {
// no need to do anything, tx committed successfully
return
@ -63,11 +63,11 @@ func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error {
// couldn't roll back for some reason, extend returned error
err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err)
}()
err = fn(&sqlQuerier{db: tx})
err = function(&sqlQuerier{db: transaction})
if err != nil {
return xerrors.Errorf("execute transaction: %w", err)
}
err = tx.Commit()
err = transaction.Commit()
if err != nil {
return xerrors.Errorf("commit transaction: %w", err)
}

View File

@ -2,7 +2,6 @@ package main
import (
"bytes"
"context"
"database/sql"
"fmt"
"io/ioutil"
@ -25,7 +24,7 @@ func main() {
if err != nil {
panic(err)
}
err = database.Migrate(context.Background(), db)
err = database.Migrate(db)
if err != nil {
panic(err)
}
@ -82,7 +81,7 @@ func main() {
if !ok {
panic("couldn't get caller path")
}
err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0644)
err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0600)
if err != nil {
panic(err)
}

View File

@ -1,7 +1,6 @@
package database
import (
"context"
"database/sql"
"embed"
"errors"
@ -16,7 +15,7 @@ import (
var migrations embed.FS
// Migrate runs SQL migrations to ensure the database schema is up-to-date.
func Migrate(ctx context.Context, db *sql.DB) error {
func Migrate(db *sql.DB) error {
sourceDriver, err := iofs.New(migrations, "migrations")
if err != nil {
return xerrors.Errorf("create iofs: %w", err)

View File

@ -3,7 +3,6 @@
package database_test
import (
"context"
"database/sql"
"testing"
@ -27,6 +26,6 @@ func TestMigrate(t *testing.T) {
db, err := sql.Open("postgres", connection)
require.NoError(t, err)
defer db.Close()
err = database.Migrate(context.Background(), db)
err = database.Migrate(db)
require.NoError(t, err)
}

View File

@ -3,7 +3,6 @@ package postgres
import (
"database/sql"
"fmt"
"log"
"time"
"github.com/ory/dockertest/v3"
@ -32,13 +31,16 @@ func Open() (string, func(), error) {
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
return "", nil, xerrors.Errorf("could not start resource: %w", err)
}
hostAndPort := resource.GetHostPort("5432/tcp")
dbURL := fmt.Sprintf("postgres://postgres:postgres@%s/postgres?sslmode=disable", hostAndPort)
// Docker should hard-kill the container after 120 seconds.
resource.Expire(120)
err = resource.Expire(120)
if err != nil {
return "", nil, xerrors.Errorf("could not expire resource: %w", err)
}
pool.MaxWait = 120 * time.Second
err = pool.Retry(func() error {

View File

@ -122,7 +122,7 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
}
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, error) {
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
@ -144,12 +144,12 @@ func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, erro
case <-ctx.Done():
return nil, ctx.Err()
}
pg := &pgPubsub{
db: db,
pgPubsub := &pgPubsub{
db: database,
pgListener: listener,
listeners: make(map[string]map[string]Listener),
}
go pg.listen(ctx)
go pgPubsub.listen(ctx)
return pg, nil
return pgPubsub, nil
}

View File

@ -52,7 +52,7 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
return nil
}
func (m *memoryPubsub) Close() error {
func (*memoryPubsub) Close() error {
return nil
}

View File

@ -17,9 +17,9 @@ func TestPubsubMemory(t *testing.T) {
pubsub := database.NewPubsubInMemory()
event := "test"
data := "testing"
ch := make(chan []byte)
messageChannel := make(chan []byte)
cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
ch <- message
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
@ -27,7 +27,7 @@ func TestPubsubMemory(t *testing.T) {
err = pubsub.Publish(event, []byte(data))
require.NoError(t, err)
}()
message := <-ch
message := <-messageChannel
assert.Equal(t, string(message), data)
})
}

View File

@ -32,9 +32,9 @@ func TestPubsub(t *testing.T) {
defer pubsub.Close()
event := "test"
data := "testing"
ch := make(chan []byte)
messageChannel := make(chan []byte)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
ch <- message
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
@ -42,7 +42,7 @@ func TestPubsub(t *testing.T) {
err = pubsub.Publish(event, []byte(data))
require.NoError(t, err)
}()
message := <-ch
message := <-messageChannel
assert.Equal(t, string(message), data)
})
}

View File

@ -31,7 +31,7 @@ func init() {
}
return name
})
validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
err := validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
f := fl.Field().Interface()
str, ok := f.(string)
if !ok {
@ -45,6 +45,9 @@ func init() {
}
return usernameRegex.MatchString(str)
})
if err != nil {
panic(err)
}
}
// Response represents a generic HTTP response.
@ -60,20 +63,20 @@ type Error struct {
}
// Write outputs a standardized format to an HTTP response body.
func Write(w http.ResponseWriter, status int, response Response) {
func Write(rw http.ResponseWriter, status int, response Response) {
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(true)
err := enc.Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
_, err = w.Write(buf.Bytes())
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(status)
_, err = rw.Write(buf.Bytes())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}

View File

@ -58,22 +58,22 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
})
return
}
id := parts[0]
secret := parts[1]
keyID := parts[0]
keySecret := parts[1]
// Ensuring key lengths are valid.
if len(id) != 10 {
if len(keyID) != 10 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key id", AuthCookie),
})
return
}
if len(secret) != 22 {
if len(keySecret) != 22 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key secret", AuthCookie),
})
return
}
key, err := db.GetAPIKeyByID(r.Context(), id)
key, err := db.GetAPIKeyByID(r.Context(), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
@ -86,7 +86,7 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
})
return
}
hashed := sha256.Sum256([]byte(secret))
hashed := sha256.Sum256([]byte(keySecret))
// Checking to see if the secret is valid.
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {

View File

@ -19,9 +19,9 @@ import (
"github.com/coder/coder/httpmw"
)
func randomAPIKeyParts() (string, string) {
id, _ := cryptorand.String(10)
secret, _ := cryptorand.String(22)
func randomAPIKeyParts() (id string, secret string) {
id, _ = cryptorand.String(10)
secret, _ = cryptorand.String(22)
return id, secret
}
@ -41,7 +41,9 @@ func TestAPIKey(t *testing.T) {
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidFormat", func(t *testing.T) {
@ -56,7 +58,9 @@ func TestAPIKey(t *testing.T) {
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidIDLength", func(t *testing.T) {
@ -71,7 +75,9 @@ func TestAPIKey(t *testing.T) {
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidSecretLength", func(t *testing.T) {
@ -86,7 +92,9 @@ func TestAPIKey(t *testing.T) {
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
@ -102,7 +110,9 @@ func TestAPIKey(t *testing.T) {
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidSecret", func(t *testing.T) {
@ -125,7 +135,9 @@ func TestAPIKey(t *testing.T) {
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Expired", func(t *testing.T) {
@ -147,7 +159,9 @@ func TestAPIKey(t *testing.T) {
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Valid", func(t *testing.T) {
@ -177,7 +191,9 @@ func TestAPIKey(t *testing.T) {
Message: "it worked!",
})
})).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
@ -207,7 +223,9 @@ func TestAPIKey(t *testing.T) {
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
@ -237,7 +255,9 @@ func TestAPIKey(t *testing.T) {
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
@ -268,7 +288,9 @@ func TestAPIKey(t *testing.T) {
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
@ -310,7 +332,9 @@ func TestAPIKey(t *testing.T) {
},
},
})(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
@ -325,7 +349,7 @@ type oauth2Config struct {
tokenSource *oauth2TokenSource
}
func (o *oauth2Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}

View File

@ -28,7 +28,7 @@ const (
// the channel on open. The datachannel should not be manually
// mutated after being passed to this function.
func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOpts) *Channel {
c := &Channel{
channel := &Channel{
opts: opts,
conn: conn,
dc: dc,
@ -37,8 +37,8 @@ func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOpts) *Channel
closed: make(chan struct{}),
sendMore: make(chan struct{}, 1),
}
c.init()
return c
channel.init()
return channel
}
type ChannelOpts struct {
@ -109,6 +109,8 @@ func (c *Channel) init() {
return
}
select {
case <-c.closed:
return
case c.sendMore <- struct{}{}:
default:
}
@ -167,7 +169,7 @@ func (c *Channel) init() {
// Read blocks until data is received.
//
// This will block until the underlying DataChannel has been opened.
func (c *Channel) Read(b []byte) (n int, err error) {
func (c *Channel) Read(bytes []byte) (int, error) {
if c.isClosed() {
return 0, c.closeError
}
@ -178,7 +180,7 @@ func (c *Channel) Read(b []byte) (n int, err error) {
}
}
n, err = c.rwc.Read(b)
bytesRead, err := c.rwc.Read(bytes)
if err != nil {
if c.isClosed() {
return 0, c.closeError
@ -189,9 +191,9 @@ func (c *Channel) Read(b []byte) (n int, err error) {
if xerrors.Is(err, io.EOF) {
err = c.closeWithError(ErrClosed)
}
return
return bytesRead, err
}
return
return bytesRead, err
}
// Write sends data to the underlying DataChannel.
@ -202,8 +204,8 @@ func (c *Channel) Read(b []byte) (n int, err error) {
//
// If the Channel is setup to close on disconnect, any buffered
// data will be lost.
func (c *Channel) Write(b []byte) (n int, err error) {
if len(b) > maxMessageLength {
func (c *Channel) Write(bytes []byte) (n int, err error) {
if len(bytes) > maxMessageLength {
return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength)
}
@ -220,7 +222,7 @@ func (c *Channel) Write(b []byte) (n int, err error) {
}
}
if c.dc.BufferedAmount()+uint64(len(b)) >= maxBufferedAmount {
if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount {
<-c.sendMore
}
// TODO (@kyle): There's an obvious race-condition here.
@ -230,7 +232,7 @@ func (c *Channel) Write(b []byte) (n int, err error) {
// See: https://github.com/pion/sctp/issues/181
time.Sleep(time.Microsecond)
return c.rwc.Write(b)
return c.rwc.Write(bytes)
}
// Close gracefully closes the DataChannel.

View File

@ -42,6 +42,7 @@ func Server(servers []webrtc.ICEServer, opts *ConnOpts) (*Conn, error) {
}
// newWithClientOrServer constructs a new connection with the client option.
// nolint:revive
func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOpts) (*Conn, error) {
if opts == nil {
opts = &ConnOpts{}
@ -60,7 +61,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
if err != nil {
return nil, xerrors.Errorf("create peer connection: %w", err)
}
c := &Conn{
conn := &Conn{
pingChannelID: 1,
pingEchoChannelID: 2,
opts: opts,
@ -77,13 +78,13 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
if client {
// If we're the client, we want to flip the echo and
// ping channel IDs so pings don't accidentally hit each other.
c.pingChannelID, c.pingEchoChannelID = c.pingEchoChannelID, c.pingChannelID
conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID
}
err = c.init()
err = conn.init()
if err != nil {
return nil, xerrors.Errorf("init: %w", err)
}
return c, nil
return conn, nil
}
type ConnOpts struct {
@ -142,6 +143,10 @@ func (c *Conn) init() error {
}
})
c.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
// Close must be locked here otherwise log output can appear
// after the connection has been closed.
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return
}
@ -211,12 +216,12 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
if c.isClosed() {
return
}
_ = c.closeWithError(xerrors.Errorf("read ping echo channel: %w", err))
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
return
}
_, err = c.pingEchoChan.Write(data[:bytesRead])
if err != nil {
_ = c.closeWithError(xerrors.Errorf("write ping echo channel: %w", err))
_ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err))
return
}
}
@ -237,12 +242,12 @@ func (c *Conn) negotiate() {
if c.offerrer {
offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{})
if err != nil {
_ = c.closeWithError(xerrors.Errorf("create offer: %w", err))
_ = c.CloseWithError(xerrors.Errorf("create offer: %w", err))
return
}
err = c.rtc.SetLocalDescription(offer)
if err != nil {
_ = c.closeWithError(xerrors.Errorf("set local description: %w", err))
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
select {
@ -261,19 +266,19 @@ func (c *Conn) negotiate() {
err := c.rtc.SetRemoteDescription(remoteDescription)
if err != nil {
_ = c.closeWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
_ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
return
}
if !c.offerrer {
answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{})
if err != nil {
_ = c.closeWithError(xerrors.Errorf("create answer: %w", err))
_ = c.CloseWithError(xerrors.Errorf("create answer: %w", err))
return
}
err = c.rtc.SetLocalDescription(answer)
if err != nil {
_ = c.closeWithError(xerrors.Errorf("set local description: %w", err))
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
if c.isClosed() {
@ -296,20 +301,20 @@ func (c *Conn) proxyICECandidates() func() {
queue = []webrtc.ICECandidateInit{}
flushed = false
)
c.rtc.OnICECandidate(func(i *webrtc.ICECandidate) {
if i == nil {
c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) {
if iceCandidate == nil {
return
}
mut.Lock()
defer mut.Unlock()
if !flushed {
queue = append(queue, i.ToJSON())
queue = append(queue, iceCandidate.ToJSON())
return
}
select {
case <-c.closed:
return
case c.localCandidateChannel <- i.ToJSON():
case c.localCandidateChannel <- iceCandidate.ToJSON():
}
})
return func() {
@ -353,7 +358,7 @@ func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error {
}
// SetRemoteSessionDescription sets the remote description for the WebRTC connection.
func (c *Conn) SetRemoteSessionDescription(s webrtc.SessionDescription) {
func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) {
if c.isClosed() {
return
}
@ -361,7 +366,7 @@ func (c *Conn) SetRemoteSessionDescription(s webrtc.SessionDescription) {
defer c.closeMutex.Unlock()
select {
case <-c.closed:
case c.remoteSessionDescriptionChannel <- s:
case c.remoteSessionDescriptionChannel <- sessionDescription:
}
}
@ -407,7 +412,7 @@ func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOpts)
return nil, xerrors.Errorf("closed: %w", c.closeError)
}
dc, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{
dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{
ID: id,
Negotiated: &opts.Negotiated,
Ordered: &ordered,
@ -416,7 +421,7 @@ func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOpts)
if err != nil {
return nil, xerrors.Errorf("create data channel: %w", err)
}
return newChannel(c, dc, opts), nil
return newChannel(c, dataChannel, opts), nil
}
// Ping returns the duration it took to round-trip data.
@ -461,12 +466,7 @@ func (c *Conn) Closed() <-chan struct{} {
// Close closes the connection and frees all associated resources.
func (c *Conn) Close() error {
return c.closeWithError(nil)
}
// CloseWithError closes the connection; subsequent reads/writes will return the error err.
func (c *Conn) CloseWithError(err error) error {
return c.closeWithError(err)
return c.CloseWithError(nil)
}
func (c *Conn) isClosed() bool {
@ -478,7 +478,8 @@ func (c *Conn) isClosed() bool {
}
}
func (c *Conn) closeWithError(err error) error {
// CloseWithError closes the connection; subsequent reads/writes will return the error err.
func (c *Conn) CloseWithError(err error) error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

View File

@ -67,13 +67,13 @@ func TestConn(t *testing.T) {
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
ch, err := server.Dial(context.Background(), "wow", nil)
channel, err := server.Dial(context.Background(), "wow", nil)
assert.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
// Once the connection is marked as disconnected, this
// channel will be closed.
_, err = ch.Read(make([]byte, 4))
_, err = channel.Read(make([]byte, 4))
assert.ErrorIs(t, err, peer.ErrClosed)
err = wan.Start()
require.NoError(t, err)
@ -154,26 +154,26 @@ func TestConn(t *testing.T) {
_, _ = io.Copy(nc2, nc1)
}()
go func() {
s := http.Server{
server := http.Server{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
}),
}
defer s.Close()
_ = s.Serve(srv)
defer server.Close()
_ = server.Serve(srv)
}()
dt := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
dt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.Dial(ctx, "hello", &peer.ChannelOpts{})
if err != nil {
return nil, err
}
return cch.NetConn(), nil
}
c := http.Client{
Transport: dt,
Transport: defaultTransport,
}
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
require.NoError(t, err)
@ -183,7 +183,7 @@ func TestConn(t *testing.T) {
require.Equal(t, resp.StatusCode, 200)
// Triggers any connections to close.
// This test below ensures the DataChannel actually closes.
dt.CloseIdleConnections()
defaultTransport.CloseIdleConnections()
err = cch.Close()
require.ErrorIs(t, err, peer.ErrClosed)
})
@ -226,13 +226,13 @@ func TestConn(t *testing.T) {
}
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
lf := logging.NewDefaultLoggerFactory()
lf.DefaultLogLevel = logging.LogLevelDisabled
loggingFactory := logging.NewDefaultLoggerFactory()
loggingFactory.DefaultLogLevel = logging.LogLevelDisabled
vnetMutex.Lock()
defer vnetMutex.Unlock()
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: lf,
LoggerFactory: loggingFactory,
})
require.NoError(t, err)
c1Net := vnet.NewNet(&vnet.NetConfig{
@ -250,25 +250,25 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
c1SettingEngine.SetVNet(c1Net)
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
c1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{
channel1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{
SettingEngine: c1SettingEngine,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
c1.Close()
channel1.Close()
})
c2SettingEngine := webrtc.SettingEngine{}
c2SettingEngine.SetVNet(c2Net)
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
c2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{
channel2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{
SettingEngine: c2SettingEngine,
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
c2.Close()
channel2.Close()
})
err = wan.Start()
@ -280,11 +280,11 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
go func() {
for {
select {
case c := <-c2.LocalCandidate():
_ = c1.AddRemoteCandidate(c)
case c := <-c2.LocalSessionDescription():
c1.SetRemoteSessionDescription(c)
case <-c2.Closed():
case c := <-channel2.LocalCandidate():
_ = channel1.AddRemoteCandidate(c)
case c := <-channel2.LocalSessionDescription():
channel1.SetRemoteSessionDescription(c)
case <-channel2.Closed():
return
}
}
@ -293,15 +293,15 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
go func() {
for {
select {
case c := <-c1.LocalCandidate():
_ = c2.AddRemoteCandidate(c)
case c := <-c1.LocalSessionDescription():
c2.SetRemoteSessionDescription(c)
case <-c1.Closed():
case c := <-channel1.LocalCandidate():
_ = channel2.AddRemoteCandidate(c)
case c := <-channel1.LocalSessionDescription():
channel2.SetRemoteSessionDescription(c)
case <-channel1.Closed():
return
}
}
}()
return c1, c2, wan
return channel1, channel2, wan
}

View File

@ -10,11 +10,11 @@ type peerAddr struct{}
// Statically checks if we properly implement net.Addr.
var _ net.Addr = &peerAddr{}
func (a *peerAddr) Network() string {
func (*peerAddr) Network() string {
return "peer"
}
func (a *peerAddr) String() string {
func (*peerAddr) String() string {
return "peer/unknown-addr"
}
@ -46,14 +46,14 @@ func (c *fakeNetConn) RemoteAddr() net.Addr {
return c.addr
}
func (c *fakeNetConn) SetDeadline(_ time.Time) error {
func (*fakeNetConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *fakeNetConn) SetReadDeadline(_ time.Time) error {
func (*fakeNetConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *fakeNetConn) SetWriteDeadline(_ time.Time) error {
func (*fakeNetConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@ -161,7 +161,6 @@ func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_Nego
Type: webrtc.SDPType(clientToServerMessage.GetOffer().SdpType),
SDP: clientToServerMessage.GetOffer().Sdp,
})
break
case clientToServerMessage.GetServers() != nil:
// Convert protobuf ICE servers to the WebRTC type.
iceServers := make([]webrtc.ICEServer, 0, len(clientToServerMessage.GetServers().Servers))

View File

@ -12,7 +12,7 @@ import (
)
// Parse extracts Terraform variables from source-code.
func (t *terraform) Parse(ctx context.Context, request *proto.Parse_Request) (*proto.Parse_Response, error) {
func (*terraform) Parse(_ context.Context, request *proto.Parse_Request) (*proto.Parse_Response, error) {
module, diags := tfconfig.LoadModule(request.Directory)
if diags.HasErrors() {
return nil, xerrors.Errorf("load module: %w", diags.Err())

View File

@ -37,7 +37,7 @@ func TestParse(t *testing.T) {
}()
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
for _, tc := range []struct {
for _, testCase := range []struct {
Name string
Files map[string]string
Response *proto.Parse_Response
@ -83,13 +83,13 @@ func TestParse(t *testing.T) {
}},
},
}} {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
testCase := testCase
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()
// Write all files to the temporary test directory.
directory := t.TempDir()
for path, content := range tc.Files {
for path, content := range testCase.Files {
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
require.NoError(t, err)
}
@ -100,7 +100,7 @@ func TestParse(t *testing.T) {
require.NoError(t, err)
// Ensure the want and got are equivalent!
want, err := json.Marshal(tc.Response)
want, err := json.Marshal(testCase.Response)
require.NoError(t, err)
got, err := json.Marshal(response)
require.NoError(t, err)

View File

@ -15,9 +15,9 @@ import (
// Provision executes `terraform apply`.
func (t *terraform) Provision(ctx context.Context, request *proto.Provision_Request) (*proto.Provision_Response, error) {
statefilePath := filepath.Join(request.Directory, "terraform.tfstate")
err := os.WriteFile(statefilePath, request.State, 0644)
err := os.WriteFile(statefilePath, request.State, 0600)
if err != nil {
return nil, xerrors.Errorf("write statefile %q: %w", err)
return nil, xerrors.Errorf("write statefile %q: %w", statefilePath, err)
}
terraform, err := tfexec.NewTerraform(request.Directory, t.binaryPath)

View File

@ -48,7 +48,7 @@ func TestProvision(t *testing.T) {
}()
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
for _, tc := range []struct {
for _, testCase := range []struct {
Name string
Files map[string]string
Request *proto.Provision_Request
@ -93,25 +93,25 @@ func TestProvision(t *testing.T) {
},
Error: true,
}} {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
testCase := testCase
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()
directory := t.TempDir()
for path, content := range tc.Files {
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0644)
for path, content := range testCase.Files {
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
require.NoError(t, err)
}
request := &proto.Provision_Request{
Directory: directory,
}
if tc.Request != nil {
request.ParameterValues = tc.Request.ParameterValues
request.State = tc.Request.State
if testCase.Request != nil {
request.ParameterValues = testCase.Request.ParameterValues
request.State = testCase.Request.State
}
response, err := api.Provision(ctx, request)
if tc.Error {
if testCase.Error {
require.Error(t, err)
return
}
@ -121,7 +121,7 @@ func TestProvision(t *testing.T) {
resourcesGot, err := json.Marshal(response.Resources)
require.NoError(t, err)
resourcesWant, err := json.Marshal(tc.Response.Resources)
resourcesWant, err := json.Marshal(testCase.Response.Resources)
require.NoError(t, err)
require.Equal(t, string(resourcesWant), string(resourcesGot))

View File

@ -28,7 +28,7 @@ var site embed.FS
// Handler returns an HTTP handler for serving the static site.
func Handler() http.Handler {
f, err := fs.Sub(site, "out")
filesystem, err := fs.Sub(site, "out")
if err != nil {
// This can't happen... Go would throw a compilation error.
panic(err)
@ -36,15 +36,15 @@ func Handler() http.Handler {
// html files are handled by a text/template. Non-html files
// are served by the default file server.
files, err := htmlFiles(f)
files, err := htmlFiles(filesystem)
if err != nil {
panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err))
}
return secureHeaders(&handler{
fs: f,
fs: filesystem,
htmlFiles: files,
h: http.FileServer(http.FS(f)), // All other non-html static files
h: http.FileServer(http.FS(filesystem)), // All other non-html static files
})
}
@ -61,15 +61,15 @@ type handler struct {
}
// filePath returns the filepath of the requested file.
func (h *handler) filePath(p string) string {
func (*handler) filePath(p string) string {
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
return strings.TrimPrefix(path.Clean(p), "/")
}
func (h *handler) exists(path string) bool {
f, err := h.fs.Open(path)
func (h *handler) exists(filePath string) bool {
f, err := h.fs.Open(filePath)
if err == nil {
_ = f.Close()
}
@ -89,7 +89,7 @@ type csrfState struct {
Token string
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// reqFile is the static file requested
reqFile := h.filePath(r.URL.Path)
state := htmlState{
@ -100,13 +100,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// First check if it's a file we have in our templates
if h.serveHtml(w, r, reqFile, state) {
if h.serveHTML(rw, r, reqFile, state) {
return
}
// If the original file path exists we serve it.
if h.exists(reqFile) {
h.h.ServeHTTP(w, r)
h.h.ServeHTTP(rw, r)
return
}
@ -117,28 +117,28 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
reqFile = h.filePath(r.URL.Path)
// All html files should be served by the htmlFile templates
if h.serveHtml(w, r, reqFile, state) {
if h.serveHTML(rw, r, reqFile, state) {
return
}
// If we don't have the file... we should redirect to `/`
// for our single-page-app.
r.URL.Path = "/"
if h.serveHtml(w, r, "", state) {
if h.serveHTML(rw, r, "", state) {
return
}
// This will send a correct 404
h.h.ServeHTTP(w, r)
h.h.ServeHTTP(rw, r)
}
func (h *handler) serveHtml(w http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool {
func (h *handler) serveHTML(rw http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool {
if data, err := h.htmlFiles.renderWithState(reqPath, state); err == nil {
if reqPath == "" {
// Pass "index.html" to the ServeContent so the ServeContent sets the right content headers.
reqPath = "index.html"
}
http.ServeContent(w, r, reqPath, time.Time{}, bytes.NewReader(data))
http.ServeContent(rw, r, reqPath, time.Time{}, bytes.NewReader(data))
return true
}
return false
@ -150,12 +150,12 @@ type htmlTemplates struct {
// renderWithState will render the file using the given nonce if the file exists
// as a template. If it does not, it will return an error.
func (t *htmlTemplates) renderWithState(path string, state htmlState) ([]byte, error) {
func (t *htmlTemplates) renderWithState(filePath string, state htmlState) ([]byte, error) {
var buf bytes.Buffer
if path == "" {
path = "index.html"
if filePath == "" {
filePath = "index.html"
}
err := t.tpls.ExecuteTemplate(&buf, path, state)
err := t.tpls.ExecuteTemplate(&buf, filePath, state)
if err != nil {
return nil, err
}
@ -168,13 +168,6 @@ func (t *htmlTemplates) renderWithState(path string, state htmlState) ([]byte, e
// All directives are semi-colon separated as a single string for the csp header.
type cspDirectives map[cspFetchDirective][]string
func (s cspDirectives) append(d cspFetchDirective, values ...string) {
if _, ok := s[d]; !ok {
s[d] = make([]string, 0)
}
s[d] = append(s[d], values...)
}
// cspFetchDirective is the list of all constant fetch directives that
// can be used/appended to.
type cspFetchDirective string
@ -234,7 +227,7 @@ func secureHeaders(next http.Handler) http.Handler {
var csp strings.Builder
for src, vals := range cspSrcs {
fmt.Fprintf(&csp, "%s %s; ", src, strings.Join(vals, " "))
_, _ = fmt.Fprintf(&csp, "%s %s; ", src, strings.Join(vals, " "))
}
// Permissions-Policy can be used to disabled various browser features that we do not use.
@ -280,16 +273,16 @@ func htmlFiles(files fs.FS) (*htmlTemplates, error) {
root := template.New("")
rootPath := "."
err := fs.WalkDir(files, rootPath, func(path string, d fs.DirEntry, err error) error {
err := fs.WalkDir(files, rootPath, func(path string, dirEntry fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
if dirEntry.IsDir() {
return nil
}
if filepath.Ext(d.Name()) != ".html" {
if filepath.Ext(dirEntry.Name()) != ".html" {
return nil
}

View File

@ -1,7 +1,9 @@
package site
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
@ -13,8 +15,11 @@ func TestIndexPageRenders(t *testing.T) {
srv := httptest.NewServer(Handler())
resp, err := srv.Client().Get(srv.URL)
req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err, "get index")
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
require.NotEmpty(t, data, "index should have contents")
}