feat: Add project API endpoints (#51)

* feat: Add project models

* Add project query functions

* Add organization parameter query

* Add project URL parameter parse

* Add project create and list endpoints

* Add test for organization provided

* Remove unimplemented routes

* Decrease conn timeout

* Add test for UnbiasedModulo32

* Fix expected value

* Add single user endpoint

* Add query for project versions

* Fix linting errors

* Add comments

* Add test for invalid archive

* Check unauthenticated endpoints

* Add check if no change happened

* Ensure context close ends listener

* Fix parallel test run

* Test empty

* Fix organization param comment
This commit is contained in:
Kyle Carberry 2022-01-24 11:07:42 -06:00 committed by GitHub
parent 52e50fc9ca
commit a44056cff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 2121 additions and 22 deletions

View File

@ -234,6 +234,7 @@ linters:
- misspell
- nilnil
- noctx
- paralleltest
- revive
- rowserrcheck
- sqlclosecheck

View File

@ -11,7 +11,7 @@ database/dump.sql: $(wildcard database/migrations/*.sql)
go run database/dump/main.go
# Generates Go code for querying the database.
database/generate: database/dump.sql database/query.sql
database/generate: fmt/sql database/dump.sql database/query.sql
cd database && sqlc generate && rm db_tmp.go
cd database && gofmt -w -r 'Querier -> querier' *.go
cd database && gofmt -w -r 'Queries -> sqlQuerier' *.go
@ -27,12 +27,13 @@ else
endif
.PHONY: fmt/prettier
fmt/sql:
fmt/sql: ./database/query.sql
npx sql-formatter \
--language postgresql \
--lines-between-queries 2 \
./database/query.sql \
--output ./database/query.sql
sed -i 's/@ /@/g' ./database/query.sql
fmt: fmt/prettier fmt/sql
.PHONY: fmt

View File

@ -10,6 +10,7 @@ import (
)
func TestRoot(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
go cancelFunc()
err := cmd.Root().ExecuteContext(ctx)

View File

@ -20,6 +20,9 @@ type Options struct {
// New constructs the Coder API into an HTTP handler.
func New(options *Options) http.Handler {
projects := &projects{
Database: options.Database,
}
users := &users{
Database: options.Database,
}
@ -44,6 +47,25 @@ func New(options *Options) http.Handler {
r.Get("/{user}/organizations", users.userOrganizations)
})
})
r.Route("/projects", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
)
r.Get("/", projects.allProjects)
r.Route("/{organization}", func(r chi.Router) {
r.Use(httpmw.ExtractOrganizationParam(options.Database))
r.Get("/", projects.allProjectsForOrganization)
r.Post("/", projects.createProject)
r.Route("/{project}", func(r chi.Router) {
r.Use(httpmw.ExtractProjectParameter(options.Database))
r.Get("/", projects.project)
r.Route("/versions", func(r chi.Router) {
r.Get("/", projects.projectVersions)
r.Post("/", projects.createProjectVersion)
})
})
})
})
})
r.NotFound(site.Handler().ServeHTTP)
return r

View File

@ -13,6 +13,7 @@ func TestMain(m *testing.M) {
}
func TestNew(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
}

229
coderd/projects.go Normal file
View File

@ -0,0 +1,229 @@
package coderd
import (
"archive/tar"
"bytes"
"database/sql"
"errors"
"fmt"
"net/http"
"time"
"github.com/go-chi/render"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
// Project is the JSON representation of a Coder project.
// This type matches the database object for now, but is
// abstracted for ease of change later on.
type Project database.Project
// ProjectVersion is the JSON representation of a Coder project version.
type ProjectVersion struct {
ID uuid.UUID `json:"id"`
ProjectID uuid.UUID `json:"project_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Name string `json:"name"`
StorageMethod database.ProjectStorageMethod `json:"storage_method"`
}
// CreateProjectRequest enables callers to create a new Project.
type CreateProjectRequest struct {
Name string `json:"name" validate:"username,required"`
Provisioner database.ProvisionerType `json:"provisioner" validate:"oneof=terraform cdr-basic,required"`
}
// CreateProjectVersionRequest enables callers to create a new Project Version.
type CreateProjectVersionRequest struct {
Name string `json:"name,omitempty" validate:"username"`
StorageMethod database.ProjectStorageMethod `json:"storage_method" validate:"oneof=inline-archive,required"`
StorageSource []byte `json:"storage_source" validate:"max=1048576,required"`
}
type projects struct {
Database database.Store
}
// allProjects lists all projects across organizations for a user.
func (p *projects) allProjects(rw http.ResponseWriter, r *http.Request) {
apiKey := httpmw.APIKey(r)
organizations, err := p.Database.GetOrganizationsByUserID(r.Context(), apiKey.UserID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get organizations: %s", err.Error()),
})
return
}
organizationIDs := make([]string, 0, len(organizations))
for _, organization := range organizations {
organizationIDs = append(organizationIDs, organization.ID)
}
projects, err := p.Database.GetProjectsByOrganizationIDs(r.Context(), organizationIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get projects: %s", err.Error()),
})
return
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, projects)
}
// allProjectsForOrganization lists all projects for a specific organization.
func (p *projects) allProjectsForOrganization(rw http.ResponseWriter, r *http.Request) {
organization := httpmw.OrganizationParam(r)
projects, err := p.Database.GetProjectsByOrganizationIDs(r.Context(), []string{organization.ID})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get projects: %s", err.Error()),
})
return
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, projects)
}
// createProject makes a new project in an organization.
func (p *projects) createProject(rw http.ResponseWriter, r *http.Request) {
var createProject CreateProjectRequest
if !httpapi.Read(rw, r, &createProject) {
return
}
organization := httpmw.OrganizationParam(r)
_, err := p.Database.GetProjectByOrganizationAndName(r.Context(), database.GetProjectByOrganizationAndNameParams{
OrganizationID: organization.ID,
Name: createProject.Name,
})
if err == nil {
httpapi.Write(rw, http.StatusConflict, httpapi.Response{
Message: fmt.Sprintf("project %q already exists", createProject.Name),
Errors: []httpapi.Error{{
Field: "name",
Code: "exists",
}},
})
return
}
if !errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get project by name: %s", err.Error()),
})
return
}
project, err := p.Database.InsertProject(r.Context(), database.InsertProjectParams{
ID: uuid.New(),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
OrganizationID: organization.ID,
Name: createProject.Name,
Provisioner: createProject.Provisioner,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("insert project: %s", err),
})
return
}
render.Status(r, http.StatusCreated)
render.JSON(rw, r, project)
}
// project returns a single project parsed from the URL path.
func (*projects) project(rw http.ResponseWriter, r *http.Request) {
project := httpmw.ProjectParam(r)
render.Status(r, http.StatusOK)
render.JSON(rw, r, project)
}
// projectVersions lists versions for a single project.
func (p *projects) projectVersions(rw http.ResponseWriter, r *http.Request) {
project := httpmw.ProjectParam(r)
history, err := p.Database.GetProjectHistoryByProjectID(r.Context(), project.ID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get project history: %s", err),
})
return
}
versions := make([]ProjectVersion, 0)
for _, version := range history {
versions = append(versions, convertProjectHistory(version))
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, versions)
}
func (p *projects) createProjectVersion(rw http.ResponseWriter, r *http.Request) {
var createProjectVersion CreateProjectVersionRequest
if !httpapi.Read(rw, r, &createProjectVersion) {
return
}
switch createProjectVersion.StorageMethod {
case database.ProjectStorageMethodInlineArchive:
tarReader := tar.NewReader(bytes.NewReader(createProjectVersion.StorageSource))
_, err := tarReader.Next()
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "the archive must be a tar",
})
return
}
default:
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("unsupported storage method %s", createProjectVersion.StorageMethod),
})
return
}
project := httpmw.ProjectParam(r)
history, err := p.Database.InsertProjectHistory(r.Context(), database.InsertProjectHistoryParams{
ID: uuid.New(),
ProjectID: project.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Name: namesgenerator.GetRandomName(1),
StorageMethod: createProjectVersion.StorageMethod,
StorageSource: createProjectVersion.StorageSource,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("insert project history: %s", err),
})
return
}
// TODO: A job to process the new version should occur here.
render.Status(r, http.StatusCreated)
render.JSON(rw, r, convertProjectHistory(history))
}
func convertProjectHistory(history database.ProjectHistory) ProjectVersion {
return ProjectVersion{
ID: history.ID,
ProjectID: history.ProjectID,
CreatedAt: history.CreatedAt,
UpdatedAt: history.UpdatedAt,
Name: history.Name,
}
}

183
coderd/projects_test.go Normal file
View File

@ -0,0 +1,183 @@
package coderd_test
import (
"archive/tar"
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/database"
)
func TestProjects(t *testing.T) {
t.Parallel()
t.Run("Create", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
})
t.Run("AlreadyExists", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
_, err = server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.Error(t, err)
})
t.Run("ListEmpty", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
projects, err := server.Client.Projects(context.Background(), "")
require.NoError(t, err)
require.Len(t, projects, 0)
})
t.Run("List", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
// Ensure global query works.
projects, err := server.Client.Projects(context.Background(), "")
require.NoError(t, err)
require.Len(t, projects, 1)
// Ensure specified query works.
projects, err = server.Client.Projects(context.Background(), user.Organization)
require.NoError(t, err)
require.Len(t, projects, 1)
})
t.Run("ListEmpty", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
projects, err := server.Client.Projects(context.Background(), user.Organization)
require.NoError(t, err)
require.Len(t, projects, 0)
})
t.Run("Single", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
_, err = server.Client.Project(context.Background(), user.Organization, project.Name)
require.NoError(t, err)
})
t.Run("NoVersions", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
versions, err := server.Client.ProjectVersions(context.Background(), user.Organization, project.Name)
require.NoError(t, err)
require.Len(t, versions, 0)
})
t.Run("CreateVersion", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
var buffer bytes.Buffer
writer := tar.NewWriter(&buffer)
err = writer.WriteHeader(&tar.Header{
Name: "file",
Size: 1 << 10,
})
require.NoError(t, err)
_, err = writer.Write(make([]byte, 1<<10))
require.NoError(t, err)
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
Name: "moo",
StorageMethod: database.ProjectStorageMethodInlineArchive,
StorageSource: buffer.Bytes(),
})
require.NoError(t, err)
versions, err := server.Client.ProjectVersions(context.Background(), user.Organization, project.Name)
require.NoError(t, err)
require.Len(t, versions, 1)
})
t.Run("CreateVersionArchiveTooBig", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
var buffer bytes.Buffer
writer := tar.NewWriter(&buffer)
err = writer.WriteHeader(&tar.Header{
Name: "file",
Size: 1 << 21,
})
require.NoError(t, err)
_, err = writer.Write(make([]byte, 1<<21))
require.NoError(t, err)
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
Name: "moo",
StorageMethod: database.ProjectStorageMethodInlineArchive,
StorageSource: buffer.Bytes(),
})
require.Error(t, err)
})
t.Run("CreateVersionInvalidArchive", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "someproject",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
Name: "moo",
StorageMethod: database.ProjectStorageMethodInlineArchive,
StorageSource: []byte{},
})
require.Error(t, err)
})
}

View File

@ -9,7 +9,9 @@ import (
)
func TestUserPassword(t *testing.T) {
t.Parallel()
t.Run("Legacy", func(t *testing.T) {
t.Parallel()
// Ensures legacy v1 passwords function for v2.
// This has is manually generated using a print statement from v1 code.
equal, err := userpassword.Compare("$pbkdf2-sha256$65535$z8c1p1C2ru9EImBP1I+ZNA$pNjE3Yk0oG0PmJ0Je+y7ENOVlSkn/b0BEqqdKsq6Y97wQBq0xT+lD5bWJpyIKJqQICuPZcEaGDKrXJn8+SIHRg", "tomato")
@ -18,6 +20,7 @@ func TestUserPassword(t *testing.T) {
})
t.Run("Same", func(t *testing.T) {
t.Parallel()
hash, err := userpassword.Hash("password")
require.NoError(t, err)
equal, err := userpassword.Compare(hash, "password")
@ -26,6 +29,7 @@ func TestUserPassword(t *testing.T) {
})
t.Run("Different", func(t *testing.T) {
t.Parallel()
hash, err := userpassword.Hash("password")
require.NoError(t, err)
equal, err := userpassword.Compare(hash, "notpassword")
@ -34,12 +38,14 @@ func TestUserPassword(t *testing.T) {
})
t.Run("Invalid", func(t *testing.T) {
t.Parallel()
equal, err := userpassword.Compare("invalidhash", "password")
require.False(t, equal)
require.Error(t, err)
})
t.Run("InvalidParts", func(t *testing.T) {
t.Parallel()
equal, err := userpassword.Compare("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "test")
require.False(t, equal)
require.Error(t, err)

View File

@ -35,6 +35,7 @@ func TestUsers(t *testing.T) {
})
t.Run("Login", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{

86
codersdk/projects.go Normal file
View File

@ -0,0 +1,86 @@
package codersdk
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/coder/coder/coderd"
)
// Projects lists projects inside an organization.
// If organization is an empty string, all projects will be returned
// for the authenticated user.
func (c *Client) Projects(ctx context.Context, organization string) ([]coderd.Project, error) {
route := "/api/v2/projects"
if organization != "" {
route = fmt.Sprintf("/api/v2/projects/%s", organization)
}
res, err := c.request(ctx, http.MethodGet, route, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
var projects []coderd.Project
return projects, json.NewDecoder(res.Body).Decode(&projects)
}
// Project returns a single project.
func (c *Client) Project(ctx context.Context, organization, project string) (coderd.Project, error) {
res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s", organization, project), nil)
if err != nil {
return coderd.Project{}, nil
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return coderd.Project{}, readBodyAsError(res)
}
var resp coderd.Project
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// CreateProject creates a new project inside an organization.
func (c *Client) CreateProject(ctx context.Context, organization string, request coderd.CreateProjectRequest) (coderd.Project, error) {
res, err := c.request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/projects/%s", organization), request)
if err != nil {
return coderd.Project{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return coderd.Project{}, readBodyAsError(res)
}
var project coderd.Project
return project, json.NewDecoder(res.Body).Decode(&project)
}
// ProjectVersions lists history for a project.
func (c *Client) ProjectVersions(ctx context.Context, organization, project string) ([]coderd.ProjectVersion, error) {
res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/versions", organization, project), nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
var projectVersions []coderd.ProjectVersion
return projectVersions, json.NewDecoder(res.Body).Decode(&projectVersions)
}
// CreateProjectVersion inserts a new version for the project.
func (c *Client) CreateProjectVersion(ctx context.Context, organization, project string, request coderd.CreateProjectVersionRequest) (coderd.ProjectVersion, error) {
res, err := c.request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/projects/%s/%s/versions", organization, project), request)
if err != nil {
return coderd.ProjectVersion{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return coderd.ProjectVersion{}, readBodyAsError(res)
}
var projectVersion coderd.ProjectVersion
return projectVersion, json.NewDecoder(res.Body).Decode(&projectVersion)
}

130
codersdk/projects_test.go Normal file
View File

@ -0,0 +1,130 @@
package codersdk_test
import (
"archive/tar"
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/database"
)
func TestProjects(t *testing.T) {
t.Parallel()
t.Run("UnauthenticatedList", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.Projects(context.Background(), "")
require.Error(t, err)
})
t.Run("List", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.Projects(context.Background(), "")
require.NoError(t, err)
_, err = server.Client.Projects(context.Background(), user.Organization)
require.NoError(t, err)
})
t.Run("UnauthenticatedCreate", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.CreateProject(context.Background(), "", coderd.CreateProjectRequest{})
require.Error(t, err)
})
t.Run("Create", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "bananas",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
})
t.Run("UnauthenticatedSingle", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.Project(context.Background(), "wow", "example")
require.Error(t, err)
})
t.Run("Single", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "bananas",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
_, err = server.Client.Project(context.Background(), user.Organization, "bananas")
require.NoError(t, err)
})
t.Run("UnauthenticatedVersions", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.ProjectVersions(context.Background(), "org", "project")
require.Error(t, err)
})
t.Run("Versions", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "bananas",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
_, err = server.Client.ProjectVersions(context.Background(), user.Organization, project.Name)
require.NoError(t, err)
})
t.Run("CreateVersionUnauthenticated", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.CreateProjectVersion(context.Background(), "org", "project", coderd.CreateProjectVersionRequest{
Name: "hello",
StorageMethod: database.ProjectStorageMethodInlineArchive,
StorageSource: []byte{},
})
require.Error(t, err)
})
t.Run("CreateVersion", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
user := server.RandomInitialUser(t)
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
Name: "bananas",
Provisioner: database.ProvisionerTypeTerraform,
})
require.NoError(t, err)
var buffer bytes.Buffer
writer := tar.NewWriter(&buffer)
err = writer.WriteHeader(&tar.Header{
Name: "file",
Size: 1 << 10,
})
require.NoError(t, err)
_, err = writer.Write(make([]byte, 1<<10))
require.NoError(t, err)
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
Name: "hello",
StorageMethod: database.ProjectStorageMethodInlineArchive,
StorageSource: buffer.Bytes(),
})
require.NoError(t, err)
})
}

View File

@ -11,7 +11,9 @@ import (
)
func TestUsers(t *testing.T) {
t.Parallel()
t.Run("CreateInitial", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{
Email: "wowie@coder.com",
@ -23,12 +25,14 @@ func TestUsers(t *testing.T) {
})
t.Run("NoUser", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_, err := server.Client.User(context.Background(), "")
require.Error(t, err)
})
t.Run("User", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
_, err := server.Client.User(context.Background(), "")
@ -36,6 +40,7 @@ func TestUsers(t *testing.T) {
})
t.Run("UserOrganizations", func(t *testing.T) {
t.Parallel()
server := coderdtest.New(t)
_ = server.RandomInitialUser(t)
orgs, err := server.Client.UserOrganizations(context.Background(), "")

View File

@ -47,6 +47,9 @@ func TestUnbiasedModulo32(t *testing.T) {
const mod = 7
dist := [mod]uint32{}
_, err := cryptorand.UnbiasedModulo32(0, mod)
require.NoError(t, err)
for i := 0; i < 1000; i++ {
b := [4]byte{}
_, _ = rand.Read(b[:])

View File

@ -91,6 +91,7 @@ func TestStringCharset(t *testing.T) {
},
}
//nolint:paralleltest
for _, test := range tests {
test := test
t.Run(test.Name, func(t *testing.T) {

View File

@ -3,6 +3,9 @@ package databasefake
import (
"context"
"database/sql"
"strings"
"github.com/google/uuid"
"github.com/coder/coder/database"
)
@ -14,15 +17,25 @@ func New() database.Store {
organizations: make([]database.Organization, 0),
organizationMembers: make([]database.OrganizationMember, 0),
users: make([]database.User, 0),
project: make([]database.Project, 0),
projectHistory: make([]database.ProjectHistory, 0),
projectParameter: make([]database.ProjectParameter, 0),
}
}
// fakeQuerier replicates database functionality to enable quick testing.
type fakeQuerier struct {
// Legacy tables
apiKeys []database.APIKey
organizations []database.Organization
organizationMembers []database.OrganizationMember
users []database.User
// New tables
project []database.Project
projectHistory []database.ProjectHistory
projectParameter []database.ProjectParameter
}
// InTx doesn't rollback data properly for in-memory yet.
@ -89,6 +102,62 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string)
return organizations, nil
}
func (q *fakeQuerier) GetProjectByOrganizationAndName(_ context.Context, arg database.GetProjectByOrganizationAndNameParams) (database.Project, error) {
for _, project := range q.project {
if project.OrganizationID != arg.OrganizationID {
continue
}
if !strings.EqualFold(project.Name, arg.Name) {
continue
}
return project, nil
}
return database.Project{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetProjectHistoryByProjectID(_ context.Context, projectID uuid.UUID) ([]database.ProjectHistory, error) {
history := make([]database.ProjectHistory, 0)
for _, projectHistory := range q.projectHistory {
if projectHistory.ProjectID.String() != projectID.String() {
continue
}
history = append(history, projectHistory)
}
if len(history) == 0 {
return nil, sql.ErrNoRows
}
return history, nil
}
func (q *fakeQuerier) GetProjectsByOrganizationIDs(_ context.Context, ids []string) ([]database.Project, error) {
projects := make([]database.Project, 0)
for _, project := range q.project {
for _, id := range ids {
if project.OrganizationID == id {
projects = append(projects, project)
break
}
}
}
if len(projects) == 0 {
return nil, sql.ErrNoRows
}
return projects, nil
}
func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
for _, organizationMember := range q.organizationMembers {
if organizationMember.OrganizationID != arg.OrganizationID {
continue
}
if organizationMember.UserID != arg.UserID {
continue
}
return organizationMember, nil
}
return database.OrganizationMember{}, sql.ErrNoRows
}
func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
//nolint:gosimple
key := database.APIKey{
@ -136,6 +205,59 @@ func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.I
return organizationMember, nil
}
func (q *fakeQuerier) InsertProject(_ context.Context, arg database.InsertProjectParams) (database.Project, error) {
project := database.Project{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
OrganizationID: arg.OrganizationID,
Name: arg.Name,
Provisioner: arg.Provisioner,
}
q.project = append(q.project, project)
return project, nil
}
func (q *fakeQuerier) InsertProjectHistory(_ context.Context, arg database.InsertProjectHistoryParams) (database.ProjectHistory, error) {
//nolint:gosimple
history := database.ProjectHistory{
ID: arg.ID,
ProjectID: arg.ProjectID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
Name: arg.Name,
Description: arg.Description,
StorageMethod: arg.StorageMethod,
StorageSource: arg.StorageSource,
ImportJobID: arg.ImportJobID,
}
q.projectHistory = append(q.projectHistory, history)
return history, nil
}
func (q *fakeQuerier) InsertProjectParameter(_ context.Context, arg database.InsertProjectParameterParams) (database.ProjectParameter, error) {
//nolint:gosimple
param := database.ProjectParameter{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
ProjectHistoryID: arg.ProjectHistoryID,
Name: arg.Name,
Description: arg.Description,
DefaultSource: arg.DefaultSource,
AllowOverrideSource: arg.AllowOverrideSource,
DefaultDestination: arg.DefaultDestination,
AllowOverrideDestination: arg.AllowOverrideDestination,
DefaultRefresh: arg.DefaultRefresh,
RedisplayValue: arg.RedisplayValue,
ValidationError: arg.ValidationError,
ValidationCondition: arg.ValidationCondition,
ValidationTypeSystem: arg.ValidationTypeSystem,
ValidationValueType: arg.ValidationValueType,
}
q.projectParameter = append(q.projectParameter, param)
return param, nil
}
func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) {
user := database.User{
ID: arg.ID,

View File

@ -6,6 +6,19 @@ CREATE TYPE login_type AS ENUM (
'oidc'
);
CREATE TYPE parameter_type_system AS ENUM (
'hcl'
);
CREATE TYPE project_storage_method AS ENUM (
'inline-archive'
);
CREATE TYPE provisioner_type AS ENUM (
'terraform',
'cdr-basic'
);
CREATE TYPE userstatus AS ENUM (
'active',
'dormant',
@ -57,6 +70,46 @@ CREATE TABLE organizations (
workspace_auto_off boolean DEFAULT false NOT NULL
);
CREATE TABLE project (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
organization_id text NOT NULL,
name character varying(64) NOT NULL,
provisioner provisioner_type NOT NULL,
active_version_id uuid
);
CREATE TABLE project_history (
id uuid NOT NULL,
project_id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
name character varying(64) NOT NULL,
description character varying(1048576) NOT NULL,
storage_method project_storage_method NOT NULL,
storage_source bytea NOT NULL,
import_job_id uuid NOT NULL
);
CREATE TABLE project_parameter (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
project_history_id uuid NOT NULL,
name character varying(64) NOT NULL,
description character varying(8192) DEFAULT ''::character varying NOT NULL,
default_source text,
allow_override_source boolean NOT NULL,
default_destination text,
allow_override_destination boolean NOT NULL,
default_refresh text NOT NULL,
redisplay_value boolean NOT NULL,
validation_error character varying(256) NOT NULL,
validation_condition character varying(512) NOT NULL,
validation_type_system parameter_type_system NOT NULL,
validation_value_type character varying(64) NOT NULL
);
CREATE TABLE users (
id text NOT NULL,
email text NOT NULL,
@ -79,3 +132,27 @@ CREATE TABLE users (
shell text DEFAULT ''::text NOT NULL
);
ALTER TABLE ONLY project_history
ADD CONSTRAINT project_history_id_key UNIQUE (id);
ALTER TABLE ONLY project_history
ADD CONSTRAINT project_history_project_id_name_key UNIQUE (project_id, name);
ALTER TABLE ONLY project
ADD CONSTRAINT project_id_key UNIQUE (id);
ALTER TABLE ONLY project
ADD CONSTRAINT project_organization_id_name_key UNIQUE (organization_id, name);
ALTER TABLE ONLY project_parameter
ADD CONSTRAINT project_parameter_id_key UNIQUE (id);
ALTER TABLE ONLY project_parameter
ADD CONSTRAINT project_parameter_project_history_id_name_key UNIQUE (project_history_id, name);
ALTER TABLE ONLY project_history
ADD CONSTRAINT project_history_project_id_fkey FOREIGN KEY (project_id) REFERENCES project(id);
ALTER TABLE ONLY project_parameter
ADD CONSTRAINT project_parameter_project_history_id_fkey FOREIGN KEY (project_history_id) REFERENCES project_history(id) ON DELETE CASCADE;

View File

@ -20,12 +20,29 @@ func TestMain(m *testing.M) {
func TestMigrate(t *testing.T) {
t.Parallel()
connection, closeFn, err := postgres.Open()
require.NoError(t, err)
defer closeFn()
db, err := sql.Open("postgres", connection)
require.NoError(t, err)
defer db.Close()
err = database.Migrate(db)
require.NoError(t, err)
t.Run("Once", func(t *testing.T) {
t.Parallel()
connection, closeFn, err := postgres.Open()
require.NoError(t, err)
defer closeFn()
db, err := sql.Open("postgres", connection)
require.NoError(t, err)
defer db.Close()
err = database.Migrate(db)
require.NoError(t, err)
})
t.Run("Twice", func(t *testing.T) {
t.Parallel()
connection, closeFn, err := postgres.Open()
require.NoError(t, err)
defer closeFn()
db, err := sql.Open("postgres", connection)
require.NoError(t, err)
defer db.Close()
err = database.Migrate(db)
require.NoError(t, err)
err = database.Migrate(db)
require.NoError(t, err)
})
}

View File

@ -0,0 +1,84 @@
CREATE TYPE provisioner_type AS ENUM ('terraform', 'cdr-basic');
-- Project defines infrastructure that your software project
-- requires for development.
CREATE TABLE project (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
-- Projects must be scoped to an organization.
organization_id text NOT NULL,
name varchar(64) NOT NULL,
provisioner provisioner_type NOT NULL,
-- Target's a Project Version to use for Workspaces.
-- If a Workspace doesn't match this version, it will be prompted to rebuild.
active_version_id uuid,
-- Disallow projects to have the same name under
-- the same organization.
UNIQUE(organization_id, name)
);
CREATE TYPE project_storage_method AS ENUM ('inline-archive');
-- Project Versions store Project history. When a Project Version is imported,
-- an "import" job is queued to parse parameters. A Project Version
-- can only be used if the import job succeeds.
CREATE TABLE project_history (
id uuid NOT NULL UNIQUE,
-- This should be indexed.
project_id uuid NOT NULL REFERENCES project (id),
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
-- Name is generated for ease of differentiation.
-- eg. TheCozyRabbit16
name varchar(64) NOT NULL,
-- Extracted from a README.md on import.
-- Maximum of 1MB.
description varchar(1048576) NOT NULL,
storage_method project_storage_method NOT NULL,
storage_source bytea NOT NULL,
-- The import job for a Project Version. This is used
-- to detect if an import was successful.
import_job_id uuid NOT NULL,
-- Disallow projects to have the same build name
-- multiple times.
UNIQUE(project_id, name)
);
-- Types of parameters the automator supports.
CREATE TYPE parameter_type_system AS ENUM ('hcl');
-- Stores project version parameters parsed on import.
-- No secrets are stored here.
--
-- All parameter validation occurs server-side to process
-- complex validations.
--
-- Parameter types, description, and validation will produce
-- a UI for users to enter values.
-- Needs to be made consistent with the examples below.
CREATE TABLE project_parameter (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
project_history_id uuid NOT NULL REFERENCES project_history(id) ON DELETE CASCADE,
name varchar(64) NOT NULL,
-- 8KB limit
description varchar(8192) NOT NULL DEFAULT '',
-- eg. data://inlinevalue
default_source text,
-- Allows the user to override the source.
allow_override_source boolean NOT null,
-- eg. env://SOME_VARIABLE, tfvars://example
default_destination text,
-- Allows the user to override the destination.
allow_override_destination boolean NOT null,
default_refresh text NOT NULL,
-- Whether the consumer can view the source and destinations.
redisplay_value boolean NOT null,
-- This error would appear in the UI if the condition is not met.
validation_error varchar(256) NOT NULL,
validation_condition varchar(512) NOT NULL,
validation_type_system parameter_type_system NOT NULL,
validation_value_type varchar(64) NOT NULL,
UNIQUE(project_history_id, name)
);

View File

@ -3,9 +3,12 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
)
type LoginType string
@ -28,6 +31,61 @@ func (e *LoginType) Scan(src interface{}) error {
return nil
}
type ParameterTypeSystem string
const (
ParameterTypeSystemHCL ParameterTypeSystem = "hcl"
)
func (e *ParameterTypeSystem) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ParameterTypeSystem(s)
case string:
*e = ParameterTypeSystem(s)
default:
return fmt.Errorf("unsupported scan type for ParameterTypeSystem: %T", src)
}
return nil
}
type ProjectStorageMethod string
const (
ProjectStorageMethodInlineArchive ProjectStorageMethod = "inline-archive"
)
func (e *ProjectStorageMethod) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ProjectStorageMethod(s)
case string:
*e = ProjectStorageMethod(s)
default:
return fmt.Errorf("unsupported scan type for ProjectStorageMethod: %T", src)
}
return nil
}
type ProvisionerType string
const (
ProvisionerTypeTerraform ProvisionerType = "terraform"
ProvisionerTypeCdrBasic ProvisionerType = "cdr-basic"
)
func (e *ProvisionerType) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ProvisionerType(s)
case string:
*e = ProvisionerType(s)
default:
return fmt.Errorf("unsupported scan type for ProvisionerType: %T", src)
}
return nil
}
type UserStatus string
const (
@ -93,6 +151,46 @@ type OrganizationMember struct {
Roles []string `db:"roles" json:"roles"`
}
type Project struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OrganizationID string `db:"organization_id" json:"organization_id"`
Name string `db:"name" json:"name"`
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
ActiveVersionID uuid.NullUUID `db:"active_version_id" json:"active_version_id"`
}
type ProjectHistory struct {
ID uuid.UUID `db:"id" json:"id"`
ProjectID uuid.UUID `db:"project_id" json:"project_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
StorageMethod ProjectStorageMethod `db:"storage_method" json:"storage_method"`
StorageSource []byte `db:"storage_source" json:"storage_source"`
ImportJobID uuid.UUID `db:"import_job_id" json:"import_job_id"`
}
type ProjectParameter struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
DefaultSource sql.NullString `db:"default_source" json:"default_source"`
AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"`
DefaultDestination sql.NullString `db:"default_destination" json:"default_destination"`
AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"`
DefaultRefresh string `db:"default_refresh" json:"default_refresh"`
RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"`
ValidationError string `db:"validation_error" json:"validation_error"`
ValidationCondition string `db:"validation_condition" json:"validation_condition"`
ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"`
ValidationValueType string `db:"validation_value_type" json:"validation_value_type"`
}
type User struct {
ID string `db:"id" json:"id"`
Email string `db:"email" json:"email"`

View File

@ -18,6 +18,7 @@ func TestPubsub(t *testing.T) {
t.Parallel()
t.Run("Postgres", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
@ -45,4 +46,20 @@ func TestPubsub(t *testing.T) {
message := <-messageChannel
assert.Equal(t, string(message), data)
})
t.Run("PostgresCloseCancel", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
connectionURL, close, err := postgres.Open()
require.NoError(t, err)
defer close()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
cancelFunc()
})
}

View File

@ -4,18 +4,27 @@ package database
import (
"context"
"github.com/google/uuid"
)
type querier interface {
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error)
GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error)
GetProjectHistoryByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectHistory, error)
GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error)
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id string) (User, error)
GetUserCount(ctx context.Context) (int64, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error)
InsertProjectHistory(ctx context.Context, arg InsertProjectHistoryParams) (ProjectHistory, error)
InsertProjectParameter(ctx context.Context, arg InsertProjectParameterParams) (ProjectParameter, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
}

View File

@ -42,12 +42,67 @@ FROM
users;
-- name: GetOrganizationByName :one
SELECT * FROM organizations WHERE name = $1 LIMIT 1;
SELECT
*
FROM
organizations
WHERE
name = $1
LIMIT
1;
-- name: GetOrganizationsByUserID :many
SELECT * FROM organizations WHERE id = (
SELECT organization_id FROM organization_members WHERE user_id = $1
);
SELECT
*
FROM
organizations
WHERE
id = (
SELECT
organization_id
FROM
organization_members
WHERE
user_id = $1
);
-- name: GetOrganizationMemberByUserID :one
SELECT
*
FROM
organization_members
WHERE
organization_id = $1
AND user_id = $2
LIMIT
1;
-- name: GetProjectByOrganizationAndName :one
SELECT
*
FROM
project
WHERE
organization_id = $1
AND name = $2
LIMIT
1;
-- name: GetProjectsByOrganizationIDs :many
SELECT
*
FROM
project
WHERE
organization_id = ANY(@ids :: text [ ]);
-- name: GetProjectHistoryByProjectID :many
SELECT
*
FROM
project_history
WHERE
project_id = $1;
-- name: InsertAPIKey :one
INSERT INTO
@ -88,10 +143,89 @@ VALUES
) RETURNING *;
-- name: InsertOrganization :one
INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING *;
INSERT INTO
organizations (id, name, description, created_at, updated_at)
VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertOrganizationMember :one
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING *;
INSERT INTO
organization_members (
organization_id,
user_id,
created_at,
updated_at,
roles
)
VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertProject :one
INSERT INTO
project (
id,
created_at,
updated_at,
organization_id,
name,
provisioner
)
VALUES
($1, $2, $3, $4, $5, $6) RETURNING *;
-- name: InsertProjectHistory :one
INSERT INTO
project_history (
id,
project_id,
created_at,
updated_at,
name,
description,
storage_method,
storage_source,
import_job_id
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
-- name: InsertProjectParameter :one
INSERT INTO
project_parameter (
id,
created_at,
project_history_id,
name,
description,
default_source,
allow_override_source,
default_destination,
allow_override_destination,
default_refresh,
redisplay_value,
validation_error,
validation_condition,
validation_type_system,
validation_value_type
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING *;
-- name: InsertUser :one
INSERT INTO

View File

@ -5,8 +5,10 @@ package database
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
)
@ -45,7 +47,14 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
}
const getOrganizationByName = `-- name: GetOrganizationByName :one
SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE name = $1 LIMIT 1
SELECT
id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
FROM
organizations
WHERE
name = $1
LIMIT
1
`
func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Organization, error) {
@ -66,10 +75,50 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or
return i, err
}
const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one
SELECT
organization_id, user_id, created_at, updated_at, roles
FROM
organization_members
WHERE
organization_id = $1
AND user_id = $2
LIMIT
1
`
type GetOrganizationMemberByUserIDParams struct {
OrganizationID string `db:"organization_id" json:"organization_id"`
UserID string `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) {
row := q.db.QueryRowContext(ctx, getOrganizationMemberByUserID, arg.OrganizationID, arg.UserID)
var i OrganizationMember
err := row.Scan(
&i.OrganizationID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
pq.Array(&i.Roles),
)
return i, err
}
const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many
SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE id = (
SELECT organization_id FROM organization_members WHERE user_id = $1
)
SELECT
id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
FROM
organizations
WHERE
id = (
SELECT
organization_id
FROM
organization_members
WHERE
user_id = $1
)
`
func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) {
@ -106,6 +155,120 @@ func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string
return items, nil
}
const getProjectByOrganizationAndName = `-- name: GetProjectByOrganizationAndName :one
SELECT
id, created_at, updated_at, organization_id, name, provisioner, active_version_id
FROM
project
WHERE
organization_id = $1
AND name = $2
LIMIT
1
`
type GetProjectByOrganizationAndNameParams struct {
OrganizationID string `db:"organization_id" json:"organization_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error) {
row := q.db.QueryRowContext(ctx, getProjectByOrganizationAndName, arg.OrganizationID, arg.Name)
var i Project
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OrganizationID,
&i.Name,
&i.Provisioner,
&i.ActiveVersionID,
)
return i, err
}
const getProjectHistoryByProjectID = `-- name: GetProjectHistoryByProjectID :many
SELECT
id, project_id, created_at, updated_at, name, description, storage_method, storage_source, import_job_id
FROM
project_history
WHERE
project_id = $1
`
func (q *sqlQuerier) GetProjectHistoryByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectHistory, error) {
rows, err := q.db.QueryContext(ctx, getProjectHistoryByProjectID, projectID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ProjectHistory
for rows.Next() {
var i ProjectHistory
if err := rows.Scan(
&i.ID,
&i.ProjectID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.Description,
&i.StorageMethod,
&i.StorageSource,
&i.ImportJobID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getProjectsByOrganizationIDs = `-- name: GetProjectsByOrganizationIDs :many
SELECT
id, created_at, updated_at, organization_id, name, provisioner, active_version_id
FROM
project
WHERE
organization_id = ANY($1 :: text [ ])
`
func (q *sqlQuerier) GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error) {
rows, err := q.db.QueryContext(ctx, getProjectsByOrganizationIDs, pq.Array(ids))
if err != nil {
return nil, err
}
defer rows.Close()
var items []Project
for rows.Next() {
var i Project
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OrganizationID,
&i.Name,
&i.Provisioner,
&i.ActiveVersionID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
SELECT
id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell
@ -299,7 +462,10 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
}
const insertOrganization = `-- name: InsertOrganization :one
INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
INSERT INTO
organizations (id, name, description, created_at, updated_at)
VALUES
($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
`
type InsertOrganizationParams struct {
@ -335,7 +501,16 @@ func (q *sqlQuerier) InsertOrganization(ctx context.Context, arg InsertOrganizat
}
const insertOrganizationMember = `-- name: InsertOrganizationMember :one
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles
INSERT INTO
organization_members (
organization_id,
user_id,
created_at,
updated_at,
roles
)
VALUES
($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles
`
type InsertOrganizationMemberParams struct {
@ -365,6 +540,203 @@ func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrg
return i, err
}
const insertProject = `-- name: InsertProject :one
INSERT INTO
project (
id,
created_at,
updated_at,
organization_id,
name,
provisioner
)
VALUES
($1, $2, $3, $4, $5, $6) RETURNING id, created_at, updated_at, organization_id, name, provisioner, active_version_id
`
type InsertProjectParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OrganizationID string `db:"organization_id" json:"organization_id"`
Name string `db:"name" json:"name"`
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
}
func (q *sqlQuerier) InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error) {
row := q.db.QueryRowContext(ctx, insertProject,
arg.ID,
arg.CreatedAt,
arg.UpdatedAt,
arg.OrganizationID,
arg.Name,
arg.Provisioner,
)
var i Project
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OrganizationID,
&i.Name,
&i.Provisioner,
&i.ActiveVersionID,
)
return i, err
}
const insertProjectHistory = `-- name: InsertProjectHistory :one
INSERT INTO
project_history (
id,
project_id,
created_at,
updated_at,
name,
description,
storage_method,
storage_source,
import_job_id
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, project_id, created_at, updated_at, name, description, storage_method, storage_source, import_job_id
`
type InsertProjectHistoryParams struct {
ID uuid.UUID `db:"id" json:"id"`
ProjectID uuid.UUID `db:"project_id" json:"project_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
StorageMethod ProjectStorageMethod `db:"storage_method" json:"storage_method"`
StorageSource []byte `db:"storage_source" json:"storage_source"`
ImportJobID uuid.UUID `db:"import_job_id" json:"import_job_id"`
}
func (q *sqlQuerier) InsertProjectHistory(ctx context.Context, arg InsertProjectHistoryParams) (ProjectHistory, error) {
row := q.db.QueryRowContext(ctx, insertProjectHistory,
arg.ID,
arg.ProjectID,
arg.CreatedAt,
arg.UpdatedAt,
arg.Name,
arg.Description,
arg.StorageMethod,
arg.StorageSource,
arg.ImportJobID,
)
var i ProjectHistory
err := row.Scan(
&i.ID,
&i.ProjectID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.Description,
&i.StorageMethod,
&i.StorageSource,
&i.ImportJobID,
)
return i, err
}
const insertProjectParameter = `-- name: InsertProjectParameter :one
INSERT INTO
project_parameter (
id,
created_at,
project_history_id,
name,
description,
default_source,
allow_override_source,
default_destination,
allow_override_destination,
default_refresh,
redisplay_value,
validation_error,
validation_condition,
validation_type_system,
validation_value_type
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING id, created_at, project_history_id, name, description, default_source, allow_override_source, default_destination, allow_override_destination, default_refresh, redisplay_value, validation_error, validation_condition, validation_type_system, validation_value_type
`
type InsertProjectParameterParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
DefaultSource sql.NullString `db:"default_source" json:"default_source"`
AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"`
DefaultDestination sql.NullString `db:"default_destination" json:"default_destination"`
AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"`
DefaultRefresh string `db:"default_refresh" json:"default_refresh"`
RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"`
ValidationError string `db:"validation_error" json:"validation_error"`
ValidationCondition string `db:"validation_condition" json:"validation_condition"`
ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"`
ValidationValueType string `db:"validation_value_type" json:"validation_value_type"`
}
func (q *sqlQuerier) InsertProjectParameter(ctx context.Context, arg InsertProjectParameterParams) (ProjectParameter, error) {
row := q.db.QueryRowContext(ctx, insertProjectParameter,
arg.ID,
arg.CreatedAt,
arg.ProjectHistoryID,
arg.Name,
arg.Description,
arg.DefaultSource,
arg.AllowOverrideSource,
arg.DefaultDestination,
arg.AllowOverrideDestination,
arg.DefaultRefresh,
arg.RedisplayValue,
arg.ValidationError,
arg.ValidationCondition,
arg.ValidationTypeSystem,
arg.ValidationValueType,
)
var i ProjectParameter
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.ProjectHistoryID,
&i.Name,
&i.Description,
&i.DefaultSource,
&i.AllowOverrideSource,
&i.DefaultDestination,
&i.AllowOverrideDestination,
&i.DefaultRefresh,
&i.RedisplayValue,
&i.ValidationError,
&i.ValidationCondition,
&i.ValidationTypeSystem,
&i.ValidationValueType,
)
return i, err
}
const insertUser = `-- name: InsertUser :one
INSERT INTO
users (

View File

@ -25,4 +25,5 @@ rename:
oidc_expiry: OIDCExpiry
oidc_id_token: OIDCIDToken
oidc_refresh_token: OIDCRefreshToken
parameter_type_system_hcl: ParameterTypeSystemHCL
userstatus: UserStatus

1
go.mod
View File

@ -18,6 +18,7 @@ require (
github.com/hashicorp/terraform-exec v0.15.0
github.com/justinas/nosurf v1.1.1
github.com/lib/pq v1.10.4
github.com/moby/moby v20.10.12+incompatible
github.com/ory/dockertest/v3 v3.8.1
github.com/pion/datachannel v1.5.2
github.com/pion/logging v0.2.2

2
go.sum
View File

@ -903,6 +903,8 @@ github.com/mitchellh/osext v0.0.0-20151018003038-5e2d6d41470f/go.mod h1:OkQIRizQ
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc=
github.com/moby/moby v20.10.12+incompatible h1:MJVrdG0tIQqVJQBTdtooPuZQFIgski5pYTXlcW8ToE0=
github.com/moby/moby v20.10.12+incompatible/go.mod h1:fDXVQ6+S340veQPv35CzDahGBmHsiclFwfEygB/TWMc=
github.com/moby/sys/mountinfo v0.4.0/go.mod h1:rEr8tzG/lsIZHBtN/JjGG+LMYx9eXgW2JI+6q0qou+A=
github.com/moby/sys/mountinfo v0.4.1/go.mod h1:rEr8tzG/lsIZHBtN/JjGG+LMYx9eXgW2JI+6q0qou+A=
github.com/moby/sys/symlink v0.1.0/go.mod h1:GGDODQmbFOjFsXvfLVn3+ZRxkch54RkSiGqsZeMYowQ=

View File

@ -13,7 +13,9 @@ import (
)
func TestWrite(t *testing.T) {
t.Parallel()
t.Run("NoErrors", func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "wow",
@ -27,7 +29,9 @@ func TestWrite(t *testing.T) {
}
func TestRead(t *testing.T) {
t.Parallel()
t.Run("EmptyStruct", func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
v := struct{}{}
@ -35,6 +39,7 @@ func TestRead(t *testing.T) {
})
t.Run("NoBody", func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", nil)
var v json.RawMessage
@ -42,6 +47,7 @@ func TestRead(t *testing.T) {
})
t.Run("Validate", func(t *testing.T) {
t.Parallel()
type toValidate struct {
Value string `json:"value" validate:"required"`
}
@ -54,6 +60,7 @@ func TestRead(t *testing.T) {
})
t.Run("ValidateFailure", func(t *testing.T) {
t.Parallel()
type toValidate struct {
Value string `json:"value" validate:"required"`
}
@ -72,6 +79,7 @@ func TestRead(t *testing.T) {
}
func TestReadUsername(t *testing.T) {
t.Parallel()
// Tests whether usernames are valid or not.
testCases := []struct {
Username string
@ -121,7 +129,9 @@ func TestReadUsername(t *testing.T) {
Username string `json:"username" validate:"username"`
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.Username, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
data, err := json.Marshal(toValidate{testCase.Username})
require.NoError(t, err)

View File

@ -35,6 +35,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("NoCookie", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
@ -47,6 +48,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("InvalidFormat", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
@ -64,6 +66,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("InvalidIDLength", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
@ -81,6 +84,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("InvalidSecretLength", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
@ -98,6 +102,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -116,6 +121,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("InvalidSecret", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -141,6 +147,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("Expired", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -165,6 +172,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("Valid", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -203,6 +211,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -235,6 +244,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("ValidUpdateExpiry", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -267,6 +277,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("OIDCNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
@ -300,6 +311,7 @@ func TestAPIKey(t *testing.T) {
})
t.Run("OIDCRefresh", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()

View File

@ -0,0 +1,86 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
type organizationParamContextKey struct{}
type organizationMemberParamContextKey struct{}
// OrganizationParam returns the organization from the ExtractOrganizationParam handler.
func OrganizationParam(r *http.Request) database.Organization {
organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization)
if !ok {
panic("developer error: organization param middleware not provided")
}
return organization
}
// OrganizationMemberParam returns the organization membership that allowed the query
// from the ExtractOrganizationParam handler.
func OrganizationMemberParam(r *http.Request) database.OrganizationMember {
organizationMember, ok := r.Context().Value(organizationMemberParamContextKey{}).(database.OrganizationMember)
if !ok {
panic("developer error: organization param middleware not provided")
}
return organizationMember
}
// ExtractOrganizationParam grabs an organization and user membership from the "organization" URL parameter.
// This middleware requires the API key middleware higher in the call stack for authentication.
func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
apiKey := APIKey(r)
organizationName := chi.URLParam(r, "organization")
if organizationName == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "organization name must be provided",
})
return
}
organization, err := db.GetOrganizationByName(r.Context(), organizationName)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("organization %q does not exist", organizationName),
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get organization: %s", err.Error()),
})
return
}
organizationMember, err := db.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{
OrganizationID: organization.ID,
UserID: apiKey.UserID,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "not a member of the organization",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get organization member: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization)
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,165 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/httpmw"
)
func TestOrganizationParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.User) {
var (
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
hashed = sha256.Sum256([]byte(secret))
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
return r, user
}
t.Run("None", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, _ = setupAuthentication(db)
rtr = chi.NewRouter()
)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, _ = setupAuthentication(db)
rtr = chi.NewRouter()
)
chi.RouteContext(r.Context()).URLParams.Add("organization", "nothin")
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("NotInOrganization", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, _ = setupAuthentication(db)
rtr = chi.NewRouter()
)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.NewString(),
Name: "test",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.Name)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Success", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, user = setupAuthentication(db)
rtr = chi.NewRouter()
)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.NewString(),
Name: "test",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.Name)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.OrganizationParam(r)
_ = httpmw.OrganizationMemberParam(r)
rw.WriteHeader(http.StatusOK)
})
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

60
httpmw/projectparam.go Normal file
View File

@ -0,0 +1,60 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
type projectParamContextKey struct{}
// ProjectParam returns the project from the ExtractProjectParameter handler.
func ProjectParam(r *http.Request) database.Project {
project, ok := r.Context().Value(projectParamContextKey{}).(database.Project)
if !ok {
panic("developer error: project param middleware not provided")
}
return project
}
// ExtractProjectParameter grabs a project from the "project" URL parameter.
func ExtractProjectParameter(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
organization := OrganizationParam(r)
projectName := chi.URLParam(r, "project")
if projectName == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "project name must be provided",
})
return
}
project, err := db.GetProjectByOrganizationAndName(r.Context(), database.GetProjectByOrganizationAndNameParams{
OrganizationID: organization.ID,
Name: projectName,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("project %q does not exist", projectName),
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get project: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), projectParamContextKey{}, project)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

151
httpmw/projectparam_test.go Normal file
View File

@ -0,0 +1,151 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/httpmw"
)
func TestProjectParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.Organization) {
var (
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
)
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
orgID, err := cryptorand.String(16)
require.NoError(t, err)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: orgID,
Name: "banana",
Description: "wowie",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: orgID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
ctx.URLParams.Add("organization", organization.Name)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, organization
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractProjectParameter(db),
)
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractProjectParameter(db),
)
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("project", "nothin")
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("Project", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractProjectParameter(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.ProjectParam(r)
rw.WriteHeader(http.StatusOK)
})
r, org := setupAuthentication(db)
project, err := db.InsertProject(context.Background(), database.InsertProjectParams{
ID: uuid.New(),
OrganizationID: org.ID,
Name: "moo",
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("project", project.Name)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -18,6 +18,7 @@ import (
)
func TestUserParam(t *testing.T) {
t.Parallel()
setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) {
var (
db = databasefake.New()
@ -47,6 +48,7 @@ func TestUserParam(t *testing.T) {
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
@ -62,6 +64,7 @@ func TestUserParam(t *testing.T) {
})
t.Run("NotMe", func(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
@ -80,6 +83,7 @@ func TestUserParam(t *testing.T) {
})
t.Run("Me", func(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {

View File

@ -211,6 +211,7 @@ func TestConn(t *testing.T) {
})
t.Run("CloseWithError", func(t *testing.T) {
t.Parallel()
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
require.NoError(t, err)
expectedErr := errors.New("wow")

View File

@ -14,9 +14,11 @@ import (
)
func TestListen(t *testing.T) {
t.Parallel()
// Ensures connections blocked on Accept() are
// closed if the listener is.
t.Run("NoAcceptClosed", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client, server := provisionersdk.TransportPipe()
defer client.Close()
@ -37,6 +39,7 @@ func TestListen(t *testing.T) {
// Ensures Accept() properly exits when Close() is called.
t.Run("AcceptClosed", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.TransportPipe()
defer client.Close()
defer server.Close()

View File

@ -18,7 +18,9 @@ func TestMain(m *testing.M) {
}
func TestProvisionerSDK(t *testing.T) {
t.Parallel()
t.Run("Serve", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.TransportPipe()
defer client.Close()
defer server.Close()
@ -37,6 +39,7 @@ func TestProvisionerSDK(t *testing.T) {
require.Equal(t, drpcerr.Unimplemented, int(drpcerr.Code(err)))
})
t.Run("ServeClosedPipe", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.TransportPipe()
_ = client.Close()
_ = server.Close()