feat: implement provisioner auth middleware and proper org params (#12330)

* feat: provisioner auth in mw to allow ExtractOrg

Step to enable org scoped provisioner daemons

* chore: handle default org handling for provisioner daemons
This commit is contained in:
Steven Masley 2024-03-04 15:15:41 -06:00 committed by GitHub
parent 926fd7ffa6
commit 5c6974e55f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 201 additions and 30 deletions

View File

@ -170,6 +170,9 @@ var (
rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete}, rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate}, rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate},
rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol}, rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol},
// When org scoped provisioner credentials are implemented,
// this can be reduced to read a specific org.
rbac.ResourceOrganization.Type: {rbac.ActionRead},
}), }),
Org: map[string][]rbac.Permission{}, Org: map[string][]rbac.Permission{},
User: []rbac.Permission{}, User: []rbac.Permission{},

View File

@ -64,3 +64,32 @@ func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler {
}) })
} }
} }
// RequireAPIKeyOrProvisionerDaemonAuth is middleware that should be inserted
// after optional ExtractAPIKey and ExtractProvisionerDaemonAuthenticated
// middlewares to ensure one of the two authentication methods is provided.
//
// If both are provided, an error is returned to avoid misuse.
func RequireAPIKeyOrProvisionerDaemonAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasAPIKey := APIKeyOptional(r)
hasProvisionerDaemon := ProvisionerDaemonAuthenticated(r)
if hasAPIKey && hasProvisionerDaemon {
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
Message: "API key and external provisioner authentication provided, but only one is allowed",
})
return
}
if !hasAPIKey && !hasProvisionerDaemon {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: "API key or external provisioner authentication required, but none provided",
})
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -53,15 +53,30 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
} }
var organization database.Organization var organization database.Organization
var err error var dbErr error
// Try by name or uuid.
id, err := uuid.Parse(arg) // If the name is exactly "default", then we fetch the default
if err == nil { // organization. This is a special case to make it easier
organization, err = db.GetOrganizationByID(ctx, id) // for single org deployments.
//
// arg == uuid.Nil.String() should be a temporary workaround for
// legacy provisioners that don't provide an organization ID.
// This prevents a breaking change.
// TODO: This change was added March 2024. Nil uuid returning the
// default org should be removed some number of months after
// that date.
if arg == codersdk.DefaultOrganization || arg == uuid.Nil.String() {
organization, dbErr = db.GetDefaultOrganization(ctx)
} else { } else {
organization, err = db.GetOrganizationByName(ctx, arg) // Try by name or uuid.
id, err := uuid.Parse(arg)
if err == nil {
organization, dbErr = db.GetOrganizationByID(ctx, id)
} else {
organization, dbErr = db.GetOrganizationByName(ctx, arg)
}
} }
if httpapi.Is404Error(err) { if httpapi.Is404Error(dbErr) {
httpapi.ResourceNotFound(rw) httpapi.ResourceNotFound(rw)
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Organization %q not found.", arg), Message: fmt.Sprintf("Organization %q not found.", arg),
@ -69,10 +84,10 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
}) })
return return
} }
if err != nil { if dbErr != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("Internal error fetching organization %q.", arg), Message: fmt.Sprintf("Internal error fetching organization %q.", arg),
Detail: err.Error(), Detail: dbErr.Error(),
}) })
return return
} }

View File

@ -208,5 +208,24 @@ func TestOrganizationParam(t *testing.T) {
res = rw.Result() res = rw.Result()
defer res.Body.Close() defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by name") require.Equal(t, http.StatusOK, res.StatusCode, "by name")
// Try by 'default'
chi.RouteContext(r.Context()).URLParams.Add("organization", codersdk.DefaultOrganization)
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.ServeHTTP(rw, r)
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by default keyword")
// Try by legacy
// TODO: This can be removed when legacy nil uuids are no longer supported.
// This is a temporary measure to ensure as legacy provisioners use
// nil uuids as the org id and expect the default org.
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.Nil.String())
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.ServeHTTP(rw, r)
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by nil uuid (legacy)")
}) })
} }

View File

@ -0,0 +1,86 @@
package httpmw
import (
"context"
"crypto/subtle"
"net/http"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)
type provisionerDaemonContextKey struct{}
func ProvisionerDaemonAuthenticated(r *http.Request) bool {
proxy, ok := r.Context().Value(provisionerDaemonContextKey{}).(bool)
return ok && proxy
}
type ExtractProvisionerAuthConfig struct {
DB database.Store
Optional bool
}
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handleOptional := func(code int, response codersdk.Response) {
if opts.Optional {
next.ServeHTTP(w, r)
return
}
httpapi.Write(ctx, w, code, response)
}
if psk == "" {
// No psk means external provisioner daemons are not allowed.
// So their auth is not valid.
handleOptional(http.StatusBadRequest, codersdk.Response{
Message: "External provisioner daemons not enabled",
})
return
}
token := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if token == "" {
handleOptional(http.StatusUnauthorized, codersdk.Response{
Message: "provisioner daemon auth token required",
})
return
}
if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 {
handleOptional(http.StatusUnauthorized, codersdk.Response{
Message: "provisioner daemon auth token invalid",
})
return
}
// The PSK does not indicate a specific provisioner daemon. So just
// store a boolean so the caller can check if the request is from an
// authenticated provisioner daemon.
ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true)
// nolint:gocritic // Authenticating as a provisioner daemon.
ctx = dbauthz.AsProvisionerd(ctx)
subj, ok := dbauthz.ActorFromContext(ctx)
if !ok {
// This should never happen
httpapi.InternalServerError(w, xerrors.New("developer error: ExtractProvisionerDaemonAuth missing rbac actor"))
}
// Use the same subject for the userAuthKey
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
Actor: subj,
ActorName: "provisioner_daemon",
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -50,6 +50,13 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
return return
} }
if req.Name == codersdk.DefaultOrganization {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Organization name %q is reserved.", codersdk.DefaultOrganization),
})
return
}
_, err := api.Database.GetOrganizationByName(ctx, req.Name) _, err := api.Database.GetOrganizationByName(ctx, req.Name)
if err == nil { if err == nil {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{

View File

@ -11,6 +11,9 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
) )
// DefaultOrganization is used as a replacement for the default organization.
var DefaultOrganization = "default"
type ProvisionerStorageMethod string type ProvisionerStorageMethod string
const ( const (

View File

@ -179,8 +179,8 @@ type ServeProvisionerDaemonRequest struct {
ID uuid.UUID `json:"id" format:"uuid"` ID uuid.UUID `json:"id" format:"uuid"`
// Name is the human-readable unique identifier for the daemon. // Name is the human-readable unique identifier for the daemon.
Name string `json:"name" example:"my-cool-provisioner-daemon"` Name string `json:"name" example:"my-cool-provisioner-daemon"`
// Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations // Organization is the organization for the URL. If no orgID is provided,
// and so the organization ID is optional. // then it is assumed to use the default organization.
Organization uuid.UUID `json:"organization" format:"uuid"` Organization uuid.UUID `json:"organization" format:"uuid"`
// Provisioners is a list of provisioner types hosted by the provisioner daemon // Provisioners is a list of provisioner types hosted by the provisioner daemon
Provisioners []ProvisionerType `json:"provisioners"` Provisioners []ProvisionerType `json:"provisioners"`
@ -194,7 +194,12 @@ type ServeProvisionerDaemonRequest struct {
// implementation. The context is during dial, not during the lifetime of the // implementation. The context is during dial, not during the lifetime of the
// client. Client should be closed after use. // client. Client should be closed after use.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) { func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization)) orgParam := req.Organization.String()
if req.Organization == uuid.Nil {
orgParam = DefaultOrganization
}
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", orgParam))
if err != nil { if err != nil {
return nil, xerrors.Errorf("parse url: %w", err) return nil, xerrors.Errorf("parse url: %w", err)
} }

View File

@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) { r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
r.Use( r.Use(
api.provisionerDaemonsEnabledMW, api.provisionerDaemonsEnabledMW,
apiKeyMiddlewareOptional,
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: api.Database,
Optional: true,
}, api.ProvisionerDaemonPSK),
// Either a user auth or provisioner auth is required
// to move forward.
httpmw.RequireAPIKeyOrProvisionerDaemonAuth(),
httpmw.ExtractOrganizationParam(api.Database),
) )
r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons) r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons)
r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe) r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe)

View File

@ -2,7 +2,6 @@ package coderd
import ( import (
"context" "context"
"crypto/subtle"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
}) })
return return
} }
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
for _, daemon := range daemons { httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon))
}
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
} }
type provisionerDaemonAuth struct { type provisionerDaemonAuth struct {
@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
} }
// Check for PSK // Check for PSK
if p.psk != "" { provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK) if provAuth {
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 { // If using PSK auth, the daemon is, by definition, scoped to the organization.
// If using PSK auth, the daemon is, by definition, scoped to the organization. tags = provisionersdk.MutateTags(uuid.Nil, tags)
tags = provisionersdk.MutateTags(uuid.Nil, tags) return tags, true
return tags, true
}
} }
return nil, false return nil, false
} }

View File

@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
t.Run("PSK_daily_cost", func(t *testing.T) { t.Run("PSK_daily_cost", func(t *testing.T) {
t.Parallel() t.Parallel()
const provPSK = `provisionersftw`
client, user := coderdenttest.New(t, &coderdenttest.Options{ client, user := coderdenttest.New(t, &coderdenttest.Options{
UserWorkspaceQuota: 10, UserWorkspaceQuota: 10,
LicenseOptions: &coderdenttest.LicenseOptions{ LicenseOptions: &coderdenttest.LicenseOptions{
@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
codersdk.FeatureTemplateRBAC: 1, codersdk.FeatureTemplateRBAC: 1,
}, },
}, },
ProvisionerDaemonPSK: "provisionersftw", ProvisionerDaemonPSK: provPSK,
}) })
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
Tags: map[string]string{ Tags: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization, provisionersdk.TagScope: provisionersdk.ScopeOrganization,
}, },
PreSharedKey: "provisionersftw", PreSharedKey: provPSK,
}) })
}, &provisionerd.Options{ }, &provisionerd.Options{
Logger: logger.Named("provisionerd"), Logger: logger.Named("provisionerd"),
@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err) require.Error(t, err)
var apiError *codersdk.Error var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError) require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode()) require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err) require.NoError(t, err)
@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err) require.Error(t, err)
var apiError *codersdk.Error var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError) require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode()) require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err) require.NoError(t, err)
@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err) require.Error(t, err)
var apiError *codersdk.Error var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError) require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode()) require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err) require.NoError(t, err)