mirror of https://github.com/coder/coder.git
feat: Add project API endpoints (#51)
* feat: Add project models * Add project query functions * Add organization parameter query * Add project URL parameter parse * Add project create and list endpoints * Add test for organization provided * Remove unimplemented routes * Decrease conn timeout * Add test for UnbiasedModulo32 * Fix expected value * Add single user endpoint * Add query for project versions * Fix linting errors * Add comments * Add test for invalid archive * Check unauthenticated endpoints * Add check if no change happened * Ensure context close ends listener * Fix parallel test run * Test empty * Fix organization param comment
This commit is contained in:
parent
52e50fc9ca
commit
a44056cff5
|
@ -234,6 +234,7 @@ linters:
|
|||
- misspell
|
||||
- nilnil
|
||||
- noctx
|
||||
- paralleltest
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
|
|
5
Makefile
5
Makefile
|
@ -11,7 +11,7 @@ database/dump.sql: $(wildcard database/migrations/*.sql)
|
|||
go run database/dump/main.go
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
database/generate: database/dump.sql database/query.sql
|
||||
database/generate: fmt/sql database/dump.sql database/query.sql
|
||||
cd database && sqlc generate && rm db_tmp.go
|
||||
cd database && gofmt -w -r 'Querier -> querier' *.go
|
||||
cd database && gofmt -w -r 'Queries -> sqlQuerier' *.go
|
||||
|
@ -27,12 +27,13 @@ else
|
|||
endif
|
||||
.PHONY: fmt/prettier
|
||||
|
||||
fmt/sql:
|
||||
fmt/sql: ./database/query.sql
|
||||
npx sql-formatter \
|
||||
--language postgresql \
|
||||
--lines-between-queries 2 \
|
||||
./database/query.sql \
|
||||
--output ./database/query.sql
|
||||
sed -i 's/@ /@/g' ./database/query.sql
|
||||
|
||||
fmt: fmt/prettier fmt/sql
|
||||
.PHONY: fmt
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func TestRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go cancelFunc()
|
||||
err := cmd.Root().ExecuteContext(ctx)
|
||||
|
|
|
@ -20,6 +20,9 @@ type Options struct {
|
|||
|
||||
// New constructs the Coder API into an HTTP handler.
|
||||
func New(options *Options) http.Handler {
|
||||
projects := &projects{
|
||||
Database: options.Database,
|
||||
}
|
||||
users := &users{
|
||||
Database: options.Database,
|
||||
}
|
||||
|
@ -44,6 +47,25 @@ func New(options *Options) http.Handler {
|
|||
r.Get("/{user}/organizations", users.userOrganizations)
|
||||
})
|
||||
})
|
||||
r.Route("/projects", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
)
|
||||
r.Get("/", projects.allProjects)
|
||||
r.Route("/{organization}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOrganizationParam(options.Database))
|
||||
r.Get("/", projects.allProjectsForOrganization)
|
||||
r.Post("/", projects.createProject)
|
||||
r.Route("/{project}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractProjectParameter(options.Database))
|
||||
r.Get("/", projects.project)
|
||||
r.Route("/versions", func(r chi.Router) {
|
||||
r.Get("/", projects.projectVersions)
|
||||
r.Post("/", projects.createProjectVersion)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
r.NotFound(site.Handler().ServeHTTP)
|
||||
return r
|
||||
|
|
|
@ -13,6 +13,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_ = server.RandomInitialUser(t)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,229 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
// Project is the JSON representation of a Coder project.
|
||||
// This type matches the database object for now, but is
|
||||
// abstracted for ease of change later on.
|
||||
type Project database.Project
|
||||
|
||||
// ProjectVersion is the JSON representation of a Coder project version.
|
||||
type ProjectVersion struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ProjectID uuid.UUID `json:"project_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Name string `json:"name"`
|
||||
StorageMethod database.ProjectStorageMethod `json:"storage_method"`
|
||||
}
|
||||
|
||||
// CreateProjectRequest enables callers to create a new Project.
|
||||
type CreateProjectRequest struct {
|
||||
Name string `json:"name" validate:"username,required"`
|
||||
Provisioner database.ProvisionerType `json:"provisioner" validate:"oneof=terraform cdr-basic,required"`
|
||||
}
|
||||
|
||||
// CreateProjectVersionRequest enables callers to create a new Project Version.
|
||||
type CreateProjectVersionRequest struct {
|
||||
Name string `json:"name,omitempty" validate:"username"`
|
||||
StorageMethod database.ProjectStorageMethod `json:"storage_method" validate:"oneof=inline-archive,required"`
|
||||
StorageSource []byte `json:"storage_source" validate:"max=1048576,required"`
|
||||
}
|
||||
|
||||
type projects struct {
|
||||
Database database.Store
|
||||
}
|
||||
|
||||
// allProjects lists all projects across organizations for a user.
|
||||
func (p *projects) allProjects(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := httpmw.APIKey(r)
|
||||
organizations, err := p.Database.GetOrganizationsByUserID(r.Context(), apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get organizations: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
organizationIDs := make([]string, 0, len(organizations))
|
||||
for _, organization := range organizations {
|
||||
organizationIDs = append(organizationIDs, organization.ID)
|
||||
}
|
||||
projects, err := p.Database.GetProjectsByOrganizationIDs(r.Context(), organizationIDs)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get projects: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, projects)
|
||||
}
|
||||
|
||||
// allProjectsForOrganization lists all projects for a specific organization.
|
||||
func (p *projects) allProjectsForOrganization(rw http.ResponseWriter, r *http.Request) {
|
||||
organization := httpmw.OrganizationParam(r)
|
||||
projects, err := p.Database.GetProjectsByOrganizationIDs(r.Context(), []string{organization.ID})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get projects: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, projects)
|
||||
}
|
||||
|
||||
// createProject makes a new project in an organization.
|
||||
func (p *projects) createProject(rw http.ResponseWriter, r *http.Request) {
|
||||
var createProject CreateProjectRequest
|
||||
if !httpapi.Read(rw, r, &createProject) {
|
||||
return
|
||||
}
|
||||
organization := httpmw.OrganizationParam(r)
|
||||
_, err := p.Database.GetProjectByOrganizationAndName(r.Context(), database.GetProjectByOrganizationAndNameParams{
|
||||
OrganizationID: organization.ID,
|
||||
Name: createProject.Name,
|
||||
})
|
||||
if err == nil {
|
||||
httpapi.Write(rw, http.StatusConflict, httpapi.Response{
|
||||
Message: fmt.Sprintf("project %q already exists", createProject.Name),
|
||||
Errors: []httpapi.Error{{
|
||||
Field: "name",
|
||||
Code: "exists",
|
||||
}},
|
||||
})
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get project by name: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
project, err := p.Database.InsertProject(r.Context(), database.InsertProjectParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
OrganizationID: organization.ID,
|
||||
Name: createProject.Name,
|
||||
Provisioner: createProject.Provisioner,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert project: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
render.Status(r, http.StatusCreated)
|
||||
render.JSON(rw, r, project)
|
||||
}
|
||||
|
||||
// project returns a single project parsed from the URL path.
|
||||
func (*projects) project(rw http.ResponseWriter, r *http.Request) {
|
||||
project := httpmw.ProjectParam(r)
|
||||
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, project)
|
||||
}
|
||||
|
||||
// projectVersions lists versions for a single project.
|
||||
func (p *projects) projectVersions(rw http.ResponseWriter, r *http.Request) {
|
||||
project := httpmw.ProjectParam(r)
|
||||
|
||||
history, err := p.Database.GetProjectHistoryByProjectID(r.Context(), project.ID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get project history: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
versions := make([]ProjectVersion, 0)
|
||||
for _, version := range history {
|
||||
versions = append(versions, convertProjectHistory(version))
|
||||
}
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, versions)
|
||||
}
|
||||
|
||||
func (p *projects) createProjectVersion(rw http.ResponseWriter, r *http.Request) {
|
||||
var createProjectVersion CreateProjectVersionRequest
|
||||
if !httpapi.Read(rw, r, &createProjectVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
switch createProjectVersion.StorageMethod {
|
||||
case database.ProjectStorageMethodInlineArchive:
|
||||
tarReader := tar.NewReader(bytes.NewReader(createProjectVersion.StorageSource))
|
||||
_, err := tarReader.Next()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "the archive must be a tar",
|
||||
})
|
||||
return
|
||||
}
|
||||
default:
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("unsupported storage method %s", createProjectVersion.StorageMethod),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
project := httpmw.ProjectParam(r)
|
||||
history, err := p.Database.InsertProjectHistory(r.Context(), database.InsertProjectHistoryParams{
|
||||
ID: uuid.New(),
|
||||
ProjectID: project.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Name: namesgenerator.GetRandomName(1),
|
||||
StorageMethod: createProjectVersion.StorageMethod,
|
||||
StorageSource: createProjectVersion.StorageSource,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert project history: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: A job to process the new version should occur here.
|
||||
|
||||
render.Status(r, http.StatusCreated)
|
||||
render.JSON(rw, r, convertProjectHistory(history))
|
||||
}
|
||||
|
||||
func convertProjectHistory(history database.ProjectHistory) ProjectVersion {
|
||||
return ProjectVersion{
|
||||
ID: history.ID,
|
||||
ProjectID: history.ProjectID,
|
||||
CreatedAt: history.CreatedAt,
|
||||
UpdatedAt: history.UpdatedAt,
|
||||
Name: history.Name,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/database"
|
||||
)
|
||||
|
||||
func TestProjects(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("AlreadyExists", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("ListEmpty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_ = server.RandomInitialUser(t)
|
||||
projects, err := server.Client.Projects(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, projects, 0)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Ensure global query works.
|
||||
projects, err := server.Client.Projects(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, projects, 1)
|
||||
|
||||
// Ensure specified query works.
|
||||
projects, err = server.Client.Projects(context.Background(), user.Organization)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, projects, 1)
|
||||
})
|
||||
|
||||
t.Run("ListEmpty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
|
||||
projects, err := server.Client.Projects(context.Background(), user.Organization)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, projects, 0)
|
||||
})
|
||||
|
||||
t.Run("Single", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.Project(context.Background(), user.Organization, project.Name)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("NoVersions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
versions, err := server.Client.ProjectVersions(context.Background(), user.Organization, project.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 0)
|
||||
})
|
||||
|
||||
t.Run("CreateVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var buffer bytes.Buffer
|
||||
writer := tar.NewWriter(&buffer)
|
||||
err = writer.WriteHeader(&tar.Header{
|
||||
Name: "file",
|
||||
Size: 1 << 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(make([]byte, 1<<10))
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
|
||||
Name: "moo",
|
||||
StorageMethod: database.ProjectStorageMethodInlineArchive,
|
||||
StorageSource: buffer.Bytes(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
versions, err := server.Client.ProjectVersions(context.Background(), user.Organization, project.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 1)
|
||||
})
|
||||
|
||||
t.Run("CreateVersionArchiveTooBig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var buffer bytes.Buffer
|
||||
writer := tar.NewWriter(&buffer)
|
||||
err = writer.WriteHeader(&tar.Header{
|
||||
Name: "file",
|
||||
Size: 1 << 21,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(make([]byte, 1<<21))
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
|
||||
Name: "moo",
|
||||
StorageMethod: database.ProjectStorageMethodInlineArchive,
|
||||
StorageSource: buffer.Bytes(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateVersionInvalidArchive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "someproject",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
|
||||
Name: "moo",
|
||||
StorageMethod: database.ProjectStorageMethodInlineArchive,
|
||||
StorageSource: []byte{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
|
@ -9,7 +9,9 @@ import (
|
|||
)
|
||||
|
||||
func TestUserPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Legacy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Ensures legacy v1 passwords function for v2.
|
||||
// This has is manually generated using a print statement from v1 code.
|
||||
equal, err := userpassword.Compare("$pbkdf2-sha256$65535$z8c1p1C2ru9EImBP1I+ZNA$pNjE3Yk0oG0PmJ0Je+y7ENOVlSkn/b0BEqqdKsq6Y97wQBq0xT+lD5bWJpyIKJqQICuPZcEaGDKrXJn8+SIHRg", "tomato")
|
||||
|
@ -18,6 +20,7 @@ func TestUserPassword(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Same", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hash, err := userpassword.Hash("password")
|
||||
require.NoError(t, err)
|
||||
equal, err := userpassword.Compare(hash, "password")
|
||||
|
@ -26,6 +29,7 @@ func TestUserPassword(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Different", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hash, err := userpassword.Hash("password")
|
||||
require.NoError(t, err)
|
||||
equal, err := userpassword.Compare(hash, "notpassword")
|
||||
|
@ -34,12 +38,14 @@ func TestUserPassword(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
equal, err := userpassword.Compare("invalidhash", "password")
|
||||
require.False(t, equal)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("InvalidParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
equal, err := userpassword.Compare("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "test")
|
||||
require.False(t, equal)
|
||||
require.Error(t, err)
|
||||
|
|
|
@ -35,6 +35,7 @@ func TestUsers(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Login", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.LoginWithPassword(context.Background(), coderd.LoginWithPasswordRequest{
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
)
|
||||
|
||||
// Projects lists projects inside an organization.
|
||||
// If organization is an empty string, all projects will be returned
|
||||
// for the authenticated user.
|
||||
func (c *Client) Projects(ctx context.Context, organization string) ([]coderd.Project, error) {
|
||||
route := "/api/v2/projects"
|
||||
if organization != "" {
|
||||
route = fmt.Sprintf("/api/v2/projects/%s", organization)
|
||||
}
|
||||
res, err := c.request(ctx, http.MethodGet, route, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
var projects []coderd.Project
|
||||
return projects, json.NewDecoder(res.Body).Decode(&projects)
|
||||
}
|
||||
|
||||
// Project returns a single project.
|
||||
func (c *Client) Project(ctx context.Context, organization, project string) (coderd.Project, error) {
|
||||
res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s", organization, project), nil)
|
||||
if err != nil {
|
||||
return coderd.Project{}, nil
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return coderd.Project{}, readBodyAsError(res)
|
||||
}
|
||||
var resp coderd.Project
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// CreateProject creates a new project inside an organization.
|
||||
func (c *Client) CreateProject(ctx context.Context, organization string, request coderd.CreateProjectRequest) (coderd.Project, error) {
|
||||
res, err := c.request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/projects/%s", organization), request)
|
||||
if err != nil {
|
||||
return coderd.Project{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return coderd.Project{}, readBodyAsError(res)
|
||||
}
|
||||
var project coderd.Project
|
||||
return project, json.NewDecoder(res.Body).Decode(&project)
|
||||
}
|
||||
|
||||
// ProjectVersions lists history for a project.
|
||||
func (c *Client) ProjectVersions(ctx context.Context, organization, project string) ([]coderd.ProjectVersion, error) {
|
||||
res, err := c.request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/projects/%s/%s/versions", organization, project), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
var projectVersions []coderd.ProjectVersion
|
||||
return projectVersions, json.NewDecoder(res.Body).Decode(&projectVersions)
|
||||
}
|
||||
|
||||
// CreateProjectVersion inserts a new version for the project.
|
||||
func (c *Client) CreateProjectVersion(ctx context.Context, organization, project string, request coderd.CreateProjectVersionRequest) (coderd.ProjectVersion, error) {
|
||||
res, err := c.request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/projects/%s/%s/versions", organization, project), request)
|
||||
if err != nil {
|
||||
return coderd.ProjectVersion{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return coderd.ProjectVersion{}, readBodyAsError(res)
|
||||
}
|
||||
var projectVersion coderd.ProjectVersion
|
||||
return projectVersion, json.NewDecoder(res.Body).Decode(&projectVersion)
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
package codersdk_test
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/database"
|
||||
)
|
||||
|
||||
func TestProjects(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UnauthenticatedList", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.Projects(context.Background(), "")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.Projects(context.Background(), "")
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.Projects(context.Background(), user.Organization)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedCreate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.CreateProject(context.Background(), "", coderd.CreateProjectRequest{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "bananas",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedSingle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.Project(context.Background(), "wow", "example")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Single", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
_, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "bananas",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.Project(context.Background(), user.Organization, "bananas")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedVersions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.ProjectVersions(context.Background(), "org", "project")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Versions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "bananas",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.ProjectVersions(context.Background(), user.Organization, project.Name)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateVersionUnauthenticated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.CreateProjectVersion(context.Background(), "org", "project", coderd.CreateProjectVersionRequest{
|
||||
Name: "hello",
|
||||
StorageMethod: database.ProjectStorageMethodInlineArchive,
|
||||
StorageSource: []byte{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("CreateVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
user := server.RandomInitialUser(t)
|
||||
project, err := server.Client.CreateProject(context.Background(), user.Organization, coderd.CreateProjectRequest{
|
||||
Name: "bananas",
|
||||
Provisioner: database.ProvisionerTypeTerraform,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var buffer bytes.Buffer
|
||||
writer := tar.NewWriter(&buffer)
|
||||
err = writer.WriteHeader(&tar.Header{
|
||||
Name: "file",
|
||||
Size: 1 << 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(make([]byte, 1<<10))
|
||||
require.NoError(t, err)
|
||||
_, err = server.Client.CreateProjectVersion(context.Background(), user.Organization, project.Name, coderd.CreateProjectVersionRequest{
|
||||
Name: "hello",
|
||||
StorageMethod: database.ProjectStorageMethodInlineArchive,
|
||||
StorageSource: buffer.Bytes(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
|
@ -11,7 +11,9 @@ import (
|
|||
)
|
||||
|
||||
func TestUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("CreateInitial", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.CreateInitialUser(context.Background(), coderd.CreateInitialUserRequest{
|
||||
Email: "wowie@coder.com",
|
||||
|
@ -23,12 +25,14 @@ func TestUsers(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("NoUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_, err := server.Client.User(context.Background(), "")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("User", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_ = server.RandomInitialUser(t)
|
||||
_, err := server.Client.User(context.Background(), "")
|
||||
|
@ -36,6 +40,7 @@ func TestUsers(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("UserOrganizations", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := coderdtest.New(t)
|
||||
_ = server.RandomInitialUser(t)
|
||||
orgs, err := server.Client.UserOrganizations(context.Background(), "")
|
||||
|
|
|
@ -47,6 +47,9 @@ func TestUnbiasedModulo32(t *testing.T) {
|
|||
const mod = 7
|
||||
dist := [mod]uint32{}
|
||||
|
||||
_, err := cryptorand.UnbiasedModulo32(0, mod)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
b := [4]byte{}
|
||||
_, _ = rand.Read(b[:])
|
||||
|
|
|
@ -91,6 +91,7 @@ func TestStringCharset(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
//nolint:paralleltest
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
|
|
|
@ -3,6 +3,9 @@ package databasefake
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/database"
|
||||
)
|
||||
|
@ -14,15 +17,25 @@ func New() database.Store {
|
|||
organizations: make([]database.Organization, 0),
|
||||
organizationMembers: make([]database.OrganizationMember, 0),
|
||||
users: make([]database.User, 0),
|
||||
|
||||
project: make([]database.Project, 0),
|
||||
projectHistory: make([]database.ProjectHistory, 0),
|
||||
projectParameter: make([]database.ProjectParameter, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// fakeQuerier replicates database functionality to enable quick testing.
|
||||
type fakeQuerier struct {
|
||||
// Legacy tables
|
||||
apiKeys []database.APIKey
|
||||
organizations []database.Organization
|
||||
organizationMembers []database.OrganizationMember
|
||||
users []database.User
|
||||
|
||||
// New tables
|
||||
project []database.Project
|
||||
projectHistory []database.ProjectHistory
|
||||
projectParameter []database.ProjectParameter
|
||||
}
|
||||
|
||||
// InTx doesn't rollback data properly for in-memory yet.
|
||||
|
@ -89,6 +102,62 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID string)
|
|||
return organizations, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetProjectByOrganizationAndName(_ context.Context, arg database.GetProjectByOrganizationAndNameParams) (database.Project, error) {
|
||||
for _, project := range q.project {
|
||||
if project.OrganizationID != arg.OrganizationID {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(project.Name, arg.Name) {
|
||||
continue
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
return database.Project{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetProjectHistoryByProjectID(_ context.Context, projectID uuid.UUID) ([]database.ProjectHistory, error) {
|
||||
history := make([]database.ProjectHistory, 0)
|
||||
for _, projectHistory := range q.projectHistory {
|
||||
if projectHistory.ProjectID.String() != projectID.String() {
|
||||
continue
|
||||
}
|
||||
history = append(history, projectHistory)
|
||||
}
|
||||
if len(history) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
return history, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetProjectsByOrganizationIDs(_ context.Context, ids []string) ([]database.Project, error) {
|
||||
projects := make([]database.Project, 0)
|
||||
for _, project := range q.project {
|
||||
for _, id := range ids {
|
||||
if project.OrganizationID == id {
|
||||
projects = append(projects, project)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(projects) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
return projects, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
|
||||
for _, organizationMember := range q.organizationMembers {
|
||||
if organizationMember.OrganizationID != arg.OrganizationID {
|
||||
continue
|
||||
}
|
||||
if organizationMember.UserID != arg.UserID {
|
||||
continue
|
||||
}
|
||||
return organizationMember, nil
|
||||
}
|
||||
return database.OrganizationMember{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) {
|
||||
//nolint:gosimple
|
||||
key := database.APIKey{
|
||||
|
@ -136,6 +205,59 @@ func (q *fakeQuerier) InsertOrganizationMember(_ context.Context, arg database.I
|
|||
return organizationMember, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertProject(_ context.Context, arg database.InsertProjectParams) (database.Project, error) {
|
||||
project := database.Project{
|
||||
ID: arg.ID,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
OrganizationID: arg.OrganizationID,
|
||||
Name: arg.Name,
|
||||
Provisioner: arg.Provisioner,
|
||||
}
|
||||
q.project = append(q.project, project)
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertProjectHistory(_ context.Context, arg database.InsertProjectHistoryParams) (database.ProjectHistory, error) {
|
||||
//nolint:gosimple
|
||||
history := database.ProjectHistory{
|
||||
ID: arg.ID,
|
||||
ProjectID: arg.ProjectID,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
Name: arg.Name,
|
||||
Description: arg.Description,
|
||||
StorageMethod: arg.StorageMethod,
|
||||
StorageSource: arg.StorageSource,
|
||||
ImportJobID: arg.ImportJobID,
|
||||
}
|
||||
q.projectHistory = append(q.projectHistory, history)
|
||||
return history, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertProjectParameter(_ context.Context, arg database.InsertProjectParameterParams) (database.ProjectParameter, error) {
|
||||
//nolint:gosimple
|
||||
param := database.ProjectParameter{
|
||||
ID: arg.ID,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
ProjectHistoryID: arg.ProjectHistoryID,
|
||||
Name: arg.Name,
|
||||
Description: arg.Description,
|
||||
DefaultSource: arg.DefaultSource,
|
||||
AllowOverrideSource: arg.AllowOverrideSource,
|
||||
DefaultDestination: arg.DefaultDestination,
|
||||
AllowOverrideDestination: arg.AllowOverrideDestination,
|
||||
DefaultRefresh: arg.DefaultRefresh,
|
||||
RedisplayValue: arg.RedisplayValue,
|
||||
ValidationError: arg.ValidationError,
|
||||
ValidationCondition: arg.ValidationCondition,
|
||||
ValidationTypeSystem: arg.ValidationTypeSystem,
|
||||
ValidationValueType: arg.ValidationValueType,
|
||||
}
|
||||
q.projectParameter = append(q.projectParameter, param)
|
||||
return param, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) {
|
||||
user := database.User{
|
||||
ID: arg.ID,
|
||||
|
|
|
@ -6,6 +6,19 @@ CREATE TYPE login_type AS ENUM (
|
|||
'oidc'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_type_system AS ENUM (
|
||||
'hcl'
|
||||
);
|
||||
|
||||
CREATE TYPE project_storage_method AS ENUM (
|
||||
'inline-archive'
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_type AS ENUM (
|
||||
'terraform',
|
||||
'cdr-basic'
|
||||
);
|
||||
|
||||
CREATE TYPE userstatus AS ENUM (
|
||||
'active',
|
||||
'dormant',
|
||||
|
@ -57,6 +70,46 @@ CREATE TABLE organizations (
|
|||
workspace_auto_off boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE project (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
organization_id text NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
provisioner provisioner_type NOT NULL,
|
||||
active_version_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE project_history (
|
||||
id uuid NOT NULL,
|
||||
project_id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
description character varying(1048576) NOT NULL,
|
||||
storage_method project_storage_method NOT NULL,
|
||||
storage_source bytea NOT NULL,
|
||||
import_job_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE project_parameter (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
project_history_id uuid NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
description character varying(8192) DEFAULT ''::character varying NOT NULL,
|
||||
default_source text,
|
||||
allow_override_source boolean NOT NULL,
|
||||
default_destination text,
|
||||
allow_override_destination boolean NOT NULL,
|
||||
default_refresh text NOT NULL,
|
||||
redisplay_value boolean NOT NULL,
|
||||
validation_error character varying(256) NOT NULL,
|
||||
validation_condition character varying(512) NOT NULL,
|
||||
validation_type_system parameter_type_system NOT NULL,
|
||||
validation_value_type character varying(64) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE users (
|
||||
id text NOT NULL,
|
||||
email text NOT NULL,
|
||||
|
@ -79,3 +132,27 @@ CREATE TABLE users (
|
|||
shell text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE ONLY project_history
|
||||
ADD CONSTRAINT project_history_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY project_history
|
||||
ADD CONSTRAINT project_history_project_id_name_key UNIQUE (project_id, name);
|
||||
|
||||
ALTER TABLE ONLY project
|
||||
ADD CONSTRAINT project_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY project
|
||||
ADD CONSTRAINT project_organization_id_name_key UNIQUE (organization_id, name);
|
||||
|
||||
ALTER TABLE ONLY project_parameter
|
||||
ADD CONSTRAINT project_parameter_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY project_parameter
|
||||
ADD CONSTRAINT project_parameter_project_history_id_name_key UNIQUE (project_history_id, name);
|
||||
|
||||
ALTER TABLE ONLY project_history
|
||||
ADD CONSTRAINT project_history_project_id_fkey FOREIGN KEY (project_id) REFERENCES project(id);
|
||||
|
||||
ALTER TABLE ONLY project_parameter
|
||||
ADD CONSTRAINT project_parameter_project_history_id_fkey FOREIGN KEY (project_history_id) REFERENCES project_history(id) ON DELETE CASCADE;
|
||||
|
||||
|
|
|
@ -20,12 +20,29 @@ func TestMain(m *testing.M) {
|
|||
func TestMigrate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connection, closeFn, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer closeFn()
|
||||
db, err := sql.Open("postgres", connection)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
err = database.Migrate(db)
|
||||
require.NoError(t, err)
|
||||
t.Run("Once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
connection, closeFn, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer closeFn()
|
||||
db, err := sql.Open("postgres", connection)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
err = database.Migrate(db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Twice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
connection, closeFn, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer closeFn()
|
||||
db, err := sql.Open("postgres", connection)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
err = database.Migrate(db)
|
||||
require.NoError(t, err)
|
||||
err = database.Migrate(db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
CREATE TYPE provisioner_type AS ENUM ('terraform', 'cdr-basic');
|
||||
|
||||
-- Project defines infrastructure that your software project
|
||||
-- requires for development.
|
||||
CREATE TABLE project (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
-- Projects must be scoped to an organization.
|
||||
organization_id text NOT NULL,
|
||||
name varchar(64) NOT NULL,
|
||||
provisioner provisioner_type NOT NULL,
|
||||
-- Target's a Project Version to use for Workspaces.
|
||||
-- If a Workspace doesn't match this version, it will be prompted to rebuild.
|
||||
active_version_id uuid,
|
||||
-- Disallow projects to have the same name under
|
||||
-- the same organization.
|
||||
UNIQUE(organization_id, name)
|
||||
);
|
||||
|
||||
CREATE TYPE project_storage_method AS ENUM ('inline-archive');
|
||||
|
||||
-- Project Versions store Project history. When a Project Version is imported,
|
||||
-- an "import" job is queued to parse parameters. A Project Version
|
||||
-- can only be used if the import job succeeds.
|
||||
CREATE TABLE project_history (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
-- This should be indexed.
|
||||
project_id uuid NOT NULL REFERENCES project (id),
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
-- Name is generated for ease of differentiation.
|
||||
-- eg. TheCozyRabbit16
|
||||
name varchar(64) NOT NULL,
|
||||
-- Extracted from a README.md on import.
|
||||
-- Maximum of 1MB.
|
||||
description varchar(1048576) NOT NULL,
|
||||
storage_method project_storage_method NOT NULL,
|
||||
storage_source bytea NOT NULL,
|
||||
-- The import job for a Project Version. This is used
|
||||
-- to detect if an import was successful.
|
||||
import_job_id uuid NOT NULL,
|
||||
-- Disallow projects to have the same build name
|
||||
-- multiple times.
|
||||
UNIQUE(project_id, name)
|
||||
);
|
||||
|
||||
-- Types of parameters the automator supports.
|
||||
CREATE TYPE parameter_type_system AS ENUM ('hcl');
|
||||
|
||||
-- Stores project version parameters parsed on import.
|
||||
-- No secrets are stored here.
|
||||
--
|
||||
-- All parameter validation occurs server-side to process
|
||||
-- complex validations.
|
||||
--
|
||||
-- Parameter types, description, and validation will produce
|
||||
-- a UI for users to enter values.
|
||||
-- Needs to be made consistent with the examples below.
|
||||
CREATE TABLE project_parameter (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
project_history_id uuid NOT NULL REFERENCES project_history(id) ON DELETE CASCADE,
|
||||
name varchar(64) NOT NULL,
|
||||
-- 8KB limit
|
||||
description varchar(8192) NOT NULL DEFAULT '',
|
||||
-- eg. data://inlinevalue
|
||||
default_source text,
|
||||
-- Allows the user to override the source.
|
||||
allow_override_source boolean NOT null,
|
||||
-- eg. env://SOME_VARIABLE, tfvars://example
|
||||
default_destination text,
|
||||
-- Allows the user to override the destination.
|
||||
allow_override_destination boolean NOT null,
|
||||
default_refresh text NOT NULL,
|
||||
-- Whether the consumer can view the source and destinations.
|
||||
redisplay_value boolean NOT null,
|
||||
-- This error would appear in the UI if the condition is not met.
|
||||
validation_error varchar(256) NOT NULL,
|
||||
validation_condition varchar(512) NOT NULL,
|
||||
validation_type_system parameter_type_system NOT NULL,
|
||||
validation_value_type varchar(64) NOT NULL,
|
||||
UNIQUE(project_history_id, name)
|
||||
);
|
|
@ -3,9 +3,12 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type LoginType string
|
||||
|
@ -28,6 +31,61 @@ func (e *LoginType) Scan(src interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type ParameterTypeSystem string
|
||||
|
||||
const (
|
||||
ParameterTypeSystemHCL ParameterTypeSystem = "hcl"
|
||||
)
|
||||
|
||||
func (e *ParameterTypeSystem) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ParameterTypeSystem(s)
|
||||
case string:
|
||||
*e = ParameterTypeSystem(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ParameterTypeSystem: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProjectStorageMethod string
|
||||
|
||||
const (
|
||||
ProjectStorageMethodInlineArchive ProjectStorageMethod = "inline-archive"
|
||||
)
|
||||
|
||||
func (e *ProjectStorageMethod) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ProjectStorageMethod(s)
|
||||
case string:
|
||||
*e = ProjectStorageMethod(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ProjectStorageMethod: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProvisionerType string
|
||||
|
||||
const (
|
||||
ProvisionerTypeTerraform ProvisionerType = "terraform"
|
||||
ProvisionerTypeCdrBasic ProvisionerType = "cdr-basic"
|
||||
)
|
||||
|
||||
func (e *ProvisionerType) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ProvisionerType(s)
|
||||
case string:
|
||||
*e = ProvisionerType(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ProvisionerType: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserStatus string
|
||||
|
||||
const (
|
||||
|
@ -93,6 +151,46 @@ type OrganizationMember struct {
|
|||
Roles []string `db:"roles" json:"roles"`
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
|
||||
ActiveVersionID uuid.NullUUID `db:"active_version_id" json:"active_version_id"`
|
||||
}
|
||||
|
||||
type ProjectHistory struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ProjectID uuid.UUID `db:"project_id" json:"project_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
StorageMethod ProjectStorageMethod `db:"storage_method" json:"storage_method"`
|
||||
StorageSource []byte `db:"storage_source" json:"storage_source"`
|
||||
ImportJobID uuid.UUID `db:"import_job_id" json:"import_job_id"`
|
||||
}
|
||||
|
||||
type ProjectParameter struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
DefaultSource sql.NullString `db:"default_source" json:"default_source"`
|
||||
AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"`
|
||||
DefaultDestination sql.NullString `db:"default_destination" json:"default_destination"`
|
||||
AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"`
|
||||
DefaultRefresh string `db:"default_refresh" json:"default_refresh"`
|
||||
RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"`
|
||||
ValidationError string `db:"validation_error" json:"validation_error"`
|
||||
ValidationCondition string `db:"validation_condition" json:"validation_condition"`
|
||||
ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"`
|
||||
ValidationValueType string `db:"validation_value_type" json:"validation_value_type"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Email string `db:"email" json:"email"`
|
||||
|
|
|
@ -18,6 +18,7 @@ func TestPubsub(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
t.Run("Postgres", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
|
@ -45,4 +46,20 @@ func TestPubsub(t *testing.T) {
|
|||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
|
||||
t.Run("PostgresCloseCancel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
connectionURL, close, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer close()
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
cancelFunc()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,18 +4,27 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type querier interface {
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
|
||||
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error)
|
||||
GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error)
|
||||
GetProjectHistoryByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectHistory, error)
|
||||
GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error)
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id string) (User, error)
|
||||
GetUserCount(ctx context.Context) (int64, error)
|
||||
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
|
||||
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
|
||||
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
|
||||
InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error)
|
||||
InsertProjectHistory(ctx context.Context, arg InsertProjectHistoryParams) (ProjectHistory, error)
|
||||
InsertProjectParameter(ctx context.Context, arg InsertProjectParameterParams) (ProjectParameter, error)
|
||||
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
}
|
||||
|
|
|
@ -42,12 +42,67 @@ FROM
|
|||
users;
|
||||
|
||||
-- name: GetOrganizationByName :one
|
||||
SELECT * FROM organizations WHERE name = $1 LIMIT 1;
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
name = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetOrganizationsByUserID :many
|
||||
SELECT * FROM organizations WHERE id = (
|
||||
SELECT organization_id FROM organization_members WHERE user_id = $1
|
||||
);
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
id = (
|
||||
SELECT
|
||||
organization_id
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
user_id = $1
|
||||
);
|
||||
|
||||
-- name: GetOrganizationMemberByUserID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND user_id = $2
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetProjectByOrganizationAndName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND name = $2
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetProjectsByOrganizationIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project
|
||||
WHERE
|
||||
organization_id = ANY(@ids :: text [ ]);
|
||||
|
||||
-- name: GetProjectHistoryByProjectID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project_history
|
||||
WHERE
|
||||
project_id = $1;
|
||||
|
||||
-- name: InsertAPIKey :one
|
||||
INSERT INTO
|
||||
|
@ -88,10 +143,89 @@ VALUES
|
|||
) RETURNING *;
|
||||
|
||||
-- name: InsertOrganization :one
|
||||
INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING *;
|
||||
INSERT INTO
|
||||
organizations (id, name, description, created_at, updated_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
-- name: InsertOrganizationMember :one
|
||||
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING *;
|
||||
INSERT INTO
|
||||
organization_members (
|
||||
organization_id,
|
||||
user_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
roles
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
-- name: InsertProject :one
|
||||
INSERT INTO
|
||||
project (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
organization_id,
|
||||
name,
|
||||
provisioner
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6) RETURNING *;
|
||||
|
||||
-- name: InsertProjectHistory :one
|
||||
INSERT INTO
|
||||
project_history (
|
||||
id,
|
||||
project_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
name,
|
||||
description,
|
||||
storage_method,
|
||||
storage_source,
|
||||
import_job_id
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
|
||||
|
||||
-- name: InsertProjectParameter :one
|
||||
INSERT INTO
|
||||
project_parameter (
|
||||
id,
|
||||
created_at,
|
||||
project_history_id,
|
||||
name,
|
||||
description,
|
||||
default_source,
|
||||
allow_override_source,
|
||||
default_destination,
|
||||
allow_override_destination,
|
||||
default_refresh,
|
||||
redisplay_value,
|
||||
validation_error,
|
||||
validation_condition,
|
||||
validation_type_system,
|
||||
validation_value_type
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING *;
|
||||
|
||||
-- name: InsertUser :one
|
||||
INSERT INTO
|
||||
|
|
|
@ -5,8 +5,10 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
@ -45,7 +47,14 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
|
|||
}
|
||||
|
||||
const getOrganizationByName = `-- name: GetOrganizationByName :one
|
||||
SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE name = $1 LIMIT 1
|
||||
SELECT
|
||||
id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
name = $1
|
||||
LIMIT
|
||||
1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Organization, error) {
|
||||
|
@ -66,10 +75,50 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or
|
|||
return i, err
|
||||
}
|
||||
|
||||
const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one
|
||||
SELECT
|
||||
organization_id, user_id, created_at, updated_at, roles
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND user_id = $2
|
||||
LIMIT
|
||||
1
|
||||
`
|
||||
|
||||
type GetOrganizationMemberByUserIDParams struct {
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
UserID string `db:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOrganizationMemberByUserID, arg.OrganizationID, arg.UserID)
|
||||
var i OrganizationMember
|
||||
err := row.Scan(
|
||||
&i.OrganizationID,
|
||||
&i.UserID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
pq.Array(&i.Roles),
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many
|
||||
SELECT id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off FROM organizations WHERE id = (
|
||||
SELECT organization_id FROM organization_members WHERE user_id = $1
|
||||
)
|
||||
SELECT
|
||||
id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
id = (
|
||||
SELECT
|
||||
organization_id
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
user_id = $1
|
||||
)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error) {
|
||||
|
@ -106,6 +155,120 @@ func (q *sqlQuerier) GetOrganizationsByUserID(ctx context.Context, userID string
|
|||
return items, nil
|
||||
}
|
||||
|
||||
const getProjectByOrganizationAndName = `-- name: GetProjectByOrganizationAndName :one
|
||||
SELECT
|
||||
id, created_at, updated_at, organization_id, name, provisioner, active_version_id
|
||||
FROM
|
||||
project
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND name = $2
|
||||
LIMIT
|
||||
1
|
||||
`
|
||||
|
||||
type GetProjectByOrganizationAndNameParams struct {
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error) {
|
||||
row := q.db.QueryRowContext(ctx, getProjectByOrganizationAndName, arg.OrganizationID, arg.Name)
|
||||
var i Project
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OrganizationID,
|
||||
&i.Name,
|
||||
&i.Provisioner,
|
||||
&i.ActiveVersionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getProjectHistoryByProjectID = `-- name: GetProjectHistoryByProjectID :many
|
||||
SELECT
|
||||
id, project_id, created_at, updated_at, name, description, storage_method, storage_source, import_job_id
|
||||
FROM
|
||||
project_history
|
||||
WHERE
|
||||
project_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetProjectHistoryByProjectID(ctx context.Context, projectID uuid.UUID) ([]ProjectHistory, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getProjectHistoryByProjectID, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ProjectHistory
|
||||
for rows.Next() {
|
||||
var i ProjectHistory
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ProjectID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.StorageMethod,
|
||||
&i.StorageSource,
|
||||
&i.ImportJobID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getProjectsByOrganizationIDs = `-- name: GetProjectsByOrganizationIDs :many
|
||||
SELECT
|
||||
id, created_at, updated_at, organization_id, name, provisioner, active_version_id
|
||||
FROM
|
||||
project
|
||||
WHERE
|
||||
organization_id = ANY($1 :: text [ ])
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetProjectsByOrganizationIDs(ctx context.Context, ids []string) ([]Project, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getProjectsByOrganizationIDs, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Project
|
||||
for rows.Next() {
|
||||
var i Project
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OrganizationID,
|
||||
&i.Name,
|
||||
&i.Provisioner,
|
||||
&i.ActiveVersionID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
|
||||
SELECT
|
||||
id, email, name, revoked, login_type, hashed_password, created_at, updated_at, temporary_password, avatar_hash, ssh_key_regenerated_at, username, dotfiles_git_uri, roles, status, relatime, gpg_key_regenerated_at, _decomissioned, shell
|
||||
|
@ -299,7 +462,10 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
|
|||
}
|
||||
|
||||
const insertOrganization = `-- name: InsertOrganization :one
|
||||
INSERT INTO organizations (id, name, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
|
||||
INSERT INTO
|
||||
organizations (id, name, description, created_at, updated_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING id, name, description, created_at, updated_at, "default", auto_off_threshold, cpu_provisioning_rate, memory_provisioning_rate, workspace_auto_off
|
||||
`
|
||||
|
||||
type InsertOrganizationParams struct {
|
||||
|
@ -335,7 +501,16 @@ func (q *sqlQuerier) InsertOrganization(ctx context.Context, arg InsertOrganizat
|
|||
}
|
||||
|
||||
const insertOrganizationMember = `-- name: InsertOrganizationMember :one
|
||||
INSERT INTO organization_members (organization_id, user_id, created_at, updated_at, roles) VALUES ($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles
|
||||
INSERT INTO
|
||||
organization_members (
|
||||
organization_id,
|
||||
user_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
roles
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING organization_id, user_id, created_at, updated_at, roles
|
||||
`
|
||||
|
||||
type InsertOrganizationMemberParams struct {
|
||||
|
@ -365,6 +540,203 @@ func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrg
|
|||
return i, err
|
||||
}
|
||||
|
||||
const insertProject = `-- name: InsertProject :one
|
||||
INSERT INTO
|
||||
project (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
organization_id,
|
||||
name,
|
||||
provisioner
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6) RETURNING id, created_at, updated_at, organization_id, name, provisioner, active_version_id
|
||||
`
|
||||
|
||||
type InsertProjectParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertProject,
|
||||
arg.ID,
|
||||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.OrganizationID,
|
||||
arg.Name,
|
||||
arg.Provisioner,
|
||||
)
|
||||
var i Project
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OrganizationID,
|
||||
&i.Name,
|
||||
&i.Provisioner,
|
||||
&i.ActiveVersionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertProjectHistory = `-- name: InsertProjectHistory :one
|
||||
INSERT INTO
|
||||
project_history (
|
||||
id,
|
||||
project_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
name,
|
||||
description,
|
||||
storage_method,
|
||||
storage_source,
|
||||
import_job_id
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, project_id, created_at, updated_at, name, description, storage_method, storage_source, import_job_id
|
||||
`
|
||||
|
||||
type InsertProjectHistoryParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ProjectID uuid.UUID `db:"project_id" json:"project_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
StorageMethod ProjectStorageMethod `db:"storage_method" json:"storage_method"`
|
||||
StorageSource []byte `db:"storage_source" json:"storage_source"`
|
||||
ImportJobID uuid.UUID `db:"import_job_id" json:"import_job_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertProjectHistory(ctx context.Context, arg InsertProjectHistoryParams) (ProjectHistory, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertProjectHistory,
|
||||
arg.ID,
|
||||
arg.ProjectID,
|
||||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.Name,
|
||||
arg.Description,
|
||||
arg.StorageMethod,
|
||||
arg.StorageSource,
|
||||
arg.ImportJobID,
|
||||
)
|
||||
var i ProjectHistory
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.ProjectID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.StorageMethod,
|
||||
&i.StorageSource,
|
||||
&i.ImportJobID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertProjectParameter = `-- name: InsertProjectParameter :one
|
||||
INSERT INTO
|
||||
project_parameter (
|
||||
id,
|
||||
created_at,
|
||||
project_history_id,
|
||||
name,
|
||||
description,
|
||||
default_source,
|
||||
allow_override_source,
|
||||
default_destination,
|
||||
allow_override_destination,
|
||||
default_refresh,
|
||||
redisplay_value,
|
||||
validation_error,
|
||||
validation_condition,
|
||||
validation_type_system,
|
||||
validation_value_type
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING id, created_at, project_history_id, name, description, default_source, allow_override_source, default_destination, allow_override_destination, default_refresh, redisplay_value, validation_error, validation_condition, validation_type_system, validation_value_type
|
||||
`
|
||||
|
||||
type InsertProjectParameterParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
ProjectHistoryID uuid.UUID `db:"project_history_id" json:"project_history_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
DefaultSource sql.NullString `db:"default_source" json:"default_source"`
|
||||
AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"`
|
||||
DefaultDestination sql.NullString `db:"default_destination" json:"default_destination"`
|
||||
AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"`
|
||||
DefaultRefresh string `db:"default_refresh" json:"default_refresh"`
|
||||
RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"`
|
||||
ValidationError string `db:"validation_error" json:"validation_error"`
|
||||
ValidationCondition string `db:"validation_condition" json:"validation_condition"`
|
||||
ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"`
|
||||
ValidationValueType string `db:"validation_value_type" json:"validation_value_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertProjectParameter(ctx context.Context, arg InsertProjectParameterParams) (ProjectParameter, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertProjectParameter,
|
||||
arg.ID,
|
||||
arg.CreatedAt,
|
||||
arg.ProjectHistoryID,
|
||||
arg.Name,
|
||||
arg.Description,
|
||||
arg.DefaultSource,
|
||||
arg.AllowOverrideSource,
|
||||
arg.DefaultDestination,
|
||||
arg.AllowOverrideDestination,
|
||||
arg.DefaultRefresh,
|
||||
arg.RedisplayValue,
|
||||
arg.ValidationError,
|
||||
arg.ValidationCondition,
|
||||
arg.ValidationTypeSystem,
|
||||
arg.ValidationValueType,
|
||||
)
|
||||
var i ProjectParameter
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.ProjectHistoryID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.DefaultSource,
|
||||
&i.AllowOverrideSource,
|
||||
&i.DefaultDestination,
|
||||
&i.AllowOverrideDestination,
|
||||
&i.DefaultRefresh,
|
||||
&i.RedisplayValue,
|
||||
&i.ValidationError,
|
||||
&i.ValidationCondition,
|
||||
&i.ValidationTypeSystem,
|
||||
&i.ValidationValueType,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertUser = `-- name: InsertUser :one
|
||||
INSERT INTO
|
||||
users (
|
||||
|
|
|
@ -25,4 +25,5 @@ rename:
|
|||
oidc_expiry: OIDCExpiry
|
||||
oidc_id_token: OIDCIDToken
|
||||
oidc_refresh_token: OIDCRefreshToken
|
||||
parameter_type_system_hcl: ParameterTypeSystemHCL
|
||||
userstatus: UserStatus
|
||||
|
|
1
go.mod
1
go.mod
|
@ -18,6 +18,7 @@ require (
|
|||
github.com/hashicorp/terraform-exec v0.15.0
|
||||
github.com/justinas/nosurf v1.1.1
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/moby/moby v20.10.12+incompatible
|
||||
github.com/ory/dockertest/v3 v3.8.1
|
||||
github.com/pion/datachannel v1.5.2
|
||||
github.com/pion/logging v0.2.2
|
||||
|
|
2
go.sum
2
go.sum
|
@ -903,6 +903,8 @@ github.com/mitchellh/osext v0.0.0-20151018003038-5e2d6d41470f/go.mod h1:OkQIRizQ
|
|||
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc=
|
||||
github.com/moby/moby v20.10.12+incompatible h1:MJVrdG0tIQqVJQBTdtooPuZQFIgski5pYTXlcW8ToE0=
|
||||
github.com/moby/moby v20.10.12+incompatible/go.mod h1:fDXVQ6+S340veQPv35CzDahGBmHsiclFwfEygB/TWMc=
|
||||
github.com/moby/sys/mountinfo v0.4.0/go.mod h1:rEr8tzG/lsIZHBtN/JjGG+LMYx9eXgW2JI+6q0qou+A=
|
||||
github.com/moby/sys/mountinfo v0.4.1/go.mod h1:rEr8tzG/lsIZHBtN/JjGG+LMYx9eXgW2JI+6q0qou+A=
|
||||
github.com/moby/sys/symlink v0.1.0/go.mod h1:GGDODQmbFOjFsXvfLVn3+ZRxkch54RkSiGqsZeMYowQ=
|
||||
|
|
|
@ -13,7 +13,9 @@ import (
|
|||
)
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NoErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
httpapi.Write(rw, http.StatusOK, httpapi.Response{
|
||||
Message: "wow",
|
||||
|
@ -27,7 +29,9 @@ func TestWrite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("EmptyStruct", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
|
||||
v := struct{}{}
|
||||
|
@ -35,6 +39,7 @@ func TestRead(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("NoBody", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", nil)
|
||||
var v json.RawMessage
|
||||
|
@ -42,6 +47,7 @@ func TestRead(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
type toValidate struct {
|
||||
Value string `json:"value" validate:"required"`
|
||||
}
|
||||
|
@ -54,6 +60,7 @@ func TestRead(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("ValidateFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
type toValidate struct {
|
||||
Value string `json:"value" validate:"required"`
|
||||
}
|
||||
|
@ -72,6 +79,7 @@ func TestRead(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReadUsername(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Tests whether usernames are valid or not.
|
||||
testCases := []struct {
|
||||
Username string
|
||||
|
@ -121,7 +129,9 @@ func TestReadUsername(t *testing.T) {
|
|||
Username string `json:"username" validate:"username"`
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.Username, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
data, err := json.Marshal(toValidate{testCase.Username})
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -35,6 +35,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("NoCookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -47,6 +48,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("InvalidFormat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -64,6 +66,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("InvalidIDLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -81,6 +84,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("InvalidSecretLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -98,6 +102,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -116,6 +121,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("InvalidSecret", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -141,6 +147,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -165,6 +172,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -203,6 +211,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -235,6 +244,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("ValidUpdateExpiry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -267,6 +277,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("OIDCNotExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
@ -300,6 +311,7 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("OIDCRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
)
|
||||
|
||||
type organizationParamContextKey struct{}
|
||||
type organizationMemberParamContextKey struct{}
|
||||
|
||||
// OrganizationParam returns the organization from the ExtractOrganizationParam handler.
|
||||
func OrganizationParam(r *http.Request) database.Organization {
|
||||
organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization)
|
||||
if !ok {
|
||||
panic("developer error: organization param middleware not provided")
|
||||
}
|
||||
return organization
|
||||
}
|
||||
|
||||
// OrganizationMemberParam returns the organization membership that allowed the query
|
||||
// from the ExtractOrganizationParam handler.
|
||||
func OrganizationMemberParam(r *http.Request) database.OrganizationMember {
|
||||
organizationMember, ok := r.Context().Value(organizationMemberParamContextKey{}).(database.OrganizationMember)
|
||||
if !ok {
|
||||
panic("developer error: organization param middleware not provided")
|
||||
}
|
||||
return organizationMember
|
||||
}
|
||||
|
||||
// ExtractOrganizationParam grabs an organization and user membership from the "organization" URL parameter.
|
||||
// This middleware requires the API key middleware higher in the call stack for authentication.
|
||||
func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := APIKey(r)
|
||||
organizationName := chi.URLParam(r, "organization")
|
||||
if organizationName == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "organization name must be provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
organization, err := db.GetOrganizationByName(r.Context(), organizationName)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("organization %q does not exist", organizationName),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get organization: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
organizationMember, err := db.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: apiKey.UserID,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "not a member of the organization",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get organization member: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization)
|
||||
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/database/databasefake"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func TestOrganizationParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.User) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
return r, user
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, _ = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, _ = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", "nothin")
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotInOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, _ = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: uuid.NewString(),
|
||||
Name: "test",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.Name)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, user = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: uuid.NewString(),
|
||||
Name: "test",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.Name)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.OrganizationParam(r)
|
||||
_ = httpmw.OrganizationMemberParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
)
|
||||
|
||||
type projectParamContextKey struct{}
|
||||
|
||||
// ProjectParam returns the project from the ExtractProjectParameter handler.
|
||||
func ProjectParam(r *http.Request) database.Project {
|
||||
project, ok := r.Context().Value(projectParamContextKey{}).(database.Project)
|
||||
if !ok {
|
||||
panic("developer error: project param middleware not provided")
|
||||
}
|
||||
return project
|
||||
}
|
||||
|
||||
// ExtractProjectParameter grabs a project from the "project" URL parameter.
|
||||
func ExtractProjectParameter(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
organization := OrganizationParam(r)
|
||||
projectName := chi.URLParam(r, "project")
|
||||
if projectName == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "project name must be provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
project, err := db.GetProjectByOrganizationAndName(r.Context(), database.GetProjectByOrganizationAndNameParams{
|
||||
OrganizationID: organization.ID,
|
||||
Name: projectName,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("project %q does not exist", projectName),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get project: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), projectParamContextKey{}, project)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,151 @@
|
|||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/database/databasefake"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func TestProjectParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.Organization) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
orgID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: orgID,
|
||||
Name: "banana",
|
||||
Description: "wowie",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: orgID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("organization", organization.Name)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, organization
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
httpmw.ExtractProjectParameter(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
httpmw.ExtractProjectParameter(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("project", "nothin")
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Project", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
httpmw.ExtractProjectParameter(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.ProjectParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, org := setupAuthentication(db)
|
||||
project, err := db.InsertProject(context.Background(), database.InsertProjectParams{
|
||||
ID: uuid.New(),
|
||||
OrganizationID: org.ID,
|
||||
Name: "moo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("project", project.Name)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
|
@ -18,6 +18,7 @@ import (
|
|||
)
|
||||
|
||||
func TestUserParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) {
|
||||
var (
|
||||
db = databasefake.New()
|
||||
|
@ -47,6 +48,7 @@ func TestUserParam(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
|
@ -62,6 +64,7 @@ func TestUserParam(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("NotMe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
|
@ -80,6 +83,7 @@ func TestUserParam(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Me", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
|
|
|
@ -211,6 +211,7 @@ func TestConn(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("CloseWithError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
|
||||
require.NoError(t, err)
|
||||
expectedErr := errors.New("wow")
|
||||
|
|
|
@ -14,9 +14,11 @@ import (
|
|||
)
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Ensures connections blocked on Accept() are
|
||||
// closed if the listener is.
|
||||
t.Run("NoAcceptClosed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
defer client.Close()
|
||||
|
@ -37,6 +39,7 @@ func TestListen(t *testing.T) {
|
|||
|
||||
// Ensures Accept() properly exits when Close() is called.
|
||||
t.Run("AcceptClosed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
|
|
@ -18,7 +18,9 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func TestProvisionerSDK(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Serve", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
@ -37,6 +39,7 @@ func TestProvisionerSDK(t *testing.T) {
|
|||
require.Equal(t, drpcerr.Unimplemented, int(drpcerr.Code(err)))
|
||||
})
|
||||
t.Run("ServeClosedPipe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
|
|
Loading…
Reference in New Issue