feat: Add authentication and personal user endpoint (#29)

* feat: Add authentication and personal user endpoint

This contribution adds a lot of scaffolding for the database fake
and testability of coderd.

A new endpoint "/user" is added to return the currently authenticated
user to the requester.

* Use TestMain to catch leak instead

* Add userpassword package

* Add WIP

* Add user auth

* Fix test

* Add comments

* Fix login response

* Fix order

* Fix generated code

* Update httpapi/httpapi.go

Co-authored-by: Bryan <bryan@coder.com>

Co-authored-by: Bryan <bryan@coder.com>
This commit is contained in:
Kyle Carberry 2022-01-20 07:46:51 -06:00 committed by GitHub
parent 36b7b20e2a
commit 6a919aea79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 2232 additions and 61 deletions

View File

@ -27,7 +27,14 @@ else
endif
.PHONY: fmt/prettier
fmt: fmt/prettier
fmt/sql:
npx sql-formatter \
--language postgresql \
--lines-between-queries 2 \
./database/query.sql \
--output ./database/query.sql
fmt: fmt/prettier fmt/sql
.PHONY: fmt
gen: database/generate peerbroker/proto provisionersdk/proto

View File

@ -21,5 +21,7 @@ coverage:
ignore:
# This is generated code.
- database/models.go
- database/query.sql.go
- peerbroker/proto
- provisionersdk/proto

View File

@ -5,12 +5,13 @@ import (
"net/http"
"os"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/coderd"
"github.com/coder/coder/database"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/database/databasefake"
)
func Root() *cobra.Command {
@ -22,7 +23,7 @@ func Root() *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error {
handler := coderd.New(&coderd.Options{
Logger: slog.Make(sloghuman.Sink(os.Stderr)),
Database: database.NewInMemory(),
Database: databasefake.New(),
})
listener, err := net.Listen("tcp", address)

View File

@ -4,8 +4,9 @@ import (
"context"
"testing"
"github.com/coder/coder/coderd/cmd"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/cmd"
)
func TestRoot(t *testing.T) {

View File

@ -3,11 +3,13 @@ package coderd
import (
"net/http"
"github.com/go-chi/chi"
"cdr.dev/slog"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
"github.com/coder/coder/site"
"github.com/go-chi/chi"
"github.com/go-chi/render"
)
// Options are requires parameters for Coder to start.
@ -18,15 +20,27 @@ type Options struct {
// New constructs the Coder API into an HTTP handler.
func New(options *Options) http.Handler {
users := &users{
Database: options.Database,
}
r := chi.NewRouter()
r.Route("/api/v2", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, struct {
Message string `json:"message"`
}{
httpapi.Write(w, http.StatusOK, httpapi.Response{
Message: "👋",
})
})
r.Post("/user", users.createInitialUser)
r.Post("/login", users.loginWithPassword)
// Require an API key and authenticated user for this group.
r.Group(func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
httpmw.ExtractUser(options.Database),
)
r.Get("/user", users.getAuthenticatedUser)
})
})
r.NotFound(site.Handler().ServeHTTP)
return r

View File

@ -0,0 +1,59 @@
package coderdtest
import (
"context"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database/databasefake"
)
// Server represents a test instance of coderd.
// The database is intentionally omitted from
// this struct to promote data being exposed via
// the API.
type Server struct {
Client *codersdk.Client
URL *url.URL
}
// New constructs a new coderd test instance.
func New(t *testing.T) Server {
// This can be hotswapped for a live database instance.
db := databasefake.New()
handler := coderd.New(&coderd.Options{
Logger: slogtest.Make(t, nil),
Database: db,
})
srv := httptest.NewServer(handler)
u, err := url.Parse(srv.URL)
require.NoError(t, err)
t.Cleanup(srv.Close)
client := codersdk.New(u)
_, err = client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "testuser@coder.com",
Username: "testuser",
Password: "testpassword",
})
require.NoError(t, err)
login, err := client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: "testuser@coder.com",
Password: "testpassword",
})
require.NoError(t, err)
err = client.SetSessionToken(login.SessionToken)
require.NoError(t, err)
return Server{
Client: client,
URL: u,
}
}

View File

@ -0,0 +1,17 @@
package coderdtest_test
import (
"testing"
"go.uber.org/goleak"
"github.com/coder/coder/coderd/coderdtest"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestNew(t *testing.T) {
_ = coderdtest.New(t)
}

View File

@ -0,0 +1,78 @@
package userpassword
import (
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"
"strconv"
"strings"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/xerrors"
)
const (
// This is the length of our output hash.
// bcrypt has a hash size of 59, so we rounded up to a power of 8.
hashLength = 64
// The scheme to include in our hashed password.
hashScheme = "pbkdf2-sha256"
)
// Compare checks the equality of passwords from a hashed pbkdf2 string.
// This uses pbkdf2 to ensure FIPS 140-2 compliance. See:
// https://csrc.nist.gov/csrc/media/projects/cryptographic-module-validation-program/documents/security-policies/140sp2261.pdf
func Compare(hashed string, password string) (bool, error) {
if len(hashed) < hashLength {
return false, xerrors.Errorf("hash too short: %d", len(hashed))
}
parts := strings.SplitN(hashed, "$", 5)
if len(parts) != 5 {
return false, xerrors.Errorf("hash has too many parts: %d", len(parts))
}
if len(parts[0]) != 0 {
return false, xerrors.Errorf("hash prefix is invalid")
}
if string(parts[1]) != hashScheme {
return false, xerrors.Errorf("hash isn't %q scheme: %q", hashScheme, parts[1])
}
iter, err := strconv.Atoi(string(parts[2]))
if err != nil {
return false, xerrors.Errorf("parse iter from hash: %w", err)
}
salt, err := base64.RawStdEncoding.DecodeString(string(parts[3]))
if err != nil {
return false, xerrors.Errorf("decode salt: %w", err)
}
if subtle.ConstantTimeCompare([]byte(hashWithSaltAndIter(password, salt, iter)), []byte(hashed)) != 1 {
return false, nil
}
return true, nil
}
// Hash generates a hash using pbkdf2.
// See the Compare() comment for rationale.
func Hash(password string) (string, error) {
// bcrypt uses a salt size of 16 bytes.
salt := make([]byte, 16)
_, err := rand.Read(salt)
if err != nil {
return "", xerrors.Errorf("read random bytes for salt: %w", err)
}
// The default hash iteration is 1024 for speed.
// As this is increased, the password is hashed more.
return hashWithSaltAndIter(password, salt, 1024), nil
}
// Produces a string representation of the hash.
func hashWithSaltAndIter(password string, salt []byte, iter int) string {
hash := pbkdf2.Key([]byte(password), salt, iter, hashLength, sha256.New)
hash = []byte(base64.RawStdEncoding.EncodeToString(hash))
salt = []byte(base64.RawStdEncoding.EncodeToString(salt))
// This format is similar to bcrypt. See:
// https://en.wikipedia.org/wiki/Bcrypt#Description
return fmt.Sprintf("$%s$%d$%s$%s", hashScheme, iter, salt, hash)
}

View File

@ -0,0 +1,47 @@
package userpassword_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/userpassword"
)
func TestUserPassword(t *testing.T) {
t.Run("Legacy", func(t *testing.T) {
// Ensures legacy v1 passwords function for v2.
// This has is manually generated using a print statement from v1 code.
equal, err := userpassword.Compare("$pbkdf2-sha256$65535$z8c1p1C2ru9EImBP1I+ZNA$pNjE3Yk0oG0PmJ0Je+y7ENOVlSkn/b0BEqqdKsq6Y97wQBq0xT+lD5bWJpyIKJqQICuPZcEaGDKrXJn8+SIHRg", "tomato")
require.NoError(t, err)
require.True(t, equal)
})
t.Run("Same", func(t *testing.T) {
hash, err := userpassword.Hash("password")
require.NoError(t, err)
equal, err := userpassword.Compare(hash, "password")
require.NoError(t, err)
require.True(t, equal)
})
t.Run("Different", func(t *testing.T) {
hash, err := userpassword.Hash("password")
require.NoError(t, err)
equal, err := userpassword.Compare(hash, "notpassword")
require.NoError(t, err)
require.False(t, equal)
})
t.Run("Invalid", func(t *testing.T) {
equal, err := userpassword.Compare("invalidhash", "password")
require.False(t, equal)
require.Error(t, err)
})
t.Run("InvalidParts", func(t *testing.T) {
equal, err := userpassword.Compare("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "test")
require.False(t, equal)
require.Error(t, err)
})
}

209
coderd/users.go Normal file
View File

@ -0,0 +1,209 @@
package coderd
import (
"context"
"crypto/sha256"
"database/sql"
"errors"
"fmt"
"net/http"
"time"
"github.com/go-chi/render"
"github.com/google/uuid"
"github.com/coder/coder/coderd/userpassword"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
// User is the JSON representation of a Coder user.
type User struct {
ID string `json:"id" validate:"required"`
Email string `json:"email" validate:"required"`
CreatedAt time.Time `json:"created_at" validate:"required"`
Username string `json:"username" validate:"required"`
}
// CreateUserRequest enables callers to create a new user.
type CreateUserRequest struct {
Email string `json:"email" validate:"required,email"`
Username string `json:"username" validate:"required,username"`
Password string `json:"password" validate:"required"`
}
// LoginWithPasswordRequest enables callers to authenticate with email and password.
type LoginWithPasswordRequest struct {
Email string `json:"email" validate:"required,email"`
Password string `json:"password" validate:"required"`
}
// LoginWithPasswordResponse contains a session token for the newly authenticated user.
type LoginWithPasswordResponse struct {
SessionToken string `json:"session_token" validate:"required"`
}
type users struct {
Database database.Store
}
// Creates the initial user for a Coder deployment.
func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
var createUser CreateUserRequest
if !httpapi.Read(rw, r, &createUser) {
return
}
// This should only function for the first user.
userCount, err := users.Database.GetUserCount(r.Context())
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user count: %s", err.Error()),
})
return
}
// If a user already exists, the initial admin user no longer can be created.
if userCount != 0 {
httpapi.Write(rw, http.StatusConflict, httpapi.Response{
Message: "the initial user has already been created",
})
return
}
user, err := users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: createUser.Email,
Username: createUser.Username,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user: %s", err.Error()),
})
return
}
hashedPassword, err := userpassword.Hash(createUser.Password)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("hash password: %s", err.Error()),
})
return
}
user, err = users.Database.InsertUser(context.Background(), database.InsertUserParams{
ID: uuid.NewString(),
Email: createUser.Email,
HashedPassword: []byte(hashedPassword),
Username: createUser.Username,
LoginType: database.LoginTypeBuiltIn,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("create user: %s", err.Error()),
})
return
}
render.Status(r, http.StatusCreated)
render.JSON(rw, r, user)
}
// Returns the currently authenticated user.
func (users *users) getAuthenticatedUser(rw http.ResponseWriter, r *http.Request) {
user := httpmw.User(r)
render.JSON(rw, r, User{
ID: user.ID,
Email: user.Email,
CreatedAt: user.CreatedAt,
Username: user.Username,
})
}
// Authenticates the user with an email and password.
func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
var loginWithPassword LoginWithPasswordRequest
if !httpapi.Read(rw, r, &loginWithPassword) {
return
}
user, err := users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: loginWithPassword.Email,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "invalid email or password",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user: %s", err.Error()),
})
return
}
equal, err := userpassword.Compare(string(user.HashedPassword), loginWithPassword.Password)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("compare: %s", err.Error()),
})
}
if !equal {
// This message is the same as above to remove ease in detecting whether
// users are registered or not. Attackers still could with a timing attack.
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "invalid email or password",
})
return
}
id, secret, err := generateAPIKeyIDSecret()
hashed := sha256.Sum256([]byte(secret))
_, err = users.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
ExpiresAt: database.Now().Add(24 * time.Hour),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
HashedSecret: hashed[:],
LoginType: database.LoginTypeBuiltIn,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("insert api key: %s", err.Error()),
})
return
}
// This format is consumed by the APIKey middleware.
sessionToken := fmt.Sprintf("%s-%s", id, secret)
http.SetCookie(rw, &http.Cookie{
Name: httpmw.AuthCookie,
Value: sessionToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
render.Status(r, http.StatusCreated)
render.JSON(rw, r, LoginWithPasswordResponse{
SessionToken: sessionToken,
})
}
// Generates a new ID and secret for an API key.
func generateAPIKeyIDSecret() (string, string, error) {
// Length of an API Key ID.
id, err := cryptorand.String(10)
if err != nil {
return "", "", err
}
// Length of an API Key secret.
secret, err := cryptorand.String(22)
if err != nil {
return "", "", err
}
return id, secret, nil
}

55
coderd/users_test.go Normal file
View File

@ -0,0 +1,55 @@
package coderd_test
import (
"context"
"testing"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/stretchr/testify/require"
)
func TestUsers(t *testing.T) {
t.Parallel()
t.Run("Authenticated", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.User(context.Background(), "")
require.NoError(t, err)
})
t.Run("CreateMultipleInitial", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "dummy@coder.com",
Username: "fake",
Password: "password",
})
require.Error(t, err)
})
t.Run("LoginNoEmail", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: "hello@io.io",
Password: "wowie",
})
require.Error(t, err)
})
t.Run("LoginBadPassword", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user, err := server.Client.User(context.Background(), "")
require.NoError(t, err)
_, err = server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: user.Email,
Password: "bananas",
})
require.Error(t, err)
})
}

116
codersdk/client.go Normal file
View File

@ -0,0 +1,116 @@
package codersdk
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"golang.org/x/xerrors"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
// New creates a Coder client for the provided URL.
func New(url *url.URL) *Client {
return &Client{
url: url,
httpClient: &http.Client{},
}
}
// Client is an HTTP caller for methods to the Coder API.
type Client struct {
url *url.URL
httpClient *http.Client
}
// SetSessionToken applies the provided token to the current client.
func (c *Client) SetSessionToken(token string) error {
if c.httpClient.Jar == nil {
var err error
c.httpClient.Jar, err = cookiejar.New(nil)
if err != nil {
return err
}
}
c.httpClient.Jar.SetCookies(c.url, []*http.Cookie{{
Name: httpmw.AuthCookie,
Value: token,
}})
return nil
}
// request performs an HTTP request with the body provided.
// The caller is responsible for closing the response body.
func (c *Client) request(ctx context.Context, method, path string, body interface{}) (*http.Response, error) {
url, err := c.url.Parse(path)
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
var buf bytes.Buffer
if body != nil {
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
if err != nil {
return nil, xerrors.Errorf("encode body: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, method, url.String(), &buf)
if err != nil {
return nil, xerrors.Errorf("create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, xerrors.Errorf("do: %w", err)
}
return resp, err
}
// readBodyAsError reads the response as an httpapi.Message, and
// wraps it in a codersdk.Error type for easy marshalling.
func readBodyAsError(res *http.Response) error {
var m httpapi.Response
err := json.NewDecoder(res.Body).Decode(&m)
if err != nil {
if errors.Is(err, io.EOF) {
// If no body is sent, we'll just provide the status code.
return &Error{
statusCode: res.StatusCode,
}
}
return xerrors.Errorf("decode body: %w", err)
}
return &Error{
Response: m,
statusCode: res.StatusCode,
}
}
// Error represents an unaccepted or invalid request to the API.
type Error struct {
httpapi.Response
statusCode int
}
func (e *Error) StatusCode() int {
return e.statusCode
}
func (e *Error) Error() string {
return fmt.Sprintf("status code %d: %s", e.statusCode, e.Message)
}

59
codersdk/users.go Normal file
View File

@ -0,0 +1,59 @@
package codersdk
import (
"context"
"encoding/json"
"net/http"
"github.com/coder/coder/coderd"
)
// CreateInitialUser attempts to create the first user on a Coder deployment.
// This initial user has superadmin privileges. If >0 users exist, this request
// will fail.
func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserRequest) (coderd.User, error) {
res, err := c.request(ctx, http.MethodPost, "/api/v2/user", req)
if err != nil {
return coderd.User{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return coderd.User{}, readBodyAsError(res)
}
var user coderd.User
return user, json.NewDecoder(res.Body).Decode(&user)
}
// User returns a user for the ID provided.
// If the ID string is empty, the current user will be returned.
func (c *Client) User(ctx context.Context, id string) (coderd.User, error) {
res, err := c.request(ctx, http.MethodGet, "/api/v2/user", nil)
if err != nil {
return coderd.User{}, err
}
defer res.Body.Close()
if res.StatusCode > http.StatusOK {
return coderd.User{}, readBodyAsError(res)
}
var user coderd.User
return user, json.NewDecoder(res.Body).Decode(&user)
}
// LoginWithPassword creates a session token authenticating with an email and password.
// Call `SetSessionToken()` to apply the newly acquired token to the client.
func (c *Client) LoginWithPassword(ctx context.Context, req coderd.LoginWithPasswordRequest) (coderd.LoginWithPasswordResponse, error) {
res, err := c.request(ctx, http.MethodPost, "/api/v2/login", req)
if err != nil {
return coderd.LoginWithPasswordResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return coderd.LoginWithPasswordResponse{}, readBodyAsError(res)
}
var resp coderd.LoginWithPasswordResponse
err = json.NewDecoder(res.Body).Decode(&resp)
if err != nil {
return coderd.LoginWithPasswordResponse{}, err
}
return resp, nil
}

33
codersdk/users_test.go Normal file
View File

@ -0,0 +1,33 @@
package codersdk_test
import (
"context"
"net/http"
"testing"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/stretchr/testify/require"
)
func TestUsers(t *testing.T) {
t.Run("MultipleInitial", func(t *testing.T) {
server := coderdtest.New(t)
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "wowie@coder.com",
Username: "tester",
Password: "moo",
})
var cerr *codersdk.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, cerr.StatusCode(), http.StatusConflict)
require.Greater(t, len(cerr.Error()), 0)
})
t.Run("Get", func(t *testing.T) {
server := coderdtest.New(t)
_, err := server.Client.User(context.Background(), "")
require.NoError(t, err)
})
}

View File

@ -0,0 +1,111 @@
package databasefake
import (
"context"
"database/sql"
"github.com/coder/coder/database"
)
// New returns an in-memory fake of the database.
func New() database.Store {
return &fakeQuerier{
apiKeys: make([]database.APIKey, 0),
users: make([]database.User, 0),
}
}
// fakeQuerier replicates database functionality to enable quick testing.
type fakeQuerier struct {
apiKeys []database.APIKey
users []database.User
}
// InTx doesn't rollback data properly for in-memory yet.
func (q *fakeQuerier) InTx(ctx context.Context, fn func(database.Store) error) error {
return fn(q)
}
func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
for _, apiKey := range q.apiKeys {
if apiKey.ID == id {
return apiKey, nil
}
}
return database.APIKey{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) {
for _, user := range q.users {
if user.Email == arg.Email || user.Username == arg.Username {
return user, nil
}
}
return database.User{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User, error) {
for _, user := range q.users {
if user.ID == id {
return user, nil
}
}
return database.User{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetUserCount(ctx context.Context) (int64, error) {
return int64(len(q.users)), nil
}
func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
key := database.APIKey{
ID: arg.ID,
HashedSecret: arg.HashedSecret,
UserID: arg.UserID,
Application: arg.Application,
Name: arg.Name,
LastUsed: arg.LastUsed,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LoginType: arg.LoginType,
OIDCAccessToken: arg.OIDCAccessToken,
OIDCRefreshToken: arg.OIDCRefreshToken,
OIDCIDToken: arg.OIDCIDToken,
OIDCExpiry: arg.OIDCExpiry,
DevurlToken: arg.DevurlToken,
}
q.apiKeys = append(q.apiKeys, key)
return key, nil
}
func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) {
user := database.User{
ID: arg.ID,
Email: arg.Email,
Name: arg.Name,
LoginType: arg.LoginType,
HashedPassword: arg.HashedPassword,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
Username: arg.Username,
}
q.users = append(q.users, user)
return user, nil
}
func (q *fakeQuerier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
for index, apiKey := range q.apiKeys {
if apiKey.ID != arg.ID {
continue
}
apiKey.LastUsed = arg.LastUsed
apiKey.ExpiresAt = arg.ExpiresAt
apiKey.OIDCAccessToken = arg.OIDCAccessToken
apiKey.OIDCRefreshToken = arg.OIDCRefreshToken
apiKey.OIDCExpiry = arg.OIDCExpiry
q.apiKeys[index] = apiKey
return nil
}
return sql.ErrNoRows
}

View File

@ -1,19 +0,0 @@
package database
import "context"
// NewInMemory returns an in-memory store of the database.
func NewInMemory() Store {
return &memoryQuerier{}
}
type memoryQuerier struct{}
// InTx doesn't rollback data properly for in-memory yet.
func (q *memoryQuerier) InTx(ctx context.Context, fn func(Store) error) error {
return fn(q)
}
func (q *memoryQuerier) ExampleQuery(ctx context.Context) error {
return nil
}

View File

@ -13,7 +13,7 @@ type LoginType string
const (
LoginTypeBuiltIn LoginType = "built-in"
LoginTypeSaml LoginType = "saml"
LoginTypeOidc LoginType = "oidc"
LoginTypeOIDC LoginType = "oidc"
)
func (e *LoginType) Scan(src interface{}) error {
@ -48,7 +48,7 @@ func (e *UserStatus) Scan(src interface{}) error {
return nil
}
type ApiKey struct {
type APIKey struct {
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID string `db:"user_id" json:"user_id"`
@ -59,10 +59,10 @@ type ApiKey struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OidcAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OidcRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OidcIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
OidcExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
}

View File

@ -7,7 +7,13 @@ import (
)
type querier interface {
ExampleQuery(ctx context.Context) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id string) (User, error)
GetUserCount(ctx context.Context) (int64, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
}
var _ querier = (*sqlQuerier)(nil)

View File

@ -1,2 +1,107 @@
-- name: ExampleQuery :exec
SELECT 'example query';
-- Database queries are generated using sqlc. See:
-- https://docs.sqlc.dev/en/latest/tutorials/getting-started-postgresql.html
--
-- Run "make gen" to generate models and query functions.
;
-- name: GetAPIKeyByID :one
SELECT
*
FROM
api_keys
WHERE
id = $1
LIMIT
1;
-- name: GetUserByID :one
SELECT
*
FROM
users
WHERE
id = $1
LIMIT
1;
-- name: GetUserByEmailOrUsername :one
SELECT
*
FROM
users
WHERE
username = $1
OR email = $2
LIMIT
1;
-- name: GetUserCount :one
SELECT
COUNT(*)
FROM
users;
-- name: InsertAPIKey :one
INSERT INTO
api_keys (
id,
hashed_secret,
user_id,
application,
name,
last_used,
expires_at,
created_at,
updated_at,
login_type,
oidc_access_token,
oidc_refresh_token,
oidc_id_token,
oidc_expiry,
devurl_token
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING *;
-- name: InsertUser :one
INSERT INTO
users (
id,
email,
name,
login_type,
hashed_password,
created_at,
updated_at,
username
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
-- name: UpdateAPIKeyByID :exec
UPDATE
api_keys
SET
last_used = $2,
expires_at = $3,
oidc_access_token = $4,
oidc_refresh_token = $5,
oidc_expiry = $6
WHERE
id = $1;

View File

@ -5,13 +5,330 @@ package database
import (
"context"
"time"
"github.com/lib/pq"
)
const exampleQuery = `-- name: ExampleQuery :exec
SELECT 'example query'
const getAPIKeyByID = `-- name: GetAPIKeyByID :one
SELECT
id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
FROM
api_keys
WHERE
id = $1
LIMIT
1
`
func (q *sqlQuerier) ExampleQuery(ctx context.Context) error {
_, err := q.db.ExecContext(ctx, exampleQuery)
func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) {
row := q.db.QueryRowContext(ctx, getAPIKeyByID, id)
var i APIKey
err := row.Scan(
&i.ID,
&i.HashedSecret,
&i.UserID,
&i.Application,
&i.Name,
&i.LastUsed,
&i.ExpiresAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OIDCAccessToken,
&i.OIDCRefreshToken,
&i.OIDCIDToken,
&i.OIDCExpiry,
&i.DevurlToken,
)
return i, err
}
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
SELECT
id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell
FROM
users
WHERE
username = $1
OR email = $2
LIMIT
1
`
type GetUserByEmailOrUsernameParams struct {
Username string `db:"username" json:"username"`
Email string `db:"email" json:"email"`
}
func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) {
row := q.db.QueryRowContext(ctx, getUserByEmailOrUsername, arg.Username, arg.Email)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.Name,
&i.Revoked,
&i.LoginType,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.TemporaryPassword,
&i.AvatarHash,
&i.SshKeyRegeneratedAt,
&i.Username,
&i.DotfilesGitUri,
pq.Array(&i.Roles),
&i.Status,
&i.Relatime,
&i.GpgKeyRegeneratedAt,
&i.Decomissioned,
&i.Shell,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT
id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell
FROM
users
WHERE
id = $1
LIMIT
1
`
func (q *sqlQuerier) GetUserByID(ctx context.Context, id string) (User, error) {
row := q.db.QueryRowContext(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.Name,
&i.Revoked,
&i.LoginType,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.TemporaryPassword,
&i.AvatarHash,
&i.SshKeyRegeneratedAt,
&i.Username,
&i.DotfilesGitUri,
pq.Array(&i.Roles),
&i.Status,
&i.Relatime,
&i.GpgKeyRegeneratedAt,
&i.Decomissioned,
&i.Shell,
)
return i, err
}
const getUserCount = `-- name: GetUserCount :one
SELECT
COUNT(*)
FROM
users
`
func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) {
row := q.db.QueryRowContext(ctx, getUserCount)
var count int64
err := row.Scan(&count)
return count, err
}
const insertAPIKey = `-- name: InsertAPIKey :one
INSERT INTO
api_keys (
id,
hashed_secret,
user_id,
application,
name,
last_used,
expires_at,
created_at,
updated_at,
login_type,
oidc_access_token,
oidc_refresh_token,
oidc_id_token,
oidc_expiry,
devurl_token
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
`
type InsertAPIKeyParams struct {
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID string `db:"user_id" json:"user_id"`
Application bool `db:"application" json:"application"`
Name string `db:"name" json:"name"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
}
func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) {
row := q.db.QueryRowContext(ctx, insertAPIKey,
arg.ID,
arg.HashedSecret,
arg.UserID,
arg.Application,
arg.Name,
arg.LastUsed,
arg.ExpiresAt,
arg.CreatedAt,
arg.UpdatedAt,
arg.LoginType,
arg.OIDCAccessToken,
arg.OIDCRefreshToken,
arg.OIDCIDToken,
arg.OIDCExpiry,
arg.DevurlToken,
)
var i APIKey
err := row.Scan(
&i.ID,
&i.HashedSecret,
&i.UserID,
&i.Application,
&i.Name,
&i.LastUsed,
&i.ExpiresAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OIDCAccessToken,
&i.OIDCRefreshToken,
&i.OIDCIDToken,
&i.OIDCExpiry,
&i.DevurlToken,
)
return i, err
}
const insertUser = `-- name: InsertUser :one
INSERT INTO
users (
id,
email,
name,
login_type,
hashed_password,
created_at,
updated_at,
username
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell
`
type InsertUserParams struct {
ID string `db:"id" json:"id"`
Email string `db:"email" json:"email"`
Name string `db:"name" json:"name"`
LoginType LoginType `db:"login_type" json:"login_type"`
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Username string `db:"username" json:"username"`
}
func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
row := q.db.QueryRowContext(ctx, insertUser,
arg.ID,
arg.Email,
arg.Name,
arg.LoginType,
arg.HashedPassword,
arg.CreatedAt,
arg.UpdatedAt,
arg.Username,
)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.Name,
&i.Revoked,
&i.LoginType,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.TemporaryPassword,
&i.AvatarHash,
&i.SshKeyRegeneratedAt,
&i.Username,
&i.DotfilesGitUri,
pq.Array(&i.Roles),
&i.Status,
&i.Relatime,
&i.GpgKeyRegeneratedAt,
&i.Decomissioned,
&i.Shell,
)
return i, err
}
const updateAPIKeyByID = `-- name: UpdateAPIKeyByID :exec
UPDATE
api_keys
SET
last_used = $2,
expires_at = $3,
oidc_access_token = $4,
oidc_refresh_token = $5,
oidc_expiry = $6
WHERE
id = $1
`
type UpdateAPIKeyByIDParams struct {
ID string `db:"id" json:"id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
}
func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error {
_, err := q.db.ExecContext(ctx, updateAPIKeyByID,
arg.ID,
arg.LastUsed,
arg.ExpiresAt,
arg.OIDCAccessToken,
arg.OIDCRefreshToken,
arg.OIDCExpiry,
)
return err
}

View File

@ -19,4 +19,10 @@ overrides:
- db_type: citext
go_type: string
rename:
api_key: APIKey
login_type_oidc: LoginTypeOIDC
oidc_access_token: OIDCAccessToken
oidc_expiry: OIDCExpiry
oidc_id_token: OIDCIDToken
oidc_refresh_token: OIDCRefreshToken
userstatus: UserStatus

8
database/time.go Normal file
View File

@ -0,0 +1,8 @@
package database
import "time"
// Now returns a standardized timezone used for database resources.
func Now() time.Time {
return time.Now().UTC()
}

10
go.mod
View File

@ -12,6 +12,7 @@ require (
cdr.dev/slog v1.4.1
github.com/go-chi/chi v1.5.4
github.com/go-chi/render v1.0.1
github.com/go-playground/validator/v10 v10.10.0
github.com/golang-migrate/migrate/v4 v4.15.1
github.com/google/uuid v1.3.0
github.com/hashicorp/go-version v1.3.0
@ -30,6 +31,8 @@ require (
github.com/unrolled/secure v1.0.9
go.uber.org/atomic v1.7.0
go.uber.org/goleak v1.1.12
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
google.golang.org/protobuf v1.27.1
storj.io/drpc v0.0.26
@ -44,6 +47,7 @@ require (
github.com/alecthomas/chroma v0.9.1 // indirect
github.com/apparentlymart/go-textseg v1.0.0 // indirect
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/aws/aws-sdk-go v1.34.28 // indirect
github.com/cenkalti/backoff/v4 v4.1.2 // indirect
github.com/containerd/continuity v0.1.0 // indirect
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
@ -55,8 +59,11 @@ require (
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/fatih/color v1.13.0 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.6 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
@ -67,6 +74,7 @@ require (
github.com/hashicorp/terraform-json v0.13.0 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mitchellh/go-wordwrap v1.0.0 // indirect
@ -98,12 +106,12 @@ require (
github.com/zclconf/go-cty v1.9.1 // indirect
github.com/zeebo/errs v1.2.2 // indirect
go.opencensus.io v0.23.0 // indirect
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect
golang.org/x/sys v0.0.0-20211210111614-af8b64212486 // indirect
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/api v0.63.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/grpc v1.43.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

24
go.sum
View File

@ -153,8 +153,9 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkY
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.15.78/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3ATZkfNZeM=
github.com/aws/aws-sdk-go v1.17.7 h1:/4+rDPe0W95KBmNGYCG+NUvdL8ssPYBMxL+aSCg6nIA=
github.com/aws/aws-sdk-go v1.17.7/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.34.28 h1:sscPpn/Ns3i0F4HPEWAVcwdIRaZZCuL7llJ2/60yPIk=
github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48=
github.com/aws/aws-sdk-go-v2 v1.8.0/go.mod h1:xEFuWz+3TYdlPRuo+CqATbeDWIWyaT5uAPwPaWtgse0=
github.com/aws/aws-sdk-go-v2 v1.9.2/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
github.com/aws/aws-sdk-go-v2/config v1.6.0/go.mod h1:TNtBVmka80lRPk5+S9ZqVfFszOQAGJJ9KbT3EM3CHNU=
@ -468,6 +469,14 @@ github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL9
github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo=
github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho=
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
@ -802,8 +811,9 @@ github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
@ -818,6 +828,8 @@ github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88
github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
@ -1020,6 +1032,7 @@ github.com/pion/webrtc/v3 v3.1.13 h1:2XxgGstOqt03ba8QD5+m9S8DCA3Ez53mULT4If8onOg
github.com/pion/webrtc/v3 v3.1.13/go.mod h1:RACpyE1EDYlzonfbdPvXkIGDaqD8+NsHqZJN0yEbRbA=
github.com/pkg/browser v0.0.0-20210706143420-7d21f8c997e2/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -1069,6 +1082,9 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
@ -1266,11 +1282,13 @@ golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce h1:Roh6XWxHFKrPgC/EQhVubSAGQ6Ozk6IdxHSzt1mR0EI=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=

114
httpapi/httpapi.go Normal file
View File

@ -0,0 +1,114 @@
package httpapi
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var (
validate *validator.Validate
usernameRegex = regexp.MustCompile("^[a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*$")
)
// This init is used to create a validator and register validation-specific
// functionality for the HTTP API.
//
// A single validator instance is used, because it caches struct parsing.
func init() {
validate = validator.New()
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
f := fl.Field().Interface()
str, ok := f.(string)
if !ok {
return false
}
if len(str) > 32 {
return false
}
if len(str) < 1 {
return false
}
return usernameRegex.MatchString(str)
})
}
// Response represents a generic HTTP response.
type Response struct {
Message string `json:"message" validate:"required"`
Errors []Error `json:"errors,omitempty" validate:"required"`
}
// Error represents a scoped error to a user input.
type Error struct {
Field string `json:"field" validate:"required"`
Code string `json:"code" validate:"required"`
}
// Write outputs a standardized format to an HTTP response body.
func Write(w http.ResponseWriter, status int, response Response) {
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(true)
err := enc.Encode(response)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
_, err = w.Write(buf.Bytes())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// Read decodes JSON from the HTTP request into the value provided.
// It uses go-validator to validate the incoming request body.
func Read(rw http.ResponseWriter, r *http.Request, value interface{}) bool {
err := json.NewDecoder(r.Body).Decode(value)
if err != nil {
Write(rw, http.StatusBadRequest, Response{
Message: fmt.Sprintf("read body: %s", err.Error()),
})
return false
}
err = validate.Struct(value)
var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
apiErrors := make([]Error, 0, len(validationErrors))
for _, validationError := range validationErrors {
apiErrors = append(apiErrors, Error{
Field: validationError.Field(),
Code: validationError.Tag(),
})
}
Write(rw, http.StatusBadRequest, Response{
Message: "Validation failed",
Errors: apiErrors,
})
return false
}
if err != nil {
Write(rw, http.StatusInternalServerError, Response{
Message: fmt.Sprintf("validation: %s", err.Error()),
})
return false
}
return true
}

134
httpapi/httpapi_test.go Normal file
View File

@ -0,0 +1,134 @@
package httpapi_test
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/httpapi"
)
func TestWrite(t *testing.T) {
t.Run("NoErrors", func(t *testing.T) {
rw := httptest.NewRecorder()
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "wow",
})
var m map[string]interface{}
err := json.NewDecoder(rw.Body).Decode(&m)
require.NoError(t, err)
_, ok := m["errors"]
require.False(t, ok)
})
}
func TestRead(t *testing.T) {
t.Run("EmptyStruct", func(t *testing.T) {
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
v := struct{}{}
require.True(t, httpapi.Read(rw, r, &v))
})
t.Run("NoBody", func(t *testing.T) {
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", nil)
var v json.RawMessage
require.False(t, httpapi.Read(rw, r, v))
})
t.Run("Validate", func(t *testing.T) {
type toValidate struct {
Value string `json:"value" validate:"required"`
}
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString(`{"value":"hi"}`))
var validate toValidate
require.True(t, httpapi.Read(rw, r, &validate))
require.Equal(t, validate.Value, "hi")
})
t.Run("ValidateFailure", func(t *testing.T) {
type toValidate struct {
Value string `json:"value" validate:"required"`
}
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
var validate toValidate
require.False(t, httpapi.Read(rw, r, &validate))
var v httpapi.Response
err := json.NewDecoder(rw.Body).Decode(&v)
require.NoError(t, err)
require.Len(t, v.Errors, 1)
require.Equal(t, v.Errors[0].Field, "value")
require.Equal(t, v.Errors[0].Code, "required")
})
}
func TestReadUsername(t *testing.T) {
// Tests whether usernames are valid or not.
testCases := []struct {
Username string
Valid bool
}{
{"1", true},
{"12", true},
{"123", true},
{"12345678901234567890", true},
{"123456789012345678901", true},
{"a", true},
{"a1", true},
{"a1b2", true},
{"a1b2c3d4e5f6g7h8i9j0", true},
{"a1b2c3d4e5f6g7h8i9j0k", true},
{"aa", true},
{"abc", true},
{"abcdefghijklmnopqrst", true},
{"abcdefghijklmnopqrstu", true},
{"wow-test", true},
{"", false},
{" ", false},
{" a", false},
{" a ", false},
{" 1", false},
{"1 ", false},
{" aa", false},
{"aa ", false},
{" 12", false},
{"12 ", false},
{" a1", false},
{"a1 ", false},
{" abcdefghijklmnopqrstu", false},
{"abcdefghijklmnopqrstu ", false},
{" 123456789012345678901", false},
{" a1b2c3d4e5f6g7h8i9j0k", false},
{"a1b2c3d4e5f6g7h8i9j0k ", false},
{"bananas_wow", false},
{"test--now", false},
{"123456789012345678901234567890123", false},
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
{"123456789012345678901234567890123123456789012345678901234567890123", false},
}
type toValidate struct {
Username string `json:"username" validate:"username"`
}
for _, testCase := range testCases {
t.Run(testCase.Username, func(t *testing.T) {
rw := httptest.NewRecorder()
data, err := json.Marshal(toValidate{testCase.Username})
require.NoError(t, err)
r := httptest.NewRequest("POST", "/", bytes.NewBuffer(data))
var validate toValidate
require.Equal(t, httpapi.Read(rw, r, &validate), testCase.Valid)
})
}
}

167
httpmw/apikey.go Normal file
View File

@ -0,0 +1,167 @@
package httpmw
import (
"context"
"crypto/sha256"
"crypto/subtle"
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/oauth2"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
// AuthCookie represents the name of the cookie the API key is stored in.
const AuthCookie = "session_token"
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
// It is abstracted for simple testing.
type OAuth2Config interface {
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
}
type apiKeyContextKey struct{}
// APIKey returns the API key from the ExtractAPIKey handler.
func APIKey(r *http.Request) database.APIKey {
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
if !ok {
panic("developer error: apikey middleware not provided")
}
return apiKey
}
// ExtractAPIKey requires authentication using a valid API key.
// It handles extending an API key if it comes close to expiry,
// updating the last used time in the database.
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(AuthCookie)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("%q cookie must be provided", AuthCookie),
})
return
}
parts := strings.Split(cookie.Value, "-")
// APIKeys are formatted: ID-SECRET
if len(parts) != 2 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key format", AuthCookie),
})
return
}
id := parts[0]
secret := parts[1]
// Ensuring key lengths are valid.
if len(id) != 10 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key id", AuthCookie),
})
return
}
if len(secret) != 22 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key secret", AuthCookie),
})
return
}
key, err := db.GetAPIKeyByID(r.Context(), id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "api key is invalid",
})
return
}
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get api key by id: %s", err.Error()),
})
return
}
hashed := sha256.Sum256([]byte(secret))
// Checking to see if the secret is valid.
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "api key secret is invalid",
})
return
}
now := database.Now()
// Tracks if the API key has properties updated!
changed := false
if key.LoginType == database.LoginTypeOIDC {
// Check if the OIDC token is expired!
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
// If it is, let's refresh it from the provided config!
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: key.OIDCAccessToken,
RefreshToken: key.OIDCRefreshToken,
Expiry: key.OIDCExpiry,
}).Token()
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("couldn't refresh expired oauth token: %s", err.Error()),
})
return
}
key.OIDCAccessToken = token.AccessToken
key.OIDCRefreshToken = token.RefreshToken
key.OIDCExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
}
}
// Checking if the key is expired.
if key.ExpiresAt.Before(now) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("api key expired at %q", key.ExpiresAt.String()),
})
return
}
// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now
changed = true
}
// Only update the ExpiresAt once an hour to prevent database spam.
// We extend the ExpiresAt to reduce reauthentication.
apiKeyLifetime := 24 * time.Hour
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
if changed {
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
ID: key.ID,
ExpiresAt: key.ExpiresAt,
LastUsed: key.LastUsed,
OIDCAccessToken: key.OIDCAccessToken,
OIDCRefreshToken: key.OIDCRefreshToken,
OIDCExpiry: key.OIDCExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("api key couldn't update: %s", err.Error()),
})
return
}
}
ctx := context.WithValue(r.Context(), apiKeyContextKey{}, key)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

338
httpmw/apikey_test.go Normal file
View File

@ -0,0 +1,338 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
func randomAPIKeyParts() (string, string) {
id, _ := cryptorand.String(10)
secret, _ := cryptorand.String(22)
return id, secret
}
func TestAPIKey(t *testing.T) {
t.Parallel()
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Only called if the API key passes through the handler.
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "it worked!",
})
})
t.Run("NoCookie", func(t *testing.T) {
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("InvalidFormat", func(t *testing.T) {
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: "test-wow-hello",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("InvalidIDLength", func(t *testing.T) {
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: "test-wow",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("InvalidSecretLength", func(t *testing.T) {
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: "testtestid-wow",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("InvalidSecret", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
// Use a different secret so they don't match!
hashed := sha256.Sum256([]byte("differentsecret"))
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("Expired", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode)
})
t.Run("Valid", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Checks that it exists on the context!
_ = httpmw.APIKey(r)
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "it worked!",
})
})).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.NotEqual(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("ValidUpdateExpiry", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCNotExpired", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCRefresh", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LastUsed: database.Now(),
OIDCExpiry: database.Now().AddDate(0, 0, -1),
})
require.NoError(t, err)
token := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: database.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKey(db, &oauth2Config{
tokenSource: &oauth2TokenSource{
token: func() (*oauth2.Token, error) {
return token, nil
},
},
})(successHandler).ServeHTTP(rw, r)
require.Equal(t, http.StatusOK, rw.Result().StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
})
}
type oauth2Config struct {
tokenSource *oauth2TokenSource
}
func (o *oauth2Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
type oauth2TokenSource struct {
token func() (*oauth2.Token, error)
}
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
return o.token()
}

51
httpmw/user.go Normal file
View File

@ -0,0 +1,51 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
type userContextKey struct{}
// User returns the user from the ExtractUser handler.
func User(r *http.Request) database.User {
user, ok := r.Context().Value(userContextKey{}).(database.User)
if !ok {
panic("developer error: user middleware not provided")
}
return user
}
// ExtractUser consumes an API key and queries the user attached to it.
// It attaches the user to the request context.
func ExtractUser(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// The user handler depends on API Key to get the authenticated user.
apiKey := APIKey(r)
user, err := db.GetUserByID(r.Context(), apiKey.UserID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "user not found for api key",
})
return
}
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("couldn't fetch user for api key: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), userContextKey{}, user)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

89
httpmw/user_test.go Normal file
View File

@ -0,0 +1,89 @@
package httpmw_test
import (
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/httpmw"
)
func TestUser(t *testing.T) {
t.Run("NoUser", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: "bananas",
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
httpmw.ExtractUser(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
})
t.Run("User", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: "testing",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
httpmw.ExtractUser(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Makes sure the context properly adds the User!
_ = httpmw.User(r)
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
})
}

View File

@ -9,7 +9,7 @@
"dev": "ts-node site/dev.ts",
"export": "next export site",
"format:check": "prettier --check '**/*.{css,html,js,json,jsx,md,ts,tsx,yaml,yml}'",
"format:write": "prettier --write '**/*.{css,html,js,json,jsx,md,ts,tsx,yaml,yml}'",
"format:write": "prettier --write '**/*.{css,html,js,json,jsx,md,ts,tsx,yaml,yml}' && sql-formatter -l postgresql ./database/query.sql -o ./database/query.sql",
"test": "jest --selectProjects test",
"test:coverage": "jest --selectProjects test --collectCoverage"
},
@ -31,6 +31,7 @@
"prettier": "2.5.1",
"react": "17.0.2",
"react-dom": "17.0.2",
"sql-formatter": "^4.0.2",
"ts-jest": "27.1.2",
"ts-loader": "9.2.6",
"ts-node": "10.4.0",

View File

@ -4,13 +4,14 @@ import (
"context"
"testing"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"storj.io/drpc/drpcconn"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func TestMain(m *testing.M) {

View File

@ -5,11 +5,12 @@ import (
"io"
"testing"
"github.com/stretchr/testify/require"
"storj.io/drpc/drpcconn"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/stretchr/testify/require"
"storj.io/drpc/drpcconn"
)
func TestListen(t *testing.T) {

View File

@ -5,9 +5,10 @@ import (
"encoding/json"
"os"
"github.com/coder/coder/provisionersdk/proto"
"github.com/hashicorp/terraform-config-inspect/tfconfig"
"golang.org/x/xerrors"
"github.com/coder/coder/provisionersdk/proto"
)
// Parse extracts Terraform variables from source-code.

View File

@ -9,10 +9,11 @@ import (
"path/filepath"
"testing"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
"github.com/stretchr/testify/require"
"storj.io/drpc/drpcconn"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
)
func TestParse(t *testing.T) {
@ -89,7 +90,7 @@ func TestParse(t *testing.T) {
// Write all files to the temporary test directory.
directory := t.TempDir()
for path, content := range tc.Files {
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0644)
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
require.NoError(t, err)
}

View File

@ -6,9 +6,10 @@ import (
"os"
"path/filepath"
"github.com/coder/coder/provisionersdk/proto"
"github.com/hashicorp/terraform-exec/tfexec"
"golang.org/x/xerrors"
"github.com/coder/coder/provisionersdk/proto"
)
// Provision executes `terraform apply`.

View File

@ -9,12 +9,13 @@ import (
"path/filepath"
"testing"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
"github.com/hashicorp/go-version"
"github.com/stretchr/testify/require"
"storj.io/drpc/drpcconn"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
"github.com/hashicorp/hc-install/product"
"github.com/hashicorp/hc-install/releases"
)

View File

@ -4,9 +4,10 @@ import (
"context"
"os/exec"
"github.com/coder/coder/provisionersdk"
"github.com/hashicorp/go-version"
"golang.org/x/xerrors"
"github.com/coder/coder/provisionersdk"
)
var (

View File

@ -1102,6 +1102,11 @@ argparse@^1.0.7:
dependencies:
sprintf-js "~1.0.2"
argparse@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38"
integrity sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==
aria-query@^5.0.0:
version "5.0.0"
resolved "https://registry.yarnpkg.com/aria-query/-/aria-query-5.0.0.tgz#210c21aaf469613ee8c9a62c7f86525e058db52c"
@ -4205,6 +4210,13 @@ sprintf-js@~1.0.2:
resolved "https://registry.yarnpkg.com/sprintf-js/-/sprintf-js-1.0.3.tgz#04e6926f662895354f3dd015203633b857297e2c"
integrity sha1-BOaSb2YolTVPPdAVIDYzuFcpfiw=
sql-formatter@^4.0.2:
version "4.0.2"
resolved "https://registry.yarnpkg.com/sql-formatter/-/sql-formatter-4.0.2.tgz#2b359e5a4c611498d327b9659da7329d71724607"
integrity sha512-R6u9GJRiXZLr/lDo8p56L+OyyN2QFJPCDnsyEOsbdIpsnDKL8gubYFo7lNR7Zx7hfdWT80SfkoVS0CMaF/DE2w==
dependencies:
argparse "^2.0.1"
stack-utils@^2.0.3:
version "2.0.5"
resolved "https://registry.yarnpkg.com/stack-utils/-/stack-utils-2.0.5.tgz#d25265fca995154659dbbfba3b49254778d2fdd5"