feat: Add organizations endpoint for users (#50)

* feat: Add organizations endpoint for users

This moves the /user endpoint to /users/me instead. This
will reduce code duplication.

This adds /users/<name>/organizations to list organizations
a user has access to. It doesn't contain the permissions a
user has over the organizations, but that will come in a future
contribution.

* Fix requested changes

* Fix tests

* Fix timeout

* Add test for UserOrgs

* Add test for userparam getting

* Add test for NoUser
This commit is contained in:
Kyle Carberry 2022-01-22 23:58:10 -06:00 committed by GitHub
parent 50d8151995
commit 8be245616a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 597 additions and 244 deletions

View File

@ -23,5 +23,7 @@ ignore:
# This is generated code.
- database/models.go
- database/query.sql.go
# All coderd tests fail if this doesn't work.
- database/databasefake
- peerbroker/proto
- provisionersdk/proto

View File

@ -31,15 +31,18 @@ func New(options *Options) http.Handler {
Message: "👋",
})
})
r.Post("/user", users.createInitialUser)
r.Post("/login", users.loginWithPassword)
// Require an API key and authenticated user for this group.
r.Group(func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
httpmw.ExtractUser(options.Database),
)
r.Get("/user", users.authenticatedUser)
r.Route("/users", func(r chi.Router) {
r.Post("/", users.createInitialUser)
r.Group(func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
httpmw.ExtractUserParam(options.Database),
)
r.Get("/{user}", users.user)
r.Get("/{user}/organizations", users.userOrganizations)
})
})
})
r.NotFound(site.Handler().ServeHTTP)

View File

@ -13,6 +13,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/database/postgres"
@ -27,7 +28,37 @@ type Server struct {
URL *url.URL
}
// New constructs a new coderd test instance.
// RandomInitialUser generates a random initial user and authenticates
// it with the client on the Server struct.
func (s *Server) RandomInitialUser(t *testing.T) coderd.CreateInitialUserRequest {
username, err := cryptorand.String(12)
require.NoError(t, err)
password, err := cryptorand.String(12)
require.NoError(t, err)
organization, err := cryptorand.String(12)
require.NoError(t, err)
req := coderd.CreateInitialUserRequest{
Email: "testuser@coder.com",
Username: username,
Password: password,
Organization: organization,
}
_, err = s.Client.CreateInitialUser(context.Background(), req)
require.NoError(t, err)
login, err := s.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: "testuser@coder.com",
Password: password,
})
require.NoError(t, err)
err = s.Client.SetSessionToken(login.SessionToken)
require.NoError(t, err)
return req
}
// New constructs a new coderd test instance. This returned Server
// should contain no side-effects.
func New(t *testing.T) Server {
// This can be hotswapped for a live database instance.
db := databasefake.New()
@ -54,24 +85,8 @@ func New(t *testing.T) Server {
require.NoError(t, err)
t.Cleanup(srv.Close)
client := codersdk.New(serverURL)
_, err = client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "testuser@coder.com",
Username: "testuser",
Password: "testpassword",
})
require.NoError(t, err)
login, err := client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: "testuser@coder.com",
Password: "testpassword",
})
require.NoError(t, err)
err = client.SetSessionToken(login.SessionToken)
require.NoError(t, err)
return Server{
Client: client,
Client: codersdk.New(serverURL),
URL: serverURL,
}
}

View File

@ -13,5 +13,6 @@ func TestMain(m *testing.M) {
}
func TestNew(t *testing.T) {
_ = coderdtest.New(t)
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
}

25
coderd/organizations.go Normal file
View File

@ -0,0 +1,25 @@
package coderd
import (
"time"
"github.com/coder/coder/database"
)
// Organization is the JSON representation of a Coder organization.
type Organization struct {
ID string `json:"id" validate:"required"`
Name string `json:"name" validate:"required"`
CreatedAt time.Time `json:"created_at" validate:"required"`
UpdatedAt time.Time `json:"updated_at" validate:"required"`
}
// convertOrganization consumes the database representation and outputs an API friendly representation.
func convertOrganization(organization database.Organization) Organization {
return Organization{
ID: organization.ID,
Name: organization.Name,
CreatedAt: organization.CreatedAt,
UpdatedAt: organization.UpdatedAt,
}
}

View File

@ -1,7 +1,6 @@
package coderd
import (
"context"
"crypto/sha256"
"database/sql"
"errors"
@ -11,6 +10,7 @@ import (
"github.com/go-chi/render"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/userpassword"
"github.com/coder/coder/cryptorand"
@ -27,11 +27,12 @@ type User struct {
Username string `json:"username" validate:"required"`
}
// CreateUserRequest enables callers to create a new user.
type CreateUserRequest struct {
Email string `json:"email" validate:"required,email"`
Username string `json:"username" validate:"required,username"`
Password string `json:"password" validate:"required"`
// CreateInitialUserRequest enables callers to create a new user.
type CreateInitialUserRequest struct {
Email string `json:"email" validate:"required,email"`
Username string `json:"username" validate:"required,username"`
Password string `json:"password" validate:"required"`
Organization string `json:"organization" validate:"required,username"`
}
// LoginWithPasswordRequest enables callers to authenticate with email and password.
@ -51,7 +52,7 @@ type users struct {
// Creates the initial user for a Coder deployment.
func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
var createUser CreateUserRequest
var createUser CreateInitialUserRequest
if !httpapi.Read(rw, r, &createUser) {
return
}
@ -70,19 +71,6 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
})
return
}
_, err = users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: createUser.Email,
Username: createUser.Username,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user: %s", err.Error()),
})
return
}
hashedPassword, err := userpassword.Hash(createUser.Password)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@ -91,28 +79,57 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) {
return
}
user, err := users.Database.InsertUser(context.Background(), database.InsertUserParams{
ID: uuid.NewString(),
Email: createUser.Email,
HashedPassword: []byte(hashedPassword),
Username: createUser.Username,
LoginType: database.LoginTypeBuiltIn,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
// Create the user, organization, and membership to the user.
var user database.User
err = users.Database.InTx(func(s database.Store) error {
user, err = users.Database.InsertUser(r.Context(), database.InsertUserParams{
ID: uuid.NewString(),
Email: createUser.Email,
HashedPassword: []byte(hashedPassword),
Username: createUser.Username,
LoginType: database.LoginTypeBuiltIn,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
return xerrors.Errorf("create user: %w", err)
}
organization, err := users.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.NewString(),
Name: createUser.Organization,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
return xerrors.Errorf("create organization: %w", err)
}
_, err = users.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Roles: []string{"organization-admin"},
})
if err != nil {
return xerrors.Errorf("create organization member: %w", err)
}
return nil
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("create user: %s", err.Error()),
Message: err.Error(),
})
return
}
render.Status(r, http.StatusCreated)
render.JSON(rw, r, user)
}
// Returns the currently authenticated user.
func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) {
user := httpmw.User(r)
// Returns the parameterized user requested. All validation
// is completed in the middleware for this route.
func (*users) user(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)
render.JSON(rw, r, User{
ID: user.ID,
@ -122,6 +139,27 @@ func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) {
})
}
// Returns organizations the parameterized user has access to.
func (users *users) userOrganizations(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)
organizations, err := users.Database.GetOrganizationsByUserID(r.Context(), user.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get organizations: %s", err.Error()),
})
return
}
publicOrganizations := make([]Organization, 0, len(organizations))
for _, organization := range organizations {
publicOrganizations = append(publicOrganizations, convertOrganization(organization))
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, publicOrganizations)
}
// Authenticates the user with an email and password.
func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) {
var loginWithPassword LoginWithPasswordRequest

View File

@ -16,6 +16,7 @@ func TestUsers(t *testing.T) {
t.Run("Authenticated", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
_, err := server.Client.User(context.Background(), "")
require.NoError(t, err)
})
@ -23,15 +24,27 @@ func TestUsers(t *testing.T) {
t.Run("CreateMultipleInitial", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "dummy@coder.com",
Username: "fake",
Password: "password",
_ = server.RandomInitialUser(t)
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{
Email: "dummy@coder.com",
Organization: "bananas",
Username: "fake",
Password: "password",
})
require.Error(t, err)
})
t.Run("LoginNoEmail", func(t *testing.T) {
t.Run("Login", func(t *testing.T) {
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: user.Email,
Password: user.Password,
})
require.NoError(t, err)
})
t.Run("LoginInvalidUser", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
@ -44,13 +57,20 @@ func TestUsers(t *testing.T) {
t.Run("LoginBadPassword", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user, err := server.Client.User(context.Background(), "")
require.NoError(t, err)
_, err = server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
user := server.RandomInitialUser(t)
_, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
Email: user.Email,
Password: "bananas",
})
require.Error(t, err)
})
t.Run("ListOrganizations", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
orgs, err := server.Client.UserOrganizations(context.Background(), "")
require.NoError(t, err)
require.Len(t, orgs, 1)
})
}

View File

@ -3,6 +3,7 @@ package codersdk
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/coder/coder/coderd"
@ -11,8 +12,8 @@ import (
// CreateInitialUser attempts to create the first user on a Coder deployment.
// This initial user has superadmin privileges. If >0 users exist, this request
// will fail.
func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserRequest) (coderd.User, error) {
res, err := c.request(ctx, http.MethodPost, "/api/v2/user", req)
func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateInitialUserRequest) (coderd.User, error) {
res, err := c.request(ctx, http.MethodPost, "/api/v2/users", req)
if err != nil {
return coderd.User{}, err
}
@ -24,21 +25,6 @@ func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserReq
return user, json.NewDecoder(res.Body).Decode(&user)
}
// User returns a user for the ID provided.
// If the ID string is empty, the current user will be returned.
func (c *Client) User(ctx context.Context, _ string) (coderd.User, error) {
res, err := c.request(ctx, http.MethodGet, "/api/v2/user", nil)
if err != nil {
return coderd.User{}, err
}
defer res.Body.Close()
if res.StatusCode > http.StatusOK {
return coderd.User{}, readBodyAsError(res)
}
var user coderd.User
return user, json.NewDecoder(res.Body).Decode(&user)
}
// LoginWithPassword creates a session token authenticating with an email and password.
// Call `SetSessionToken()` to apply the newly acquired token to the client.
func (c *Client) LoginWithPassword(ctx context.Context, req coderd.LoginWithPasswordRequest) (coderd.LoginWithPasswordResponse, error) {
@ -57,3 +43,38 @@ func (c *Client) LoginWithPassword(ctx context.Context, req coderd.LoginWithPass
}
return resp, nil
}
// User returns a user for the ID provided.
// If the ID string is empty, the current user will be returned.
func (c *Client) User(ctx context.Context, id string) (coderd.User, error) {
if id == "" {
id = "me"
}
res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s", id), nil)
if err != nil {
return coderd.User{}, err
}
defer res.Body.Close()
if res.StatusCode > http.StatusOK {
return coderd.User{}, readBodyAsError(res)
}
var user coderd.User
return user, json.NewDecoder(res.Body).Decode(&user)
}
// UserOrganizations fetches organizations a user is part of.
func (c *Client) UserOrganizations(ctx context.Context, id string) ([]coderd.Organization, error) {
if id == "" {
id = "me"
}
res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/organizations", id), nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
var orgs []coderd.Organization
return orgs, json.NewDecoder(res.Body).Decode(&orgs)
}

View File

@ -2,33 +2,44 @@ package codersdk_test
import (
"context"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
)
func TestUsers(t *testing.T) {
t.Run("MultipleInitial", func(t *testing.T) {
t.Run("CreateInitial", func(t *testing.T) {
server := coderdtest.New(t)
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{
Email: "wowie@coder.com",
Username: "tester",
Password: "moo",
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{
Email: "wowie@coder.com",
Organization: "somethin",
Username: "tester",
Password: "moo",
})
var cerr *codersdk.Error
require.ErrorAs(t, err, &cerr)
require.Equal(t, cerr.StatusCode(), http.StatusConflict)
require.Greater(t, len(cerr.Error()), 0)
require.NoError(t, err)
})
t.Run("Get", func(t *testing.T) {
t.Run("NoUser", func(t *testing.T) {
server := coderdtest.New(t)
_, err := server.Client.User(context.Background(), "")
require.Error(t, err)
})
t.Run("User", func(t *testing.T) {
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
_, err := server.Client.User(context.Background(), "")
require.NoError(t, err)
})
t.Run("UserOrganizations", func(t *testing.T) {
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
orgs, err := server.Client.UserOrganizations(context.Background(), "")
require.NoError(t, err)
require.Len(t, orgs, 1)
})
}

View File

@ -10,15 +10,19 @@ import (
// New returns an in-memory fake of the database.
func New() database.Store {
return &fakeQuerier{
apiKeys: make([]database.APIKey, 0),
users: make([]database.User, 0),
apiKeys: make([]database.APIKey, 0),
organizations: make([]database.Organization, 0),
organizationMembers: make([]database.OrganizationMember, 0),
users: make([]database.User, 0),
}
}
// fakeQuerier replicates database functionality to enable quick testing.
type fakeQuerier struct {
apiKeys []database.APIKey
users []database.User
apiKeys []database.APIKey
organizations []database.Organization
organizationMembers []database.OrganizationMember
users []database.User
}
// InTx doesn't rollback data properly for in-memory yet.
@ -57,6 +61,34 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) {
return int64(len(q.users)), nil
}
func (q *fakeQuerier) GetOrganizationByName(_ context.Context, name string) (database.Organization, error) {
for _, organization := range q.organizations {
if organization.Name == name {
return organization, nil
}
}
return database.Organization{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string) ([]database.Organization, error) {
organizations := make([]database.Organization, 0)
for _, organizationMember := range q.organizationMembers {
if organizationMember.UserID != userID {
continue
}
for _, organization := range q.organizations {
if organization.ID != organizationMember.OrganizationID {
continue
}
organizations = append(organizations, organization)
}
}
if len(organizations) == 0 {
return nil, sql.ErrNoRows
}
return organizations, nil
}
func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
//nolint:gosimple
key := database.APIKey{
@ -80,6 +112,30 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
return key, nil
}
func (q *fakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
organization := database.Organization{
ID: arg.ID,
Name: arg.Name,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
}
q.organizations = append(q.organizations, organization)
return organization, nil
}
func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) {
//nolint:gosimple
organizationMember := database.OrganizationMember{
OrganizationID: arg.OrganizationID,
UserID: arg.UserID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
Roles: arg.Roles,
}
q.organizationMembers = append(q.organizationMembers, organizationMember)
return organizationMember, nil
}
func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) {
user := database.User{
ID: arg.ID,

View File

@ -8,10 +8,14 @@ import (
type querier interface {
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error)
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id string) (User, error)
GetUserCount(ctx context.Context) (int64, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
}

View File

@ -41,6 +41,14 @@ SELECT
FROM
users;
-- name: GetOrganizationByName :one
SELECT * FROM organizations WHERE name = $1 LIMIT 1;
-- name: GetOrganizationsByUserID :many
SELECT * FROM organizations WHERE id = (
SELECT organization_id FROM organization_members WHERE user_id = $1
);
-- name: InsertAPIKey :one
INSERT INTO
api_keys (
@ -79,6 +87,12 @@ VALUES
$15
) RETURNING *;
-- name: InsertOrganization :one
INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertOrganizationMember :one
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertUser :one
INSERT INTO
users (

View File

@ -44,6 +44,68 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
return i, err
}
const getOrganizationByName = `-- name: GetOrganizationByName :one
SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE name = $1 LIMIT 1
`
func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Organization, error) {
row := q.db.QueryRowContext(ctx, getOrganizationByName, name)
var i Organization
err := row.Scan(
&i.ID,
&i.Name,
&i.Description,
&i.CreatedAt,
&i.UpdatedAt,
&i.Default,
&i.AutoOffThreshold,
&i.CpuProvisioningRate,
&i.MemoryProvisioningRate,
&i.WorkspaceAutoOff,
)
return i, err
}
const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many
SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE id = (
SELECT organization_id FROM organization_members WHERE user_id = $1
)
`
func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) {
rows, err := q.db.QueryContext(ctx, getOrganizationsByUserID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Organization
for rows.Next() {
var i Organization
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Description,
&i.CreatedAt,
&i.UpdatedAt,
&i.Default,
&i.AutoOffThreshold,
&i.CpuProvisioningRate,
&i.MemoryProvisioningRate,
&i.WorkspaceAutoOff,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
SELECT
id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell
@ -236,6 +298,73 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
return i, err
}
const insertOrganization = `-- name: InsertOrganization :one
INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
`
type InsertOrganizationParams struct {
ID string `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
func (q *sqlQuerier) InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error) {
row := q.db.QueryRowContext(ctx, insertOrganization,
arg.ID,
arg.Name,
arg.Description,
arg.CreatedAt,
arg.UpdatedAt,
)
var i Organization
err := row.Scan(
&i.ID,
&i.Name,
&i.Description,
&i.CreatedAt,
&i.UpdatedAt,
&i.Default,
&i.AutoOffThreshold,
&i.CpuProvisioningRate,
&i.MemoryProvisioningRate,
&i.WorkspaceAutoOff,
)
return i, err
}
const insertOrganizationMember = `-- name: InsertOrganizationMember :one
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles
`
type InsertOrganizationMemberParams struct {
OrganizationID string `db:"organization_id" json:"organization_id"`
UserID string `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Roles []string `db:"roles" json:"roles"`
}
func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error) {
row := q.db.QueryRowContext(ctx, insertOrganizationMember,
arg.OrganizationID,
arg.UserID,
arg.CreatedAt,
arg.UpdatedAt,
pq.Array(arg.Roles),
)
var i OrganizationMember
err := row.Scan(
&i.OrganizationID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
pq.Array(&i.Roles),
)
return i, err
}
const insertUser = `-- name: InsertUser :one
INSERT INTO
users (

View File

@ -1,51 +0,0 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
type userContextKey struct{}
// User returns the user from the ExtractUser handler.
func User(r *http.Request) database.User {
user, ok := r.Context().Value(userContextKey{}).(database.User)
if !ok {
panic("developer error: user middleware not provided")
}
return user
}
// ExtractUser consumes an API key and queries the user attached to it.
// It attaches the user to the request context.
func ExtractUser(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// The user handler depends on API Key to get the authenticated user.
apiKey := APIKey(r)
user, err := db.GetUserByID(r.Context(), apiKey.UserID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "user not found for api key",
})
return
}
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("couldn't fetch user for api key: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), userContextKey{}, user)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -1,89 +0,0 @@
package httpmw_test
import (
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/httpmw"
)
func TestUser(t *testing.T) {
t.Run("NoUser", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: "bananas",
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
httpmw.ExtractUser(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
})
t.Run("User", func(t *testing.T) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: "testing",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
httpmw.ExtractUser(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Makes sure the context properly adds the User!
_ = httpmw.User(r)
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
})
}

54
httpmw/userparam.go Normal file
View File

@ -0,0 +1,54 @@
package httpmw
import (
"context"
"fmt"
"net/http"
"github.com/go-chi/chi"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
type userParamContextKey struct{}
// UserParam returns the user from the ExtractUserParam handler.
func UserParam(r *http.Request) database.User {
user, ok := r.Context().Value(userParamContextKey{}).(database.User)
if !ok {
panic("developer error: user parameter middleware not provided")
}
return user
}
// ExtractUserParam extracts a user from an ID/username in the {user} URL parameter.
func ExtractUserParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "user")
if userID == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "user id or name must be provided",
})
return
}
if userID != "me" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "getting non-personal users isn't supported yet",
})
return
}
apiKey := APIKey(r)
user, err := db.GetUserByID(r.Context(), apiKey.UserID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user: %s", err.Error()),
})
}
ctx := context.WithValue(r.Context(), userParamContextKey{}, user)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

100
httpmw/userparam_test.go Normal file
View File

@ -0,0 +1,100 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/stretchr/testify/require"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/httpmw"
)
func TestUserParam(t *testing.T) {
setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
_, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: "bananas",
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: "bananas",
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
return db, rw, r
}
t.Run("None", func(t *testing.T) {
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotMe", func(t *testing.T) {
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
routeContext := chi.NewRouteContext()
routeContext.URLParams.Add("user", "ben")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("Me", func(t *testing.T) {
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
routeContext := chi.NewRouteContext()
routeContext.URLParams.Add("user", "me")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.UserParam(r)
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -23,7 +23,7 @@ import (
)
const (
disconnectedTimeout = time.Second
disconnectedTimeout = 5 * time.Second
failedTimeout = disconnectedTimeout * 5
keepAliveInterval = time.Millisecond * 2
)

View File

@ -87,7 +87,7 @@ export const SignInForm: React.FC<SignInProps> = ({
try {
await loginHandler(email, password)
// Tell SWR to invalidate the cache for the user endpoint
await mutate("/api/v2/user")
await mutate("/api/v2/users/me")
await router.push("/")
} catch (err) {
helpers.setFieldError("password", "The username or password is incorrect.")

View File

@ -38,7 +38,7 @@ export const useUser = (redirectOnError = false): UserContext => {
}
export const UserProvider: React.FC = (props) => {
const { data, error } = useSWR("/api/v2/user")
const { data, error } = useSWR("/api/v2/users/me")
return (
<UserContext.Provider