From 5c6974e55ff5aa3a1110de95954924887e583c2a Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 4 Mar 2024 15:15:41 -0600 Subject: [PATCH] feat: implement provisioner auth middleware and proper org params (#12330) * feat: provisioner auth in mw to allow ExtractOrg Step to enable org scoped provisioner daemons * chore: handle default org handling for provisioner daemons --- coderd/database/dbauthz/dbauthz.go | 3 + coderd/httpmw/actor.go | 29 +++++++ coderd/httpmw/organizationparam.go | 33 ++++++-- coderd/httpmw/organizationparam_test.go | 19 +++++ coderd/httpmw/provisionerdaemon.go | 86 ++++++++++++++++++++ coderd/organizations.go | 7 ++ codersdk/organizations.go | 3 + codersdk/provisionerdaemons.go | 11 ++- enterprise/coderd/coderd.go | 9 ++ enterprise/coderd/provisionerdaemons.go | 20 ++--- enterprise/coderd/provisionerdaemons_test.go | 11 +-- 11 files changed, 201 insertions(+), 30 deletions(-) create mode 100644 coderd/httpmw/provisionerdaemon.go diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 135703bb0b..32f2761f91 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -170,6 +170,9 @@ var ( rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete}, rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate}, rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol}, + // When org scoped provisioner credentials are implemented, + // this can be reduced to read a specific org. + rbac.ResourceOrganization.Type: {rbac.ActionRead}, }), Org: map[string][]rbac.Permission{}, User: []rbac.Permission{}, diff --git a/coderd/httpmw/actor.go b/coderd/httpmw/actor.go index af3142aed2..59eb1cf907 100644 --- a/coderd/httpmw/actor.go +++ b/coderd/httpmw/actor.go @@ -64,3 +64,32 @@ func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler { }) } } + +// RequireAPIKeyOrProvisionerDaemonAuth is middleware that should be inserted +// after optional ExtractAPIKey and ExtractProvisionerDaemonAuthenticated +// middlewares to ensure one of the two authentication methods is provided. +// +// If both are provided, an error is returned to avoid misuse. +func RequireAPIKeyOrProvisionerDaemonAuth() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, hasAPIKey := APIKeyOptional(r) + hasProvisionerDaemon := ProvisionerDaemonAuthenticated(r) + + if hasAPIKey && hasProvisionerDaemon { + httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{ + Message: "API key and external provisioner authentication provided, but only one is allowed", + }) + return + } + if !hasAPIKey && !hasProvisionerDaemon { + httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ + Message: "API key or external provisioner authentication required, but none provided", + }) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/coderd/httpmw/organizationparam.go b/coderd/httpmw/organizationparam.go index c219751e2b..0c8ccae96c 100644 --- a/coderd/httpmw/organizationparam.go +++ b/coderd/httpmw/organizationparam.go @@ -53,15 +53,30 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler } var organization database.Organization - var err error - // Try by name or uuid. - id, err := uuid.Parse(arg) - if err == nil { - organization, err = db.GetOrganizationByID(ctx, id) + var dbErr error + + // If the name is exactly "default", then we fetch the default + // organization. This is a special case to make it easier + // for single org deployments. + // + // arg == uuid.Nil.String() should be a temporary workaround for + // legacy provisioners that don't provide an organization ID. + // This prevents a breaking change. + // TODO: This change was added March 2024. Nil uuid returning the + // default org should be removed some number of months after + // that date. + if arg == codersdk.DefaultOrganization || arg == uuid.Nil.String() { + organization, dbErr = db.GetDefaultOrganization(ctx) } else { - organization, err = db.GetOrganizationByName(ctx, arg) + // Try by name or uuid. + id, err := uuid.Parse(arg) + if err == nil { + organization, dbErr = db.GetOrganizationByID(ctx, id) + } else { + organization, dbErr = db.GetOrganizationByName(ctx, arg) + } } - if httpapi.Is404Error(err) { + if httpapi.Is404Error(dbErr) { httpapi.ResourceNotFound(rw) httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: fmt.Sprintf("Organization %q not found.", arg), @@ -69,10 +84,10 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler }) return } - if err != nil { + if dbErr != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: fmt.Sprintf("Internal error fetching organization %q.", arg), - Detail: err.Error(), + Detail: dbErr.Error(), }) return } diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index e5415d1348..02b7ce1e14 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -208,5 +208,24 @@ func TestOrganizationParam(t *testing.T) { res = rw.Result() defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode, "by name") + + // Try by 'default' + chi.RouteContext(r.Context()).URLParams.Add("organization", codersdk.DefaultOrganization) + chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String()) + rtr.ServeHTTP(rw, r) + res = rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode, "by default keyword") + + // Try by legacy + // TODO: This can be removed when legacy nil uuids are no longer supported. + // This is a temporary measure to ensure as legacy provisioners use + // nil uuids as the org id and expect the default org. + chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.Nil.String()) + chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String()) + rtr.ServeHTTP(rw, r) + res = rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode, "by nil uuid (legacy)") }) } diff --git a/coderd/httpmw/provisionerdaemon.go b/coderd/httpmw/provisionerdaemon.go new file mode 100644 index 0000000000..84e335751f --- /dev/null +++ b/coderd/httpmw/provisionerdaemon.go @@ -0,0 +1,86 @@ +package httpmw + +import ( + "context" + "crypto/subtle" + "net/http" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" +) + +type provisionerDaemonContextKey struct{} + +func ProvisionerDaemonAuthenticated(r *http.Request) bool { + proxy, ok := r.Context().Value(provisionerDaemonContextKey{}).(bool) + return ok && proxy +} + +type ExtractProvisionerAuthConfig struct { + DB database.Store + Optional bool +} + +func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + handleOptional := func(code int, response codersdk.Response) { + if opts.Optional { + next.ServeHTTP(w, r) + return + } + httpapi.Write(ctx, w, code, response) + } + + if psk == "" { + // No psk means external provisioner daemons are not allowed. + // So their auth is not valid. + handleOptional(http.StatusBadRequest, codersdk.Response{ + Message: "External provisioner daemons not enabled", + }) + return + } + + token := r.Header.Get(codersdk.ProvisionerDaemonPSK) + if token == "" { + handleOptional(http.StatusUnauthorized, codersdk.Response{ + Message: "provisioner daemon auth token required", + }) + return + } + + if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 { + handleOptional(http.StatusUnauthorized, codersdk.Response{ + Message: "provisioner daemon auth token invalid", + }) + return + } + + // The PSK does not indicate a specific provisioner daemon. So just + // store a boolean so the caller can check if the request is from an + // authenticated provisioner daemon. + ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true) + // nolint:gocritic // Authenticating as a provisioner daemon. + ctx = dbauthz.AsProvisionerd(ctx) + subj, ok := dbauthz.ActorFromContext(ctx) + if !ok { + // This should never happen + httpapi.InternalServerError(w, xerrors.New("developer error: ExtractProvisionerDaemonAuth missing rbac actor")) + } + + // Use the same subject for the userAuthKey + ctx = context.WithValue(ctx, userAuthKey{}, Authorization{ + Actor: subj, + ActorName: "provisioner_daemon", + }) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/coderd/organizations.go b/coderd/organizations.go index d50c0a4e25..e5098a9697 100644 --- a/coderd/organizations.go +++ b/coderd/organizations.go @@ -50,6 +50,13 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) { return } + if req.Name == codersdk.DefaultOrganization { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Organization name %q is reserved.", codersdk.DefaultOrganization), + }) + return + } + _, err := api.Database.GetOrganizationByName(ctx, req.Name) if err == nil { httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ diff --git a/codersdk/organizations.go b/codersdk/organizations.go index 55e2a6b1ab..88a710f642 100644 --- a/codersdk/organizations.go +++ b/codersdk/organizations.go @@ -11,6 +11,9 @@ import ( "golang.org/x/xerrors" ) +// DefaultOrganization is used as a replacement for the default organization. +var DefaultOrganization = "default" + type ProvisionerStorageMethod string const ( diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index fbfa800cfd..300e24b64e 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -179,8 +179,8 @@ type ServeProvisionerDaemonRequest struct { ID uuid.UUID `json:"id" format:"uuid"` // Name is the human-readable unique identifier for the daemon. Name string `json:"name" example:"my-cool-provisioner-daemon"` - // Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations - // and so the organization ID is optional. + // Organization is the organization for the URL. If no orgID is provided, + // then it is assumed to use the default organization. Organization uuid.UUID `json:"organization" format:"uuid"` // Provisioners is a list of provisioner types hosted by the provisioner daemon Provisioners []ProvisionerType `json:"provisioners"` @@ -194,7 +194,12 @@ type ServeProvisionerDaemonRequest struct { // implementation. The context is during dial, not during the lifetime of the // client. Client should be closed after use. func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) { - serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization)) + orgParam := req.Organization.String() + if req.Organization == uuid.Nil { + orgParam = DefaultOrganization + } + + serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", orgParam)) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 8c22ea0f0b..16da0453a5 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) { r.Use( api.provisionerDaemonsEnabledMW, + apiKeyMiddlewareOptional, + httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ + DB: api.Database, + Optional: true, + }, api.ProvisionerDaemonPSK), + // Either a user auth or provisioner auth is required + // to move forward. + httpmw.RequireAPIKeyOrProvisionerDaemonAuth(), + httpmw.ExtractOrganizationParam(api.Database), ) r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons) r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe) diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 709cce3fa6..7c89722974 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -2,7 +2,6 @@ package coderd import ( "context" - "crypto/subtle" "database/sql" "errors" "fmt" @@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { }) return } - apiDaemons := make([]codersdk.ProvisionerDaemon, 0) - for _, daemon := range daemons { - apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon)) - } - httpapi.Write(ctx, rw, http.StatusOK, apiDaemons) + + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon)) } type provisionerDaemonAuth struct { @@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin } // Check for PSK - if p.psk != "" { - psk := r.Header.Get(codersdk.ProvisionerDaemonPSK) - if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 { - // If using PSK auth, the daemon is, by definition, scoped to the organization. - tags = provisionersdk.MutateTags(uuid.Nil, tags) - return tags, true - } + provAuth := httpmw.ProvisionerDaemonAuthenticated(r) + if provAuth { + // If using PSK auth, the daemon is, by definition, scoped to the organization. + tags = provisionersdk.MutateTags(uuid.Nil, tags) + return tags, true } return nil, false } diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index 362aaee427..caa65c8850 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) { t.Run("PSK_daily_cost", func(t *testing.T) { t.Parallel() + const provPSK = `provisionersftw` client, user := coderdenttest.New(t, &coderdenttest.Options{ UserWorkspaceQuota: 10, LicenseOptions: &coderdenttest.LicenseOptions{ @@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) { codersdk.FeatureTemplateRBAC: 1, }, }, - ProvisionerDaemonPSK: "provisionersftw", + ProvisionerDaemonPSK: provPSK, }) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) { Tags: map[string]string{ provisionersdk.TagScope: provisionersdk.ScopeOrganization, }, - PreSharedKey: "provisionersftw", + PreSharedKey: provPSK, }) }, &provisionerd.Options{ Logger: logger.Named("provisionerd"), @@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) { require.Error(t, err) var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) - require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + require.Equal(t, http.StatusUnauthorized, apiError.StatusCode()) daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. require.NoError(t, err) @@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) { require.Error(t, err) var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) - require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + require.Equal(t, http.StatusUnauthorized, apiError.StatusCode()) daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. require.NoError(t, err) @@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) { require.Error(t, err) var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) - require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + require.Equal(t, http.StatusUnauthorized, apiError.StatusCode()) daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. require.NoError(t, err)