mirror of https://github.com/coder/coder.git
chore: Fix golangci-lint configuration and patch errors (#34)
* chore: Fix golangci-lint configuration and patch errors Due to misconfiguration of a linting rules directory, our linter has not been working properly. This change fixes the configuration issue, and all remaining linting errors. * Fix race in peer logging * Fix race and return * Lock on bufferred amount low * Fix mutex lock
This commit is contained in:
parent
6a919aea79
commit
2654a93132
|
@ -100,10 +100,6 @@ linters-settings:
|
|||
# - whyNoLint
|
||||
# - wrapperFunc
|
||||
# - yodaStyleExpr
|
||||
settings:
|
||||
ruleguard:
|
||||
failOn: all
|
||||
rules: "${configDir}/lib/go/lintrules/*.go"
|
||||
|
||||
goimports:
|
||||
local-prefixes: coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder
|
||||
|
@ -113,24 +109,6 @@ linters-settings:
|
|||
|
||||
importas:
|
||||
no-unaliased: true
|
||||
alias:
|
||||
- pkg: k8s.io/api/(\w+)/(v[\w\d]+)
|
||||
alias: ${1}${2}
|
||||
|
||||
- pkg: k8s.io/apimachinery/pkg/apis/meta/(v[\w\d]+)
|
||||
alias: meta${1}
|
||||
|
||||
- pkg: k8s.io/client-go/kubernetes/typed/(\w+)/(v[\w\d]+)
|
||||
alias: ${1}${2}client
|
||||
|
||||
- pkg: k8s.io/metrics/pkg/apis/metrics/(v[\w\d]+)
|
||||
alias: metrics${1}
|
||||
|
||||
- pkg: github.com/docker/docker/api/types
|
||||
alias: dockertypes
|
||||
|
||||
- pkg: github.com/docker/docker/client
|
||||
alias: dockerclient
|
||||
|
||||
misspell:
|
||||
locale: US
|
||||
|
@ -195,6 +173,20 @@ linters-settings:
|
|||
- name: var-declaration
|
||||
- name: var-naming
|
||||
- name: waitgroup-by-value
|
||||
varnamelen:
|
||||
ignore-names:
|
||||
- err
|
||||
- rw
|
||||
- r
|
||||
- i
|
||||
- db
|
||||
# Optional list of variable declarations that should be ignored completely. (defaults to empty list)
|
||||
# Entries must be in the form of "<variable name> <type>" or "<variable name> *<type>" for
|
||||
# variables, or "const <name>" for constants.
|
||||
ignore-decls:
|
||||
- rw http.ResponseWriter
|
||||
- r *http.Request
|
||||
- t testing.T
|
||||
|
||||
issues:
|
||||
# Rules listed here: https://github.com/securego/gosec#available-rules
|
||||
|
@ -222,7 +214,6 @@ linters:
|
|||
- asciicheck
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- contextcheck
|
||||
- deadcode
|
||||
- dogsled
|
||||
- errcheck
|
||||
|
@ -239,7 +230,6 @@ linters:
|
|||
- govet
|
||||
- importas
|
||||
- ineffassign
|
||||
# - ireturn
|
||||
- makezero
|
||||
- misspell
|
||||
- nilnil
|
||||
|
|
|
@ -3,5 +3,5 @@ package main
|
|||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello World!")
|
||||
_, _ = fmt.Println("Hello World!")
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ func New(options *Options) http.Handler {
|
|||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
httpmw.ExtractUser(options.Database),
|
||||
)
|
||||
r.Get("/user", users.getAuthenticatedUser)
|
||||
r.Get("/user", users.authenticatedUser)
|
||||
})
|
||||
})
|
||||
r.NotFound(site.Handler().ServeHTTP)
|
||||
|
|
|
@ -32,11 +32,11 @@ func New(t *testing.T) Server {
|
|||
Database: db,
|
||||
})
|
||||
srv := httptest.NewServer(handler)
|
||||
u, err := url.Parse(srv.URL)
|
||||
serverURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
client := codersdk.New(u)
|
||||
client := codersdk.New(serverURL)
|
||||
_, err = client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
|
||||
Email: "testuser@coder.com",
|
||||
Username: "testuser",
|
||||
|
@ -54,6 +54,6 @@ func New(t *testing.T) Server {
|
|||
|
||||
return Server{
|
||||
Client: client,
|
||||
URL: u,
|
||||
URL: serverURL,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,14 +35,14 @@ func Compare(hashed string, password string) (bool, error) {
|
|||
if len(parts[0]) != 0 {
|
||||
return false, xerrors.Errorf("hash prefix is invalid")
|
||||
}
|
||||
if string(parts[1]) != hashScheme {
|
||||
if parts[1] != hashScheme {
|
||||
return false, xerrors.Errorf("hash isn't %q scheme: %q", hashScheme, parts[1])
|
||||
}
|
||||
iter, err := strconv.Atoi(string(parts[2]))
|
||||
iter, err := strconv.Atoi(parts[2])
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("parse iter from hash: %w", err)
|
||||
}
|
||||
salt, err := base64.RawStdEncoding.DecodeString(string(parts[3]))
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("decode salt: %w", err)
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
return
|
||||
}
|
||||
user, err := users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
_, err = users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Email: createUser.Email,
|
||||
Username: createUser.Username,
|
||||
})
|
||||
|
@ -91,7 +91,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, err = users.Database.InsertUser(context.Background(), database.InsertUserParams{
|
||||
user, err := users.Database.InsertUser(context.Background(), database.InsertUserParams{
|
||||
ID: uuid.NewString(),
|
||||
Email: createUser.Email,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
|
@ -111,7 +111,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Returns the currently authenticated user.
|
||||
func (users *users) getAuthenticatedUser(rw http.ResponseWriter, r *http.Request) {
|
||||
func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) {
|
||||
user := httpmw.User(r)
|
||||
|
||||
render.JSON(rw, r, User{
|
||||
|
@ -158,11 +158,17 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
id, secret, err := generateAPIKeyIDSecret()
|
||||
hashed := sha256.Sum256([]byte(secret))
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = users.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
ID: keyID,
|
||||
UserID: user.ID,
|
||||
ExpiresAt: database.Now().Add(24 * time.Hour),
|
||||
CreatedAt: database.Now(),
|
||||
|
@ -178,7 +184,7 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", id, secret)
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: sessionToken,
|
||||
|
@ -194,14 +200,14 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Generates a new ID and secret for an API key.
|
||||
func generateAPIKeyIDSecret() (string, string, error) {
|
||||
func generateAPIKeyIDSecret() (id string, secret string, err error) {
|
||||
// Length of an API Key ID.
|
||||
id, err := cryptorand.String(10)
|
||||
id, err = cryptorand.String(10)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
// Length of an API Key secret.
|
||||
secret, err := cryptorand.String(22)
|
||||
secret, err = cryptorand.String(22)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
|
|
@ -4,9 +4,10 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsers(t *testing.T) {
|
||||
|
|
|
@ -18,9 +18,9 @@ import (
|
|||
)
|
||||
|
||||
// New creates a Coder client for the provided URL.
|
||||
func New(url *url.URL) *Client {
|
||||
func New(serverURL *url.URL) *Client {
|
||||
return &Client{
|
||||
url: url,
|
||||
url: serverURL,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ func (c *Client) SetSessionToken(token string) error {
|
|||
// request performs an HTTP request with the body provided.
|
||||
// The caller is responsible for closing the response body.
|
||||
func (c *Client) request(ctx context.Context, method, path string, body interface{}) (*http.Response, error) {
|
||||
url, err := c.url.Parse(path)
|
||||
serverURL, err := c.url.Parse(path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
|
|||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url.String(), &buf)
|
||||
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
|
|||
}
|
||||
|
||||
// readBodyAsError reads the response as an httpapi.Message, and
|
||||
// wraps it in a codersdk.Error type for easy marshalling.
|
||||
// wraps it in a codersdk.Error type for easy marshaling.
|
||||
func readBodyAsError(res *http.Response) error {
|
||||
var m httpapi.Response
|
||||
err := json.NewDecoder(res.Body).Decode(&m)
|
||||
|
|
|
@ -26,7 +26,7 @@ func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserReq
|
|||
|
||||
// User returns a user for the ID provided.
|
||||
// If the ID string is empty, the current user will be returned.
|
||||
func (c *Client) User(ctx context.Context, id string) (coderd.User, error) {
|
||||
func (c *Client) User(ctx context.Context, _ string) (coderd.User, error) {
|
||||
res, err := c.request(ctx, http.MethodGet, "/api/v2/user", nil)
|
||||
if err != nil {
|
||||
return coderd.User{}, err
|
||||
|
|
|
@ -5,10 +5,11 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsers(t *testing.T) {
|
||||
|
|
|
@ -73,42 +73,43 @@ func Int() (int, error) {
|
|||
return int(i), nil
|
||||
}
|
||||
|
||||
// Int63n returns a non-negative random integer in [0,n) as a int64.
|
||||
func Int63n(n int64) (int64, error) {
|
||||
if n <= 0 {
|
||||
// Int63n returns a non-negative random integer in [0,max) as a int64.
|
||||
func Int63n(max int64) (int64, error) {
|
||||
if max <= 0 {
|
||||
panic("invalid argument to Int63n")
|
||||
}
|
||||
|
||||
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
|
||||
trueMax := int64((1 << 63) - 1 - (1<<63)%uint64(max))
|
||||
i, err := Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i > max {
|
||||
for i > trueMax {
|
||||
i, err = Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return i % n, nil
|
||||
return i % max, nil
|
||||
}
|
||||
|
||||
// Int31n returns a non-negative integer in [0,n) as a int32.
|
||||
func Int31n(n int32) (int32, error) {
|
||||
// Int31n returns a non-negative integer in [0,max) as a int32.
|
||||
func Int31n(max int32) (int32, error) {
|
||||
i, err := Uint32()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return UnbiasedModulo32(i, n)
|
||||
return UnbiasedModulo32(i, max)
|
||||
}
|
||||
|
||||
// UnbiasedModulo32 uniformly modulos v by n over a sufficiently large data
|
||||
// set, regenerating v if necessary. n must be > 0. All input bits in v must be
|
||||
// fully random, you cannot cast a random uint8/uint16 for input into this
|
||||
// function.
|
||||
//nolint:varnamelen
|
||||
func UnbiasedModulo32(v uint32, n int32) (int32, error) {
|
||||
prod := uint64(v) * uint64(n)
|
||||
low := uint32(prod)
|
||||
|
@ -127,14 +128,14 @@ func UnbiasedModulo32(v uint32, n int32) (int32, error) {
|
|||
return int32(prod >> 32), nil
|
||||
}
|
||||
|
||||
// Intn returns a non-negative integer in [0,n) as a int.
|
||||
func Intn(n int) (int, error) {
|
||||
if n <= 0 {
|
||||
// Intn returns a non-negative integer in [0,max) as a int.
|
||||
func Intn(max int) (int, error) {
|
||||
if max <= 0 {
|
||||
panic("n must be a positive nonzero number")
|
||||
}
|
||||
|
||||
if n <= 1<<31-1 {
|
||||
i, err := Int31n(int32(n))
|
||||
if max <= 1<<31-1 {
|
||||
i, err := Int31n(int32(max))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -142,7 +143,7 @@ func Intn(n int) (int, error) {
|
|||
return int(i), nil
|
||||
}
|
||||
|
||||
i, err := Int63n(int64(n))
|
||||
i, err := Int63n(int64(max))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
@ -5,8 +5,9 @@ import (
|
|||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestInt63(t *testing.T) {
|
||||
|
@ -144,7 +145,7 @@ func TestBool(t *testing.T) {
|
|||
const iterations = 10000
|
||||
trueCount := 0
|
||||
|
||||
for i := 0; i < iterations; i += 1 {
|
||||
for i := 0; i < iterations; i++ {
|
||||
v, err := cryptorand.Bool()
|
||||
require.NoError(t, err, "unexpected error from Bool")
|
||||
if v {
|
||||
|
|
|
@ -53,7 +53,7 @@ func StringCharset(charSetStr string, size int) (string, error) {
|
|||
buf.Grow(size)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
c, err := UnbiasedModulo32(
|
||||
count, err := UnbiasedModulo32(
|
||||
binary.BigEndian.Uint32(ibuf[i*4:(i+1)*4]),
|
||||
int32(len(charSet)),
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ func StringCharset(charSetStr string, size int) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
_, _ = buf.WriteRune(charSet[c])
|
||||
_, _ = buf.WriteRune(charSet[count])
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
|
|
|
@ -8,8 +8,9 @@ import (
|
|||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
|
|
|
@ -22,11 +22,11 @@ type fakeQuerier struct {
|
|||
}
|
||||
|
||||
// InTx doesn't rollback data properly for in-memory yet.
|
||||
func (q *fakeQuerier) InTx(ctx context.Context, fn func(database.Store) error) error {
|
||||
func (q *fakeQuerier) InTx(fn func(database.Store) error) error {
|
||||
return fn(q)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
|
||||
func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) {
|
||||
for _, apiKey := range q.apiKeys {
|
||||
if apiKey.ID == id {
|
||||
return apiKey, nil
|
||||
|
@ -35,7 +35,7 @@ func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.AP
|
|||
return database.APIKey{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) {
|
||||
func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) {
|
||||
for _, user := range q.users {
|
||||
if user.Email == arg.Email || user.Username == arg.Username {
|
||||
return user, nil
|
||||
|
@ -44,7 +44,7 @@ func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database
|
|||
return database.User{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User, error) {
|
||||
func (q *fakeQuerier) GetUserByID(_ context.Context, id string) (database.User, error) {
|
||||
for _, user := range q.users {
|
||||
if user.ID == id {
|
||||
return user, nil
|
||||
|
@ -53,11 +53,12 @@ func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User
|
|||
return database.User{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserCount(ctx context.Context) (int64, error) {
|
||||
func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) {
|
||||
return int64(len(q.users)), nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
|
||||
func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
|
||||
//nolint:gosimple
|
||||
key := database.APIKey{
|
||||
ID: arg.ID,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
|
@ -79,7 +80,7 @@ func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKe
|
|||
return key, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) {
|
||||
func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) {
|
||||
user := database.User{
|
||||
ID: arg.ID,
|
||||
Email: arg.Email,
|
||||
|
@ -94,7 +95,7 @@ func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserPar
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error {
|
||||
for index, apiKey := range q.apiKeys {
|
||||
if apiKey.ID != arg.ID {
|
||||
continue
|
||||
|
|
|
@ -21,7 +21,7 @@ import (
|
|||
type Store interface {
|
||||
querier
|
||||
|
||||
InTx(context.Context, func(Store) error) error
|
||||
InTx(func(Store) error) error
|
||||
}
|
||||
|
||||
// DBTX represents a database connection or transaction.
|
||||
|
@ -46,16 +46,16 @@ type sqlQuerier struct {
|
|||
}
|
||||
|
||||
// InTx performs database operations inside a transaction.
|
||||
func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error {
|
||||
func (q *sqlQuerier) InTx(function func(Store) error) error {
|
||||
if q.sdb == nil {
|
||||
return nil
|
||||
}
|
||||
tx, err := q.sdb.Begin()
|
||||
transaction, err := q.sdb.Begin()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
rerr := tx.Rollback()
|
||||
rerr := transaction.Rollback()
|
||||
if rerr == nil || errors.Is(rerr, sql.ErrTxDone) {
|
||||
// no need to do anything, tx committed successfully
|
||||
return
|
||||
|
@ -63,11 +63,11 @@ func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error {
|
|||
// couldn't roll back for some reason, extend returned error
|
||||
err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err)
|
||||
}()
|
||||
err = fn(&sqlQuerier{db: tx})
|
||||
err = function(&sqlQuerier{db: transaction})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("execute transaction: %w", err)
|
||||
}
|
||||
err = tx.Commit()
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -25,7 +24,7 @@ func main() {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = database.Migrate(context.Background(), db)
|
||||
err = database.Migrate(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -82,7 +81,7 @@ func main() {
|
|||
if !ok {
|
||||
panic("couldn't get caller path")
|
||||
}
|
||||
err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0644)
|
||||
err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
|
@ -16,7 +15,7 @@ import (
|
|||
var migrations embed.FS
|
||||
|
||||
// Migrate runs SQL migrations to ensure the database schema is up-to-date.
|
||||
func Migrate(ctx context.Context, db *sql.DB) error {
|
||||
func Migrate(db *sql.DB) error {
|
||||
sourceDriver, err := iofs.New(migrations, "migrations")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create iofs: %w", err)
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
package database_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
|
@ -27,6 +26,6 @@ func TestMigrate(t *testing.T) {
|
|||
db, err := sql.Open("postgres", connection)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
err = database.Migrate(context.Background(), db)
|
||||
err = database.Migrate(db)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package postgres
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/ory/dockertest/v3"
|
||||
|
@ -32,13 +31,16 @@ func Open() (string, func(), error) {
|
|||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start resource: %s", err)
|
||||
return "", nil, xerrors.Errorf("could not start resource: %w", err)
|
||||
}
|
||||
hostAndPort := resource.GetHostPort("5432/tcp")
|
||||
dbURL := fmt.Sprintf("postgres://postgres:postgres@%s/postgres?sslmode=disable", hostAndPort)
|
||||
|
||||
// Docker should hard-kill the container after 120 seconds.
|
||||
resource.Expire(120)
|
||||
err = resource.Expire(120)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("could not expire resource: %w", err)
|
||||
}
|
||||
|
||||
pool.MaxWait = 120 * time.Second
|
||||
err = pool.Retry(func() error {
|
||||
|
|
|
@ -122,7 +122,7 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
|
|||
}
|
||||
|
||||
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
|
||||
func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, error) {
|
||||
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
|
||||
// Creates a new listener using pq.
|
||||
errCh := make(chan error)
|
||||
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
|
||||
|
@ -144,12 +144,12 @@ func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, erro
|
|||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
pg := &pgPubsub{
|
||||
db: db,
|
||||
pgPubsub := &pgPubsub{
|
||||
db: database,
|
||||
pgListener: listener,
|
||||
listeners: make(map[string]map[string]Listener),
|
||||
}
|
||||
go pg.listen(ctx)
|
||||
go pgPubsub.listen(ctx)
|
||||
|
||||
return pg, nil
|
||||
return pgPubsub, nil
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *memoryPubsub) Close() error {
|
||||
func (*memoryPubsub) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ func TestPubsubMemory(t *testing.T) {
|
|||
pubsub := database.NewPubsubInMemory()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
ch := make(chan []byte)
|
||||
messageChannel := make(chan []byte)
|
||||
cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
ch <- message
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
@ -27,7 +27,7 @@ func TestPubsubMemory(t *testing.T) {
|
|||
err = pubsub.Publish(event, []byte(data))
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
message := <-ch
|
||||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -32,9 +32,9 @@ func TestPubsub(t *testing.T) {
|
|||
defer pubsub.Close()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
ch := make(chan []byte)
|
||||
messageChannel := make(chan []byte)
|
||||
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
ch <- message
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
|
@ -42,7 +42,7 @@ func TestPubsub(t *testing.T) {
|
|||
err = pubsub.Publish(event, []byte(data))
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
message := <-ch
|
||||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ func init() {
|
|||
}
|
||||
return name
|
||||
})
|
||||
validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
err := validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
f := fl.Field().Interface()
|
||||
str, ok := f.(string)
|
||||
if !ok {
|
||||
|
@ -45,6 +45,9 @@ func init() {
|
|||
}
|
||||
return usernameRegex.MatchString(str)
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents a generic HTTP response.
|
||||
|
@ -60,20 +63,20 @@ type Error struct {
|
|||
}
|
||||
|
||||
// Write outputs a standardized format to an HTTP response body.
|
||||
func Write(w http.ResponseWriter, status int, response Response) {
|
||||
func Write(rw http.ResponseWriter, status int, response Response) {
|
||||
buf := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetEscapeHTML(true)
|
||||
err := enc.Encode(response)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
_, err = w.Write(buf.Bytes())
|
||||
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
rw.WriteHeader(status)
|
||||
_, err = rw.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,22 +58,22 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
|||
})
|
||||
return
|
||||
}
|
||||
id := parts[0]
|
||||
secret := parts[1]
|
||||
keyID := parts[0]
|
||||
keySecret := parts[1]
|
||||
// Ensuring key lengths are valid.
|
||||
if len(id) != 10 {
|
||||
if len(keyID) != 10 {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("invalid %q cookie api key id", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(secret) != 22 {
|
||||
if len(keySecret) != 22 {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("invalid %q cookie api key secret", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
key, err := db.GetAPIKeyByID(r.Context(), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
|
@ -86,7 +86,7 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
|||
})
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(secret))
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
|
||||
|
|
|
@ -19,9 +19,9 @@ import (
|
|||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func randomAPIKeyParts() (string, string) {
|
||||
id, _ := cryptorand.String(10)
|
||||
secret, _ := cryptorand.String(22)
|
||||
func randomAPIKeyParts() (id string, secret string) {
|
||||
id, _ = cryptorand.String(10)
|
||||
secret, _ = cryptorand.String(22)
|
||||
return id, secret
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,9 @@ func TestAPIKey(t *testing.T) {
|
|||
rw = httptest.NewRecorder()
|
||||
)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidFormat", func(t *testing.T) {
|
||||
|
@ -56,7 +58,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidIDLength", func(t *testing.T) {
|
||||
|
@ -71,7 +75,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidSecretLength", func(t *testing.T) {
|
||||
|
@ -86,7 +92,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
|
@ -102,7 +110,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidSecret", func(t *testing.T) {
|
||||
|
@ -125,7 +135,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
|
@ -147,7 +159,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
|
@ -177,7 +191,9 @@ func TestAPIKey(t *testing.T) {
|
|||
Message: "it worked!",
|
||||
})
|
||||
})).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
@ -207,7 +223,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
@ -237,7 +255,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
@ -268,7 +288,9 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
@ -310,7 +332,9 @@ func TestAPIKey(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
@ -325,7 +349,7 @@ type oauth2Config struct {
|
|||
tokenSource *oauth2TokenSource
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
|
||||
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
|
||||
return o.tokenSource
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ const (
|
|||
// the channel on open. The datachannel should not be manually
|
||||
// mutated after being passed to this function.
|
||||
func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOpts) *Channel {
|
||||
c := &Channel{
|
||||
channel := &Channel{
|
||||
opts: opts,
|
||||
conn: conn,
|
||||
dc: dc,
|
||||
|
@ -37,8 +37,8 @@ func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOpts) *Channel
|
|||
closed: make(chan struct{}),
|
||||
sendMore: make(chan struct{}, 1),
|
||||
}
|
||||
c.init()
|
||||
return c
|
||||
channel.init()
|
||||
return channel
|
||||
}
|
||||
|
||||
type ChannelOpts struct {
|
||||
|
@ -109,6 +109,8 @@ func (c *Channel) init() {
|
|||
return
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.sendMore <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
@ -167,7 +169,7 @@ func (c *Channel) init() {
|
|||
// Read blocks until data is received.
|
||||
//
|
||||
// This will block until the underlying DataChannel has been opened.
|
||||
func (c *Channel) Read(b []byte) (n int, err error) {
|
||||
func (c *Channel) Read(bytes []byte) (int, error) {
|
||||
if c.isClosed() {
|
||||
return 0, c.closeError
|
||||
}
|
||||
|
@ -178,7 +180,7 @@ func (c *Channel) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
n, err = c.rwc.Read(b)
|
||||
bytesRead, err := c.rwc.Read(bytes)
|
||||
if err != nil {
|
||||
if c.isClosed() {
|
||||
return 0, c.closeError
|
||||
|
@ -189,9 +191,9 @@ func (c *Channel) Read(b []byte) (n int, err error) {
|
|||
if xerrors.Is(err, io.EOF) {
|
||||
err = c.closeWithError(ErrClosed)
|
||||
}
|
||||
return
|
||||
return bytesRead, err
|
||||
}
|
||||
return
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
// Write sends data to the underlying DataChannel.
|
||||
|
@ -202,8 +204,8 @@ func (c *Channel) Read(b []byte) (n int, err error) {
|
|||
//
|
||||
// If the Channel is setup to close on disconnect, any buffered
|
||||
// data will be lost.
|
||||
func (c *Channel) Write(b []byte) (n int, err error) {
|
||||
if len(b) > maxMessageLength {
|
||||
func (c *Channel) Write(bytes []byte) (n int, err error) {
|
||||
if len(bytes) > maxMessageLength {
|
||||
return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength)
|
||||
}
|
||||
|
||||
|
@ -220,7 +222,7 @@ func (c *Channel) Write(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
if c.dc.BufferedAmount()+uint64(len(b)) >= maxBufferedAmount {
|
||||
if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount {
|
||||
<-c.sendMore
|
||||
}
|
||||
// TODO (@kyle): There's an obvious race-condition here.
|
||||
|
@ -230,7 +232,7 @@ func (c *Channel) Write(b []byte) (n int, err error) {
|
|||
// See: https://github.com/pion/sctp/issues/181
|
||||
time.Sleep(time.Microsecond)
|
||||
|
||||
return c.rwc.Write(b)
|
||||
return c.rwc.Write(bytes)
|
||||
}
|
||||
|
||||
// Close gracefully closes the DataChannel.
|
||||
|
|
53
peer/conn.go
53
peer/conn.go
|
@ -42,6 +42,7 @@ func Server(servers []webrtc.ICEServer, opts *ConnOpts) (*Conn, error) {
|
|||
}
|
||||
|
||||
// newWithClientOrServer constructs a new connection with the client option.
|
||||
// nolint:revive
|
||||
func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOpts) (*Conn, error) {
|
||||
if opts == nil {
|
||||
opts = &ConnOpts{}
|
||||
|
@ -60,7 +61,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
c := &Conn{
|
||||
conn := &Conn{
|
||||
pingChannelID: 1,
|
||||
pingEchoChannelID: 2,
|
||||
opts: opts,
|
||||
|
@ -77,13 +78,13 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
|
|||
if client {
|
||||
// If we're the client, we want to flip the echo and
|
||||
// ping channel IDs so pings don't accidentally hit each other.
|
||||
c.pingChannelID, c.pingEchoChannelID = c.pingEchoChannelID, c.pingChannelID
|
||||
conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID
|
||||
}
|
||||
err = c.init()
|
||||
err = conn.init()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("init: %w", err)
|
||||
}
|
||||
return c, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type ConnOpts struct {
|
||||
|
@ -142,6 +143,10 @@ func (c *Conn) init() error {
|
|||
}
|
||||
})
|
||||
c.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) {
|
||||
// Close must be locked here otherwise log output can appear
|
||||
// after the connection has been closed.
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
|
@ -211,12 +216,12 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
|
|||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
_ = c.closeWithError(xerrors.Errorf("read ping echo channel: %w", err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
|
||||
return
|
||||
}
|
||||
_, err = c.pingEchoChan.Write(data[:bytesRead])
|
||||
if err != nil {
|
||||
_ = c.closeWithError(xerrors.Errorf("write ping echo channel: %w", err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -237,12 +242,12 @@ func (c *Conn) negotiate() {
|
|||
if c.offerrer {
|
||||
offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{})
|
||||
if err != nil {
|
||||
_ = c.closeWithError(xerrors.Errorf("create offer: %w", err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("create offer: %w", err))
|
||||
return
|
||||
}
|
||||
err = c.rtc.SetLocalDescription(offer)
|
||||
if err != nil {
|
||||
_ = c.closeWithError(xerrors.Errorf("set local description: %w", err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
|
||||
return
|
||||
}
|
||||
select {
|
||||
|
@ -261,19 +266,19 @@ func (c *Conn) negotiate() {
|
|||
|
||||
err := c.rtc.SetRemoteDescription(remoteDescription)
|
||||
if err != nil {
|
||||
_ = c.closeWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
|
||||
return
|
||||
}
|
||||
|
||||
if !c.offerrer {
|
||||
answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{})
|
||||
if err != nil {
|
||||
_ = c.closeWithError(xerrors.Errorf("create answer: %w", err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("create answer: %w", err))
|
||||
return
|
||||
}
|
||||
err = c.rtc.SetLocalDescription(answer)
|
||||
if err != nil {
|
||||
_ = c.closeWithError(xerrors.Errorf("set local description: %w", err))
|
||||
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
|
||||
return
|
||||
}
|
||||
if c.isClosed() {
|
||||
|
@ -296,20 +301,20 @@ func (c *Conn) proxyICECandidates() func() {
|
|||
queue = []webrtc.ICECandidateInit{}
|
||||
flushed = false
|
||||
)
|
||||
c.rtc.OnICECandidate(func(i *webrtc.ICECandidate) {
|
||||
if i == nil {
|
||||
c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) {
|
||||
if iceCandidate == nil {
|
||||
return
|
||||
}
|
||||
mut.Lock()
|
||||
defer mut.Unlock()
|
||||
if !flushed {
|
||||
queue = append(queue, i.ToJSON())
|
||||
queue = append(queue, iceCandidate.ToJSON())
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.localCandidateChannel <- i.ToJSON():
|
||||
case c.localCandidateChannel <- iceCandidate.ToJSON():
|
||||
}
|
||||
})
|
||||
return func() {
|
||||
|
@ -353,7 +358,7 @@ func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error {
|
|||
}
|
||||
|
||||
// SetRemoteSessionDescription sets the remote description for the WebRTC connection.
|
||||
func (c *Conn) SetRemoteSessionDescription(s webrtc.SessionDescription) {
|
||||
func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) {
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
|
@ -361,7 +366,7 @@ func (c *Conn) SetRemoteSessionDescription(s webrtc.SessionDescription) {
|
|||
defer c.closeMutex.Unlock()
|
||||
select {
|
||||
case <-c.closed:
|
||||
case c.remoteSessionDescriptionChannel <- s:
|
||||
case c.remoteSessionDescriptionChannel <- sessionDescription:
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -407,7 +412,7 @@ func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOpts)
|
|||
return nil, xerrors.Errorf("closed: %w", c.closeError)
|
||||
}
|
||||
|
||||
dc, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{
|
||||
dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{
|
||||
ID: id,
|
||||
Negotiated: &opts.Negotiated,
|
||||
Ordered: &ordered,
|
||||
|
@ -416,7 +421,7 @@ func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOpts)
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("create data channel: %w", err)
|
||||
}
|
||||
return newChannel(c, dc, opts), nil
|
||||
return newChannel(c, dataChannel, opts), nil
|
||||
}
|
||||
|
||||
// Ping returns the duration it took to round-trip data.
|
||||
|
@ -461,12 +466,7 @@ func (c *Conn) Closed() <-chan struct{} {
|
|||
|
||||
// Close closes the connection and frees all associated resources.
|
||||
func (c *Conn) Close() error {
|
||||
return c.closeWithError(nil)
|
||||
}
|
||||
|
||||
// CloseWithError closes the connection; subsequent reads/writes will return the error err.
|
||||
func (c *Conn) CloseWithError(err error) error {
|
||||
return c.closeWithError(err)
|
||||
return c.CloseWithError(nil)
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
|
@ -478,7 +478,8 @@ func (c *Conn) isClosed() bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) closeWithError(err error) error {
|
||||
// CloseWithError closes the connection; subsequent reads/writes will return the error err.
|
||||
func (c *Conn) CloseWithError(err error) error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
|
|
|
@ -67,13 +67,13 @@ func TestConn(t *testing.T) {
|
|||
_, err := server.Ping()
|
||||
require.NoError(t, err)
|
||||
// Create a channel that closes on disconnect.
|
||||
ch, err := server.Dial(context.Background(), "wow", nil)
|
||||
channel, err := server.Dial(context.Background(), "wow", nil)
|
||||
assert.NoError(t, err)
|
||||
err = wan.Stop()
|
||||
require.NoError(t, err)
|
||||
// Once the connection is marked as disconnected, this
|
||||
// channel will be closed.
|
||||
_, err = ch.Read(make([]byte, 4))
|
||||
_, err = channel.Read(make([]byte, 4))
|
||||
assert.ErrorIs(t, err, peer.ErrClosed)
|
||||
err = wan.Start()
|
||||
require.NoError(t, err)
|
||||
|
@ -154,26 +154,26 @@ func TestConn(t *testing.T) {
|
|||
_, _ = io.Copy(nc2, nc1)
|
||||
}()
|
||||
go func() {
|
||||
s := http.Server{
|
||||
server := http.Server{
|
||||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
}),
|
||||
}
|
||||
defer s.Close()
|
||||
_ = s.Serve(srv)
|
||||
defer server.Close()
|
||||
_ = server.Serve(srv)
|
||||
}()
|
||||
|
||||
dt := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
var cch *peer.Channel
|
||||
dt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
cch, err = client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
|
||||
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
cch, err = client.Dial(ctx, "hello", &peer.ChannelOpts{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cch.NetConn(), nil
|
||||
}
|
||||
c := http.Client{
|
||||
Transport: dt,
|
||||
Transport: defaultTransport,
|
||||
}
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
|
||||
require.NoError(t, err)
|
||||
|
@ -183,7 +183,7 @@ func TestConn(t *testing.T) {
|
|||
require.Equal(t, resp.StatusCode, 200)
|
||||
// Triggers any connections to close.
|
||||
// This test below ensures the DataChannel actually closes.
|
||||
dt.CloseIdleConnections()
|
||||
defaultTransport.CloseIdleConnections()
|
||||
err = cch.Close()
|
||||
require.ErrorIs(t, err, peer.ErrClosed)
|
||||
})
|
||||
|
@ -226,13 +226,13 @@ func TestConn(t *testing.T) {
|
|||
}
|
||||
|
||||
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
|
||||
lf := logging.NewDefaultLoggerFactory()
|
||||
lf.DefaultLogLevel = logging.LogLevelDisabled
|
||||
loggingFactory := logging.NewDefaultLoggerFactory()
|
||||
loggingFactory.DefaultLogLevel = logging.LogLevelDisabled
|
||||
vnetMutex.Lock()
|
||||
defer vnetMutex.Unlock()
|
||||
wan, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
CIDR: "1.2.3.0/24",
|
||||
LoggerFactory: lf,
|
||||
LoggerFactory: loggingFactory,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
c1Net := vnet.NewNet(&vnet.NetConfig{
|
||||
|
@ -250,25 +250,25 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
|
|||
c1SettingEngine.SetVNet(c1Net)
|
||||
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
|
||||
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
||||
c1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{
|
||||
channel1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{
|
||||
SettingEngine: c1SettingEngine,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
c1.Close()
|
||||
channel1.Close()
|
||||
})
|
||||
c2SettingEngine := webrtc.SettingEngine{}
|
||||
c2SettingEngine.SetVNet(c2Net)
|
||||
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
|
||||
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
||||
c2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{
|
||||
channel2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{
|
||||
SettingEngine: c2SettingEngine,
|
||||
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
c2.Close()
|
||||
channel2.Close()
|
||||
})
|
||||
|
||||
err = wan.Start()
|
||||
|
@ -280,11 +280,11 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
|
|||
go func() {
|
||||
for {
|
||||
select {
|
||||
case c := <-c2.LocalCandidate():
|
||||
_ = c1.AddRemoteCandidate(c)
|
||||
case c := <-c2.LocalSessionDescription():
|
||||
c1.SetRemoteSessionDescription(c)
|
||||
case <-c2.Closed():
|
||||
case c := <-channel2.LocalCandidate():
|
||||
_ = channel1.AddRemoteCandidate(c)
|
||||
case c := <-channel2.LocalSessionDescription():
|
||||
channel1.SetRemoteSessionDescription(c)
|
||||
case <-channel2.Closed():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -293,15 +293,15 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
|
|||
go func() {
|
||||
for {
|
||||
select {
|
||||
case c := <-c1.LocalCandidate():
|
||||
_ = c2.AddRemoteCandidate(c)
|
||||
case c := <-c1.LocalSessionDescription():
|
||||
c2.SetRemoteSessionDescription(c)
|
||||
case <-c1.Closed():
|
||||
case c := <-channel1.LocalCandidate():
|
||||
_ = channel2.AddRemoteCandidate(c)
|
||||
case c := <-channel1.LocalSessionDescription():
|
||||
channel2.SetRemoteSessionDescription(c)
|
||||
case <-channel1.Closed():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return c1, c2, wan
|
||||
return channel1, channel2, wan
|
||||
}
|
||||
|
|
|
@ -10,11 +10,11 @@ type peerAddr struct{}
|
|||
// Statically checks if we properly implement net.Addr.
|
||||
var _ net.Addr = &peerAddr{}
|
||||
|
||||
func (a *peerAddr) Network() string {
|
||||
func (*peerAddr) Network() string {
|
||||
return "peer"
|
||||
}
|
||||
|
||||
func (a *peerAddr) String() string {
|
||||
func (*peerAddr) String() string {
|
||||
return "peer/unknown-addr"
|
||||
}
|
||||
|
||||
|
@ -46,14 +46,14 @@ func (c *fakeNetConn) RemoteAddr() net.Addr {
|
|||
return c.addr
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) SetDeadline(_ time.Time) error {
|
||||
func (*fakeNetConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) SetReadDeadline(_ time.Time) error {
|
||||
func (*fakeNetConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) SetWriteDeadline(_ time.Time) error {
|
||||
func (*fakeNetConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -161,7 +161,6 @@ func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_Nego
|
|||
Type: webrtc.SDPType(clientToServerMessage.GetOffer().SdpType),
|
||||
SDP: clientToServerMessage.GetOffer().Sdp,
|
||||
})
|
||||
break
|
||||
case clientToServerMessage.GetServers() != nil:
|
||||
// Convert protobuf ICE servers to the WebRTC type.
|
||||
iceServers := make([]webrtc.ICEServer, 0, len(clientToServerMessage.GetServers().Servers))
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
// Parse extracts Terraform variables from source-code.
|
||||
func (t *terraform) Parse(ctx context.Context, request *proto.Parse_Request) (*proto.Parse_Response, error) {
|
||||
func (*terraform) Parse(_ context.Context, request *proto.Parse_Request) (*proto.Parse_Response, error) {
|
||||
module, diags := tfconfig.LoadModule(request.Directory)
|
||||
if diags.HasErrors() {
|
||||
return nil, xerrors.Errorf("load module: %w", diags.Err())
|
||||
|
|
|
@ -37,7 +37,7 @@ func TestParse(t *testing.T) {
|
|||
}()
|
||||
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
|
||||
|
||||
for _, tc := range []struct {
|
||||
for _, testCase := range []struct {
|
||||
Name string
|
||||
Files map[string]string
|
||||
Response *proto.Parse_Response
|
||||
|
@ -83,13 +83,13 @@ func TestParse(t *testing.T) {
|
|||
}},
|
||||
},
|
||||
}} {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Write all files to the temporary test directory.
|
||||
directory := t.TempDir()
|
||||
for path, content := range tc.Files {
|
||||
for path, content := range testCase.Files {
|
||||
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ func TestParse(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Ensure the want and got are equivalent!
|
||||
want, err := json.Marshal(tc.Response)
|
||||
want, err := json.Marshal(testCase.Response)
|
||||
require.NoError(t, err)
|
||||
got, err := json.Marshal(response)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -15,9 +15,9 @@ import (
|
|||
// Provision executes `terraform apply`.
|
||||
func (t *terraform) Provision(ctx context.Context, request *proto.Provision_Request) (*proto.Provision_Response, error) {
|
||||
statefilePath := filepath.Join(request.Directory, "terraform.tfstate")
|
||||
err := os.WriteFile(statefilePath, request.State, 0644)
|
||||
err := os.WriteFile(statefilePath, request.State, 0600)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("write statefile %q: %w", err)
|
||||
return nil, xerrors.Errorf("write statefile %q: %w", statefilePath, err)
|
||||
}
|
||||
|
||||
terraform, err := tfexec.NewTerraform(request.Directory, t.binaryPath)
|
||||
|
|
|
@ -48,7 +48,7 @@ func TestProvision(t *testing.T) {
|
|||
}()
|
||||
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
|
||||
|
||||
for _, tc := range []struct {
|
||||
for _, testCase := range []struct {
|
||||
Name string
|
||||
Files map[string]string
|
||||
Request *proto.Provision_Request
|
||||
|
@ -93,25 +93,25 @@ func TestProvision(t *testing.T) {
|
|||
},
|
||||
Error: true,
|
||||
}} {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
directory := t.TempDir()
|
||||
for path, content := range tc.Files {
|
||||
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0644)
|
||||
for path, content := range testCase.Files {
|
||||
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
request := &proto.Provision_Request{
|
||||
Directory: directory,
|
||||
}
|
||||
if tc.Request != nil {
|
||||
request.ParameterValues = tc.Request.ParameterValues
|
||||
request.State = tc.Request.State
|
||||
if testCase.Request != nil {
|
||||
request.ParameterValues = testCase.Request.ParameterValues
|
||||
request.State = testCase.Request.State
|
||||
}
|
||||
response, err := api.Provision(ctx, request)
|
||||
if tc.Error {
|
||||
if testCase.Error {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ func TestProvision(t *testing.T) {
|
|||
resourcesGot, err := json.Marshal(response.Resources)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesWant, err := json.Marshal(tc.Response.Resources)
|
||||
resourcesWant, err := json.Marshal(testCase.Response.Resources)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, string(resourcesWant), string(resourcesGot))
|
||||
|
|
|
@ -28,7 +28,7 @@ var site embed.FS
|
|||
|
||||
// Handler returns an HTTP handler for serving the static site.
|
||||
func Handler() http.Handler {
|
||||
f, err := fs.Sub(site, "out")
|
||||
filesystem, err := fs.Sub(site, "out")
|
||||
if err != nil {
|
||||
// This can't happen... Go would throw a compilation error.
|
||||
panic(err)
|
||||
|
@ -36,15 +36,15 @@ func Handler() http.Handler {
|
|||
|
||||
// html files are handled by a text/template. Non-html files
|
||||
// are served by the default file server.
|
||||
files, err := htmlFiles(f)
|
||||
files, err := htmlFiles(filesystem)
|
||||
if err != nil {
|
||||
panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err))
|
||||
}
|
||||
|
||||
return secureHeaders(&handler{
|
||||
fs: f,
|
||||
fs: filesystem,
|
||||
htmlFiles: files,
|
||||
h: http.FileServer(http.FS(f)), // All other non-html static files
|
||||
h: http.FileServer(http.FS(filesystem)), // All other non-html static files
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -61,15 +61,15 @@ type handler struct {
|
|||
}
|
||||
|
||||
// filePath returns the filepath of the requested file.
|
||||
func (h *handler) filePath(p string) string {
|
||||
func (*handler) filePath(p string) string {
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = "/" + p
|
||||
}
|
||||
return strings.TrimPrefix(path.Clean(p), "/")
|
||||
}
|
||||
|
||||
func (h *handler) exists(path string) bool {
|
||||
f, err := h.fs.Open(path)
|
||||
func (h *handler) exists(filePath string) bool {
|
||||
f, err := h.fs.Open(filePath)
|
||||
if err == nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ type csrfState struct {
|
|||
Token string
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
// reqFile is the static file requested
|
||||
reqFile := h.filePath(r.URL.Path)
|
||||
state := htmlState{
|
||||
|
@ -100,13 +100,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// First check if it's a file we have in our templates
|
||||
if h.serveHtml(w, r, reqFile, state) {
|
||||
if h.serveHTML(rw, r, reqFile, state) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the original file path exists we serve it.
|
||||
if h.exists(reqFile) {
|
||||
h.h.ServeHTTP(w, r)
|
||||
h.h.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -117,28 +117,28 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
reqFile = h.filePath(r.URL.Path)
|
||||
// All html files should be served by the htmlFile templates
|
||||
if h.serveHtml(w, r, reqFile, state) {
|
||||
if h.serveHTML(rw, r, reqFile, state) {
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't have the file... we should redirect to `/`
|
||||
// for our single-page-app.
|
||||
r.URL.Path = "/"
|
||||
if h.serveHtml(w, r, "", state) {
|
||||
if h.serveHTML(rw, r, "", state) {
|
||||
return
|
||||
}
|
||||
|
||||
// This will send a correct 404
|
||||
h.h.ServeHTTP(w, r)
|
||||
h.h.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func (h *handler) serveHtml(w http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool {
|
||||
func (h *handler) serveHTML(rw http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool {
|
||||
if data, err := h.htmlFiles.renderWithState(reqPath, state); err == nil {
|
||||
if reqPath == "" {
|
||||
// Pass "index.html" to the ServeContent so the ServeContent sets the right content headers.
|
||||
reqPath = "index.html"
|
||||
}
|
||||
http.ServeContent(w, r, reqPath, time.Time{}, bytes.NewReader(data))
|
||||
http.ServeContent(rw, r, reqPath, time.Time{}, bytes.NewReader(data))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -150,12 +150,12 @@ type htmlTemplates struct {
|
|||
|
||||
// renderWithState will render the file using the given nonce if the file exists
|
||||
// as a template. If it does not, it will return an error.
|
||||
func (t *htmlTemplates) renderWithState(path string, state htmlState) ([]byte, error) {
|
||||
func (t *htmlTemplates) renderWithState(filePath string, state htmlState) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if path == "" {
|
||||
path = "index.html"
|
||||
if filePath == "" {
|
||||
filePath = "index.html"
|
||||
}
|
||||
err := t.tpls.ExecuteTemplate(&buf, path, state)
|
||||
err := t.tpls.ExecuteTemplate(&buf, filePath, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -168,13 +168,6 @@ func (t *htmlTemplates) renderWithState(path string, state htmlState) ([]byte, e
|
|||
// All directives are semi-colon separated as a single string for the csp header.
|
||||
type cspDirectives map[cspFetchDirective][]string
|
||||
|
||||
func (s cspDirectives) append(d cspFetchDirective, values ...string) {
|
||||
if _, ok := s[d]; !ok {
|
||||
s[d] = make([]string, 0)
|
||||
}
|
||||
s[d] = append(s[d], values...)
|
||||
}
|
||||
|
||||
// cspFetchDirective is the list of all constant fetch directives that
|
||||
// can be used/appended to.
|
||||
type cspFetchDirective string
|
||||
|
@ -234,7 +227,7 @@ func secureHeaders(next http.Handler) http.Handler {
|
|||
|
||||
var csp strings.Builder
|
||||
for src, vals := range cspSrcs {
|
||||
fmt.Fprintf(&csp, "%s %s; ", src, strings.Join(vals, " "))
|
||||
_, _ = fmt.Fprintf(&csp, "%s %s; ", src, strings.Join(vals, " "))
|
||||
}
|
||||
|
||||
// Permissions-Policy can be used to disabled various browser features that we do not use.
|
||||
|
@ -280,16 +273,16 @@ func htmlFiles(files fs.FS) (*htmlTemplates, error) {
|
|||
root := template.New("")
|
||||
|
||||
rootPath := "."
|
||||
err := fs.WalkDir(files, rootPath, func(path string, d fs.DirEntry, err error) error {
|
||||
err := fs.WalkDir(files, rootPath, func(path string, dirEntry fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
if dirEntry.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if filepath.Ext(d.Name()) != ".html" {
|
||||
if filepath.Ext(dirEntry.Name()) != ".html" {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package site
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
|
@ -13,8 +15,11 @@ func TestIndexPageRenders(t *testing.T) {
|
|||
|
||||
srv := httptest.NewServer(Handler())
|
||||
|
||||
resp, err := srv.Client().Get(srv.URL)
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err, "get index")
|
||||
defer resp.Body.Close()
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
require.NotEmpty(t, data, "index should have contents")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue