mirror of https://github.com/coder/coder.git
feat: implement provisioner auth middleware and proper org params (#12330)
* feat: provisioner auth in mw to allow ExtractOrg Step to enable org scoped provisioner daemons * chore: handle default org handling for provisioner daemons
This commit is contained in:
parent
926fd7ffa6
commit
5c6974e55f
|
@ -170,6 +170,9 @@ var (
|
|||
rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
|
||||
rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate},
|
||||
rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol},
|
||||
// When org scoped provisioner credentials are implemented,
|
||||
// this can be reduced to read a specific org.
|
||||
rbac.ResourceOrganization.Type: {rbac.ActionRead},
|
||||
}),
|
||||
Org: map[string][]rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
|
|
|
@ -64,3 +64,32 @@ func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAPIKeyOrProvisionerDaemonAuth is middleware that should be inserted
|
||||
// after optional ExtractAPIKey and ExtractProvisionerDaemonAuthenticated
|
||||
// middlewares to ensure one of the two authentication methods is provided.
|
||||
//
|
||||
// If both are provided, an error is returned to avoid misuse.
|
||||
func RequireAPIKeyOrProvisionerDaemonAuth() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, hasAPIKey := APIKeyOptional(r)
|
||||
hasProvisionerDaemon := ProvisionerDaemonAuthenticated(r)
|
||||
|
||||
if hasAPIKey && hasProvisionerDaemon {
|
||||
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key and external provisioner authentication provided, but only one is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !hasAPIKey && !hasProvisionerDaemon {
|
||||
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "API key or external provisioner authentication required, but none provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,15 +53,30 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
|
|||
}
|
||||
|
||||
var organization database.Organization
|
||||
var err error
|
||||
// Try by name or uuid.
|
||||
id, err := uuid.Parse(arg)
|
||||
if err == nil {
|
||||
organization, err = db.GetOrganizationByID(ctx, id)
|
||||
var dbErr error
|
||||
|
||||
// If the name is exactly "default", then we fetch the default
|
||||
// organization. This is a special case to make it easier
|
||||
// for single org deployments.
|
||||
//
|
||||
// arg == uuid.Nil.String() should be a temporary workaround for
|
||||
// legacy provisioners that don't provide an organization ID.
|
||||
// This prevents a breaking change.
|
||||
// TODO: This change was added March 2024. Nil uuid returning the
|
||||
// default org should be removed some number of months after
|
||||
// that date.
|
||||
if arg == codersdk.DefaultOrganization || arg == uuid.Nil.String() {
|
||||
organization, dbErr = db.GetDefaultOrganization(ctx)
|
||||
} else {
|
||||
organization, err = db.GetOrganizationByName(ctx, arg)
|
||||
// Try by name or uuid.
|
||||
id, err := uuid.Parse(arg)
|
||||
if err == nil {
|
||||
organization, dbErr = db.GetOrganizationByID(ctx, id)
|
||||
} else {
|
||||
organization, dbErr = db.GetOrganizationByName(ctx, arg)
|
||||
}
|
||||
}
|
||||
if httpapi.Is404Error(err) {
|
||||
if httpapi.Is404Error(dbErr) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Organization %q not found.", arg),
|
||||
|
@ -69,10 +84,10 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
|
|||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
if dbErr != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("Internal error fetching organization %q.", arg),
|
||||
Detail: err.Error(),
|
||||
Detail: dbErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -208,5 +208,24 @@ func TestOrganizationParam(t *testing.T) {
|
|||
res = rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "by name")
|
||||
|
||||
// Try by 'default'
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", codersdk.DefaultOrganization)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res = rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "by default keyword")
|
||||
|
||||
// Try by legacy
|
||||
// TODO: This can be removed when legacy nil uuids are no longer supported.
|
||||
// This is a temporary measure to ensure as legacy provisioners use
|
||||
// nil uuids as the org id and expect the default org.
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.Nil.String())
|
||||
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res = rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "by nil uuid (legacy)")
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type provisionerDaemonContextKey struct{}
|
||||
|
||||
func ProvisionerDaemonAuthenticated(r *http.Request) bool {
|
||||
proxy, ok := r.Context().Value(provisionerDaemonContextKey{}).(bool)
|
||||
return ok && proxy
|
||||
}
|
||||
|
||||
type ExtractProvisionerAuthConfig struct {
|
||||
DB database.Store
|
||||
Optional bool
|
||||
}
|
||||
|
||||
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
handleOptional := func(code int, response codersdk.Response) {
|
||||
if opts.Optional {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, w, code, response)
|
||||
}
|
||||
|
||||
if psk == "" {
|
||||
// No psk means external provisioner daemons are not allowed.
|
||||
// So their auth is not valid.
|
||||
handleOptional(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "External provisioner daemons not enabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token := r.Header.Get(codersdk.ProvisionerDaemonPSK)
|
||||
if token == "" {
|
||||
handleOptional(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "provisioner daemon auth token required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 {
|
||||
handleOptional(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "provisioner daemon auth token invalid",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// The PSK does not indicate a specific provisioner daemon. So just
|
||||
// store a boolean so the caller can check if the request is from an
|
||||
// authenticated provisioner daemon.
|
||||
ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true)
|
||||
// nolint:gocritic // Authenticating as a provisioner daemon.
|
||||
ctx = dbauthz.AsProvisionerd(ctx)
|
||||
subj, ok := dbauthz.ActorFromContext(ctx)
|
||||
if !ok {
|
||||
// This should never happen
|
||||
httpapi.InternalServerError(w, xerrors.New("developer error: ExtractProvisionerDaemonAuth missing rbac actor"))
|
||||
}
|
||||
|
||||
// Use the same subject for the userAuthKey
|
||||
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
|
||||
Actor: subj,
|
||||
ActorName: "provisioner_daemon",
|
||||
})
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -50,6 +50,13 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if req.Name == codersdk.DefaultOrganization {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Organization name %q is reserved.", codersdk.DefaultOrganization),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := api.Database.GetOrganizationByName(ctx, req.Name)
|
||||
if err == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
|
|
|
@ -11,6 +11,9 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// DefaultOrganization is used as a replacement for the default organization.
|
||||
var DefaultOrganization = "default"
|
||||
|
||||
type ProvisionerStorageMethod string
|
||||
|
||||
const (
|
||||
|
|
|
@ -179,8 +179,8 @@ type ServeProvisionerDaemonRequest struct {
|
|||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
// Name is the human-readable unique identifier for the daemon.
|
||||
Name string `json:"name" example:"my-cool-provisioner-daemon"`
|
||||
// Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations
|
||||
// and so the organization ID is optional.
|
||||
// Organization is the organization for the URL. If no orgID is provided,
|
||||
// then it is assumed to use the default organization.
|
||||
Organization uuid.UUID `json:"organization" format:"uuid"`
|
||||
// Provisioners is a list of provisioner types hosted by the provisioner daemon
|
||||
Provisioners []ProvisionerType `json:"provisioners"`
|
||||
|
@ -194,7 +194,12 @@ type ServeProvisionerDaemonRequest struct {
|
|||
// implementation. The context is during dial, not during the lifetime of the
|
||||
// client. Client should be closed after use.
|
||||
func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization))
|
||||
orgParam := req.Organization.String()
|
||||
if req.Organization == uuid.Nil {
|
||||
orgParam = DefaultOrganization
|
||||
}
|
||||
|
||||
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", orgParam))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
|
|
@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
|||
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
|
||||
r.Use(
|
||||
api.provisionerDaemonsEnabledMW,
|
||||
apiKeyMiddlewareOptional,
|
||||
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: api.Database,
|
||||
Optional: true,
|
||||
}, api.ProvisionerDaemonPSK),
|
||||
// Either a user auth or provisioner auth is required
|
||||
// to move forward.
|
||||
httpmw.RequireAPIKeyOrProvisionerDaemonAuth(),
|
||||
httpmw.ExtractOrganizationParam(api.Database),
|
||||
)
|
||||
r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons)
|
||||
r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe)
|
||||
|
|
|
@ -2,7 +2,6 @@ package coderd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
return
|
||||
}
|
||||
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
|
||||
for _, daemon := range daemons {
|
||||
apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon))
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
|
||||
}
|
||||
|
||||
type provisionerDaemonAuth struct {
|
||||
|
@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
|
|||
}
|
||||
|
||||
// Check for PSK
|
||||
if p.psk != "" {
|
||||
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
|
||||
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
|
||||
// If using PSK auth, the daemon is, by definition, scoped to the organization.
|
||||
tags = provisionersdk.MutateTags(uuid.Nil, tags)
|
||||
return tags, true
|
||||
}
|
||||
provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
|
||||
if provAuth {
|
||||
// If using PSK auth, the daemon is, by definition, scoped to the organization.
|
||||
tags = provisionersdk.MutateTags(uuid.Nil, tags)
|
||||
return tags, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
|
||||
t.Run("PSK_daily_cost", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
const provPSK = `provisionersftw`
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
UserWorkspaceQuota: 10,
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
|
@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
codersdk.FeatureTemplateRBAC: 1,
|
||||
},
|
||||
},
|
||||
ProvisionerDaemonPSK: "provisionersftw",
|
||||
ProvisionerDaemonPSK: provPSK,
|
||||
})
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
|
@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
Tags: map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
},
|
||||
PreSharedKey: "provisionersftw",
|
||||
PreSharedKey: provPSK,
|
||||
})
|
||||
}, &provisionerd.Options{
|
||||
Logger: logger.Named("provisionerd"),
|
||||
|
@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
|
@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
|
@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue