coder/enterprise/coderd/provisionerdaemons.go

380 lines
12 KiB
Go

package coderd
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/moby/moby/pkg/namesgenerator"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
)
func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
api.entitlementsMu.RLock()
epd := api.entitlements.Features[codersdk.FeatureExternalProvisionerDaemons].Enabled
api.entitlementsMu.RUnlock()
if !epd {
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
Message: "External provisioner daemons is an Enterprise feature. Contact sales!",
})
return
}
next.ServeHTTP(rw, r)
})
}
// @Summary Get provisioner daemons
// @ID get-provisioner-daemons
// @Security CoderSessionToken
// @Produce json
// @Tags Enterprise
// @Param organization path string true "Organization ID" format(uuid)
// @Success 200 {array} codersdk.ProvisionerDaemon
// @Router /organizations/{organization}/provisionerdaemons [get]
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
daemons, err := api.Database.GetProvisionerDaemons(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, daemons)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
for _, daemon := range daemons {
apiDaemons = append(apiDaemons, convertProvisionerDaemon(daemon))
}
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
}
// Serves the provisioner daemon protobuf API over a WebSocket.
//
// @Summary Serve provisioner daemon
// @ID serve-provisioner-daemon
// @Security CoderSessionToken
// @Tags Enterprise
// @Param organization path string true "Organization ID" format(uuid)
// @Success 101
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tags := map[string]string{}
if r.URL.Query().Has("tag") {
for _, tag := range r.URL.Query()["tag"] {
parts := strings.SplitN(tag, "=", 2)
if len(parts) < 2 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
})
return
}
tags[parts[0]] = parts[1]
}
}
if !r.URL.Query().Has("provisioner") {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "The provisioner query parameter must be specified.",
})
return
}
provisionersMap := map[codersdk.ProvisionerType]struct{}{}
for _, provisioner := range r.URL.Query()["provisioner"] {
switch provisioner {
case string(codersdk.ProvisionerTypeEcho):
provisionersMap[codersdk.ProvisionerTypeEcho] = struct{}{}
case string(codersdk.ProvisionerTypeTerraform):
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
})
return
}
}
// Any authenticated user can create provisioner daemons scoped
// for jobs that they own, but only authorized users can create
// globally scoped provisioners that attach to all jobs.
apiKey := httpmw.APIKey(r)
tags = provisionerdserver.MutateTags(apiKey.UserID, tags)
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "You aren't allowed to create provisioner daemons for the organization.",
})
return
}
}
provisioners := make([]database.ProvisionerType, 0)
for p := range provisionersMap {
switch p {
case codersdk.ProvisionerTypeTerraform:
provisioners = append(provisioners, database.ProvisionerTypeTerraform)
case codersdk.ProvisionerTypeEcho:
provisioners = append(provisioners, database.ProvisionerTypeEcho)
}
}
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: provisioners,
Tags: tags,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error writing provisioner daemon.",
Detail: err.Error(),
})
return
}
rawTags, err := json.Marshal(daemon.Tags)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling daemon tags.",
Detail: err.Error(),
})
return
}
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock()
defer api.AGPL.WebsocketWaitGroup.Done()
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error accepting websocket connection.",
Detail: err.Error(),
})
return
}
// Align with the frame size of yamux.
conn.SetReadLimit(256 * 1024)
// Multiplexes the incoming connection using yamux.
// This allows multiple function calls to occur over
// the same connection.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
return
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
GitAuthConfigs: api.GitAuthConfigs,
OIDCConfig: api.OIDCConfig,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Auditor: &api.AGPL.Auditor,
TemplateScheduleStore: api.AGPL.TemplateScheduleStore,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
Tags: rawTags,
Tracer: trace.NewNoopTracerProvider().Tracer("noop"),
DeploymentValues: api.DeploymentValues,
})
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))
return
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
err = server.Serve(ctx, session)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
_ = conn.Close(websocket.StatusGoingAway, "")
}
func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.ProvisionerDaemon {
result := codersdk.ProvisionerDaemon{
ID: daemon.ID,
CreatedAt: daemon.CreatedAt,
UpdatedAt: daemon.UpdatedAt,
Name: daemon.Name,
Tags: daemon.Tags,
}
for _, provisionerType := range daemon.Provisioners {
result.Provisioners = append(result.Provisioners, codersdk.ProvisionerType(provisionerType))
}
return result
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}
type enterpriseTemplateScheduleStore struct{}
var _ schedule.TemplateScheduleStore = &enterpriseTemplateScheduleStore{}
func (*enterpriseTemplateScheduleStore) GetTemplateScheduleOptions(ctx context.Context, db database.Store, templateID uuid.UUID) (schedule.TemplateScheduleOptions, error) {
tpl, err := db.GetTemplateByID(ctx, templateID)
if err != nil {
return schedule.TemplateScheduleOptions{}, err
}
return schedule.TemplateScheduleOptions{
UserAutostartEnabled: tpl.AllowUserAutostart,
UserAutostopEnabled: tpl.AllowUserAutostop,
DefaultTTL: time.Duration(tpl.DefaultTTL),
MaxTTL: time.Duration(tpl.MaxTTL),
FailureTTL: time.Duration(tpl.FailureTTL),
InactivityTTL: time.Duration(tpl.InactivityTTL),
}, nil
}
func (*enterpriseTemplateScheduleStore) SetTemplateScheduleOptions(ctx context.Context, db database.Store, tpl database.Template, opts schedule.TemplateScheduleOptions) (database.Template, error) {
if int64(opts.DefaultTTL) == tpl.DefaultTTL &&
int64(opts.MaxTTL) == tpl.MaxTTL &&
int64(opts.FailureTTL) == tpl.FailureTTL &&
int64(opts.InactivityTTL) == tpl.InactivityTTL &&
opts.UserAutostartEnabled == tpl.AllowUserAutostart &&
opts.UserAutostopEnabled == tpl.AllowUserAutostop {
// Avoid updating the UpdatedAt timestamp if nothing will be changed.
return tpl, nil
}
template, err := db.UpdateTemplateScheduleByID(ctx, database.UpdateTemplateScheduleByIDParams{
ID: tpl.ID,
UpdatedAt: database.Now(),
AllowUserAutostart: opts.UserAutostartEnabled,
AllowUserAutostop: opts.UserAutostopEnabled,
DefaultTTL: int64(opts.DefaultTTL),
MaxTTL: int64(opts.MaxTTL),
FailureTTL: int64(opts.FailureTTL),
InactivityTTL: int64(opts.InactivityTTL),
})
if err != nil {
return database.Template{}, xerrors.Errorf("update template schedule: %w", err)
}
// Update all workspaces using the template to set the user defined schedule
// to be within the new bounds. This essentially does the following for each
// workspace using the template.
// if (template.ttl != NULL) {
// workspace.ttl = min(workspace.ttl, template.ttl)
// }
//
// NOTE: this does not apply to currently running workspaces as their
// schedule information is committed to the workspace_build during start.
// This limitation is displayed to the user while editing the template.
if opts.MaxTTL > 0 {
err = db.UpdateWorkspaceTTLToBeWithinTemplateMax(ctx, database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams{
TemplateID: template.ID,
TemplateMaxTTL: int64(opts.MaxTTL),
})
if err != nil {
return database.Template{}, xerrors.Errorf("update TTL of all workspaces on template to be within new template max TTL: %w", err)
}
}
return template, nil
}