feat: implement provisioner auth middleware and proper org params (#12330)

* feat: provisioner auth in mw to allow ExtractOrg

Step to enable org scoped provisioner daemons

* chore: handle default org handling for provisioner daemons
This commit is contained in:
Steven Masley 2024-03-04 15:15:41 -06:00 committed by GitHub
parent 926fd7ffa6
commit 5c6974e55f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 201 additions and 30 deletions

View File

@ -170,6 +170,9 @@ var (
rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate},
rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol},
// When org scoped provisioner credentials are implemented,
// this can be reduced to read a specific org.
rbac.ResourceOrganization.Type: {rbac.ActionRead},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},

View File

@ -64,3 +64,32 @@ func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler {
})
}
}
// RequireAPIKeyOrProvisionerDaemonAuth is middleware that should be inserted
// after optional ExtractAPIKey and ExtractProvisionerDaemonAuthenticated
// middlewares to ensure one of the two authentication methods is provided.
//
// If both are provided, an error is returned to avoid misuse.
func RequireAPIKeyOrProvisionerDaemonAuth() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasAPIKey := APIKeyOptional(r)
hasProvisionerDaemon := ProvisionerDaemonAuthenticated(r)
if hasAPIKey && hasProvisionerDaemon {
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
Message: "API key and external provisioner authentication provided, but only one is allowed",
})
return
}
if !hasAPIKey && !hasProvisionerDaemon {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: "API key or external provisioner authentication required, but none provided",
})
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -53,15 +53,30 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
}
var organization database.Organization
var err error
// Try by name or uuid.
id, err := uuid.Parse(arg)
if err == nil {
organization, err = db.GetOrganizationByID(ctx, id)
var dbErr error
// If the name is exactly "default", then we fetch the default
// organization. This is a special case to make it easier
// for single org deployments.
//
// arg == uuid.Nil.String() should be a temporary workaround for
// legacy provisioners that don't provide an organization ID.
// This prevents a breaking change.
// TODO: This change was added March 2024. Nil uuid returning the
// default org should be removed some number of months after
// that date.
if arg == codersdk.DefaultOrganization || arg == uuid.Nil.String() {
organization, dbErr = db.GetDefaultOrganization(ctx)
} else {
organization, err = db.GetOrganizationByName(ctx, arg)
// Try by name or uuid.
id, err := uuid.Parse(arg)
if err == nil {
organization, dbErr = db.GetOrganizationByID(ctx, id)
} else {
organization, dbErr = db.GetOrganizationByName(ctx, arg)
}
}
if httpapi.Is404Error(err) {
if httpapi.Is404Error(dbErr) {
httpapi.ResourceNotFound(rw)
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Organization %q not found.", arg),
@ -69,10 +84,10 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
})
return
}
if err != nil {
if dbErr != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("Internal error fetching organization %q.", arg),
Detail: err.Error(),
Detail: dbErr.Error(),
})
return
}

View File

@ -208,5 +208,24 @@ func TestOrganizationParam(t *testing.T) {
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by name")
// Try by 'default'
chi.RouteContext(r.Context()).URLParams.Add("organization", codersdk.DefaultOrganization)
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.ServeHTTP(rw, r)
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by default keyword")
// Try by legacy
// TODO: This can be removed when legacy nil uuids are no longer supported.
// This is a temporary measure to ensure as legacy provisioners use
// nil uuids as the org id and expect the default org.
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.Nil.String())
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.ServeHTTP(rw, r)
res = rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode, "by nil uuid (legacy)")
})
}

View File

@ -0,0 +1,86 @@
package httpmw
import (
"context"
"crypto/subtle"
"net/http"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)
type provisionerDaemonContextKey struct{}
func ProvisionerDaemonAuthenticated(r *http.Request) bool {
proxy, ok := r.Context().Value(provisionerDaemonContextKey{}).(bool)
return ok && proxy
}
type ExtractProvisionerAuthConfig struct {
DB database.Store
Optional bool
}
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handleOptional := func(code int, response codersdk.Response) {
if opts.Optional {
next.ServeHTTP(w, r)
return
}
httpapi.Write(ctx, w, code, response)
}
if psk == "" {
// No psk means external provisioner daemons are not allowed.
// So their auth is not valid.
handleOptional(http.StatusBadRequest, codersdk.Response{
Message: "External provisioner daemons not enabled",
})
return
}
token := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if token == "" {
handleOptional(http.StatusUnauthorized, codersdk.Response{
Message: "provisioner daemon auth token required",
})
return
}
if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 {
handleOptional(http.StatusUnauthorized, codersdk.Response{
Message: "provisioner daemon auth token invalid",
})
return
}
// The PSK does not indicate a specific provisioner daemon. So just
// store a boolean so the caller can check if the request is from an
// authenticated provisioner daemon.
ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true)
// nolint:gocritic // Authenticating as a provisioner daemon.
ctx = dbauthz.AsProvisionerd(ctx)
subj, ok := dbauthz.ActorFromContext(ctx)
if !ok {
// This should never happen
httpapi.InternalServerError(w, xerrors.New("developer error: ExtractProvisionerDaemonAuth missing rbac actor"))
}
// Use the same subject for the userAuthKey
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
Actor: subj,
ActorName: "provisioner_daemon",
})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -50,6 +50,13 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
return
}
if req.Name == codersdk.DefaultOrganization {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Organization name %q is reserved.", codersdk.DefaultOrganization),
})
return
}
_, err := api.Database.GetOrganizationByName(ctx, req.Name)
if err == nil {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{

View File

@ -11,6 +11,9 @@ import (
"golang.org/x/xerrors"
)
// DefaultOrganization is used as a replacement for the default organization.
var DefaultOrganization = "default"
type ProvisionerStorageMethod string
const (

View File

@ -179,8 +179,8 @@ type ServeProvisionerDaemonRequest struct {
ID uuid.UUID `json:"id" format:"uuid"`
// Name is the human-readable unique identifier for the daemon.
Name string `json:"name" example:"my-cool-provisioner-daemon"`
// Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations
// and so the organization ID is optional.
// Organization is the organization for the URL. If no orgID is provided,
// then it is assumed to use the default organization.
Organization uuid.UUID `json:"organization" format:"uuid"`
// Provisioners is a list of provisioner types hosted by the provisioner daemon
Provisioners []ProvisionerType `json:"provisioners"`
@ -194,7 +194,12 @@ type ServeProvisionerDaemonRequest struct {
// implementation. The context is during dial, not during the lifetime of the
// client. Client should be closed after use.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization))
orgParam := req.Organization.String()
if req.Organization == uuid.Nil {
orgParam = DefaultOrganization
}
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", orgParam))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}

View File

@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
r.Use(
api.provisionerDaemonsEnabledMW,
apiKeyMiddlewareOptional,
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: api.Database,
Optional: true,
}, api.ProvisionerDaemonPSK),
// Either a user auth or provisioner auth is required
// to move forward.
httpmw.RequireAPIKeyOrProvisionerDaemonAuth(),
httpmw.ExtractOrganizationParam(api.Database),
)
r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons)
r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe)

View File

@ -2,7 +2,6 @@ package coderd
import (
"context"
"crypto/subtle"
"database/sql"
"errors"
"fmt"
@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
for _, daemon := range daemons {
apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon))
}
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
}
type provisionerDaemonAuth struct {
@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
}
// Check for PSK
if p.psk != "" {
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
if provAuth {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
return nil, false
}

View File

@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
t.Run("PSK_daily_cost", func(t *testing.T) {
t.Parallel()
const provPSK = `provisionersftw`
client, user := coderdenttest.New(t, &coderdenttest.Options{
UserWorkspaceQuota: 10,
LicenseOptions: &coderdenttest.LicenseOptions{
@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
codersdk.FeatureTemplateRBAC: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
ProvisionerDaemonPSK: provPSK,
})
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
Tags: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
},
PreSharedKey: "provisionersftw",
PreSharedKey: provPSK,
})
}, &provisionerd.Options{
Logger: logger.Named("provisionerd"),
@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)