coder/cli/server.go

2594 lines
85 KiB
Go

//go:build !slim
package cli
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/hex"
"errors"
"flag"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"os/user"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-systemd/daemon"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver"
"golang.org/x/oauth2"
xgithub "golang.org/x/oauth2/github"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
"gopkg.in/yaml.v3"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/clilog"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/cli/cliutil"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/autobuild"
"github.com/coder/coder/v2/coderd/batchstats"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/awsiamrds"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbmetrics"
"github.com/coder/coder/v2/coderd/database/dbpurge"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/devtunnel"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/gitsshkey"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/oauthpki"
"github.com/coder/coder/v2/coderd/prometheusmetrics"
"github.com/coder/coder/v2/coderd/prometheusmetrics/insights"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/schedule"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/unhanger"
"github.com/coder/coder/v2/coderd/updatecheck"
"github.com/coder/coder/v2/coderd/util/slice"
stringutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspaceusage"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisioner/terraform"
"github.com/coder/coder/v2/provisionerd"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/pretty"
"github.com/coder/retry"
"github.com/coder/serpent"
"github.com/coder/wgtunnel/tunnelsdk"
)
func createOIDCConfig(ctx context.Context, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) {
if vals.OIDC.ClientID == "" {
return nil, xerrors.Errorf("OIDC client ID must be set!")
}
if vals.OIDC.IssuerURL == "" {
return nil, xerrors.Errorf("OIDC issuer URL must be set!")
}
oidcProvider, err := oidc.NewProvider(
ctx, vals.OIDC.IssuerURL.String(),
)
if err != nil {
return nil, xerrors.Errorf("configure oidc provider: %w", err)
}
redirectURL, err := vals.AccessURL.Value().Parse("/api/v2/users/oidc/callback")
if err != nil {
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
}
// If the scopes contain 'groups', we enable group support.
// Do not override any custom value set by the user.
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
vals.OIDC.GroupField = "groups"
}
oauthCfg := &oauth2.Config{
ClientID: vals.OIDC.ClientID.String(),
ClientSecret: vals.OIDC.ClientSecret.String(),
RedirectURL: redirectURL.String(),
Endpoint: oidcProvider.Endpoint(),
Scopes: vals.OIDC.Scopes,
}
var useCfg promoauth.OAuth2Config = oauthCfg
if vals.OIDC.ClientKeyFile != "" {
// PKI authentication is done in the params. If a
// counter example is found, we can add a config option to
// change this.
oauthCfg.Endpoint.AuthStyle = oauth2.AuthStyleInParams
if vals.OIDC.ClientSecret != "" {
return nil, xerrors.Errorf("cannot specify both oidc client secret and oidc client key file")
}
pkiCfg, err := configureOIDCPKI(oauthCfg, vals.OIDC.ClientKeyFile.Value(), vals.OIDC.ClientCertFile.Value())
if err != nil {
return nil, xerrors.Errorf("configure oauth pki authentication: %w", err)
}
useCfg = pkiCfg
}
if len(vals.OIDC.GroupAllowList) > 0 && vals.OIDC.GroupField == "" {
return nil, xerrors.Errorf("'oidc-group-field' must be set if 'oidc-allowed-groups' is set. Either unset 'oidc-allowed-groups' or set 'oidc-group-field'")
}
groupAllowList := make(map[string]bool)
for _, group := range vals.OIDC.GroupAllowList.Value() {
groupAllowList[group] = true
}
return &coderd.OIDCConfig{
OAuth2Config: useCfg,
Provider: oidcProvider,
Verifier: oidcProvider.Verifier(&oidc.Config{
ClientID: vals.OIDC.ClientID.String(),
}),
EmailDomain: vals.OIDC.EmailDomain,
AllowSignups: vals.OIDC.AllowSignups.Value(),
UsernameField: vals.OIDC.UsernameField.String(),
EmailField: vals.OIDC.EmailField.String(),
AuthURLParams: vals.OIDC.AuthURLParams.Value,
IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(),
GroupField: vals.OIDC.GroupField.String(),
GroupFilter: vals.OIDC.GroupRegexFilter.Value(),
GroupAllowList: groupAllowList,
CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(),
GroupMapping: vals.OIDC.GroupMapping.Value,
UserRoleField: vals.OIDC.UserRoleField.String(),
UserRoleMapping: vals.OIDC.UserRoleMapping.Value,
UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(),
SignInText: vals.OIDC.SignInText.String(),
SignupsDisabledText: vals.OIDC.SignupsDisabledText.String(),
IconURL: vals.OIDC.IconURL.String(),
IgnoreEmailVerified: vals.OIDC.IgnoreEmailVerified.Value(),
}, nil
}
func afterCtx(ctx context.Context, fn func()) {
go func() {
<-ctx.Done()
fn()
}()
}
func enablePrometheus(
ctx context.Context,
logger slog.Logger,
vals *codersdk.DeploymentValues,
options *coderd.Options,
) (closeFn func(), err error) {
options.PrometheusRegistry.MustRegister(collectors.NewGoCollector())
options.PrometheusRegistry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
closeUsersFunc, err := prometheusmetrics.ActiveUsers(ctx, options.PrometheusRegistry, options.Database, 0)
if err != nil {
return nil, xerrors.Errorf("register active users prometheus metric: %w", err)
}
afterCtx(ctx, closeUsersFunc)
closeWorkspacesFunc, err := prometheusmetrics.Workspaces(ctx, options.Logger.Named("workspaces_metrics"), options.PrometheusRegistry, options.Database, 0)
if err != nil {
return nil, xerrors.Errorf("register workspaces prometheus metric: %w", err)
}
afterCtx(ctx, closeWorkspacesFunc)
insightsMetricsCollector, err := insights.NewMetricsCollector(options.Database, options.Logger, 0, 0)
if err != nil {
return nil, xerrors.Errorf("unable to initialize insights metrics collector: %w", err)
}
err = options.PrometheusRegistry.Register(insightsMetricsCollector)
if err != nil {
return nil, xerrors.Errorf("unable to register insights metrics collector: %w", err)
}
closeInsightsMetricsCollector, err := insightsMetricsCollector.Run(ctx)
if err != nil {
return nil, xerrors.Errorf("unable to run insights metrics collector: %w", err)
}
afterCtx(ctx, closeInsightsMetricsCollector)
if vals.Prometheus.CollectAgentStats {
closeAgentStatsFunc, err := prometheusmetrics.AgentStats(ctx, logger, options.PrometheusRegistry, options.Database, time.Now(), 0, options.DeploymentValues.Prometheus.AggregateAgentStatsBy.Value())
if err != nil {
return nil, xerrors.Errorf("register agent stats prometheus metric: %w", err)
}
afterCtx(ctx, closeAgentStatsFunc)
metricsAggregator, err := prometheusmetrics.NewMetricsAggregator(logger, options.PrometheusRegistry, 0, options.DeploymentValues.Prometheus.AggregateAgentStatsBy.Value())
if err != nil {
return nil, xerrors.Errorf("can't initialize metrics aggregator: %w", err)
}
cancelMetricsAggregator := metricsAggregator.Run(ctx)
afterCtx(ctx, cancelMetricsAggregator)
options.UpdateAgentMetrics = metricsAggregator.Update
err = options.PrometheusRegistry.Register(metricsAggregator)
if err != nil {
return nil, xerrors.Errorf("can't register metrics aggregator as collector: %w", err)
}
}
//nolint:revive
return ServeHandler(
ctx, logger, promhttp.InstrumentMetricHandler(
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
), vals.Prometheus.Address.String(), "prometheus",
), nil
}
//nolint:gocognit // TODO(dannyk): reduce complexity of this function
func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *serpent.Command {
if newAPI == nil {
newAPI = func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) {
api := coderd.New(o)
return api, api, nil
}
}
var (
vals = new(codersdk.DeploymentValues)
opts = vals.Options()
)
serverCmd := &serpent.Command{
Use: "server",
Short: "Start a Coder server",
Options: opts,
Middleware: serpent.Chain(
WriteConfigMW(vals),
PrintDeprecatedOptions(),
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
// Main command context for managing cancellation of running
// services.
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()
if vals.Config != "" {
cliui.Warnf(inv.Stderr, "YAML support is experimental and offers no compatibility guarantees.")
}
go DumpHandler(ctx, "coderd")
// Validate bind addresses.
if vals.Address.String() != "" {
if vals.TLS.Enable {
vals.HTTPAddress = ""
vals.TLS.Address = vals.Address
} else {
_ = vals.HTTPAddress.Set(vals.Address.String())
vals.TLS.Address.Host = ""
vals.TLS.Address.Port = ""
}
}
if vals.TLS.Enable && vals.TLS.Address.String() == "" {
return xerrors.Errorf("TLS address must be set if TLS is enabled")
}
if !vals.TLS.Enable && vals.HTTPAddress.String() == "" {
return xerrors.Errorf("TLS is disabled. Enable with --tls-enable or specify a HTTP address")
}
if vals.AccessURL.String() != "" &&
!(vals.AccessURL.Scheme == "http" || vals.AccessURL.Scheme == "https") {
return xerrors.Errorf("access-url must include a scheme (e.g. 'http://' or 'https://)")
}
// Disable rate limits if the `--dangerous-disable-rate-limits` flag
// was specified.
loginRateLimit := 60
filesRateLimit := 12
if vals.RateLimit.DisableAll {
vals.RateLimit.API = -1
loginRateLimit = -1
filesRateLimit = -1
}
PrintLogo(inv, "Coder")
logger, logCloser, err := clilog.New(clilog.FromDeploymentValues(vals)).Build(inv)
if err != nil {
return xerrors.Errorf("make logger: %w", err)
}
defer logCloser()
// This line is helpful in tests.
logger.Debug(ctx, "started debug logging")
logger.Sync()
// Register signals early on so that graceful shutdown can't
// be interrupted by additional signals. Note that we avoid
// shadowing cancel() (from above) here because stopCancel()
// restores default behavior for the signals. This protects
// the shutdown sequence from abruptly terminating things
// like: database migrations, provisioner work, workspace
// cleanup in dev-mode, etc.
//
// To get out of a graceful shutdown, the user can send
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
stopCtx, stopCancel := signalNotifyContext(ctx, inv, StopSignalsNoInterrupt...)
defer stopCancel()
interruptCtx, interruptCancel := signalNotifyContext(ctx, inv, InterruptSignals...)
defer interruptCancel()
cacheDir := vals.CacheDir.String()
err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
// Clean up idle connections at the end, e.g.
// embedded-postgres can leave an idle connection
// which is caught by goleaks.
defer http.DefaultClient.CloseIdleConnections()
tracerProvider, sqlDriver, closeTracing := ConfigureTraceProvider(ctx, logger, vals)
defer func() {
logger.Debug(ctx, "closing tracing")
traceCloseErr := shutdownWithTimeout(closeTracing, 5*time.Second)
logger.Debug(ctx, "tracing closed", slog.Error(traceCloseErr))
}()
httpServers, err := ConfigureHTTPServers(logger, inv, vals)
if err != nil {
return xerrors.Errorf("configure http(s): %w", err)
}
defer httpServers.Close()
config := r.createConfig()
builtinPostgres := false
// Only use built-in if PostgreSQL URL isn't specified!
if !vals.InMemoryDatabase && vals.PostgresURL == "" {
var closeFunc func() error
cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)", config.PostgresPath())
pgURL, closeFunc, err := startBuiltinPostgres(ctx, config, logger)
if err != nil {
return err
}
err = vals.PostgresURL.Set(pgURL)
if err != nil {
return err
}
builtinPostgres = true
defer func() {
cliui.Infof(inv.Stdout, "Stopping built-in PostgreSQL...")
// Gracefully shut PostgreSQL down!
if err := closeFunc(); err != nil {
cliui.Errorf(inv.Stderr, "Failed to stop built-in PostgreSQL: %v", err)
} else {
cliui.Infof(inv.Stdout, "Stopped built-in PostgreSQL")
}
}()
}
// Prefer HTTP because it's less prone to TLS errors over localhost.
localURL := httpServers.TLSUrl
if httpServers.HTTPUrl != nil {
localURL = httpServers.HTTPUrl
}
ctx, httpClient, err := ConfigureHTTPClient(
ctx,
vals.TLS.ClientCertFile.String(),
vals.TLS.ClientKeyFile.String(),
vals.TLS.ClientCAFile.String(),
)
if err != nil {
return xerrors.Errorf("configure http client: %w", err)
}
// If the access URL is empty, we attempt to run a reverse-proxy
// tunnel to make the initial setup really simple.
var (
tunnel *tunnelsdk.Tunnel
tunnelDone <-chan struct{} = make(chan struct{}, 1)
)
if vals.AccessURL.String() == "" {
cliui.Infof(inv.Stderr, "Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL")
tunnel, err = devtunnel.New(ctx, logger.Named("net.devtunnel"), vals.WgtunnelHost.String())
if err != nil {
return xerrors.Errorf("create tunnel: %w", err)
}
defer tunnel.Close()
tunnelDone = tunnel.Wait()
vals.AccessURL = serpent.URL(*tunnel.URL)
if vals.WildcardAccessURL.String() == "" {
// Suffixed wildcard access URL.
wu := fmt.Sprintf("*--%s", tunnel.URL.Hostname())
err = vals.WildcardAccessURL.Set(wu)
if err != nil {
return xerrors.Errorf("set wildcard access url %q: %w", wu, err)
}
}
}
_, accessURLPortRaw, _ := net.SplitHostPort(vals.AccessURL.Host)
if accessURLPortRaw == "" {
accessURLPortRaw = "80"
if vals.AccessURL.Scheme == "https" {
accessURLPortRaw = "443"
}
}
accessURLPort, err := strconv.Atoi(accessURLPortRaw)
if err != nil {
return xerrors.Errorf("parse access URL port: %w", err)
}
// Warn the user if the access URL is loopback or unresolvable.
isLocal, err := IsLocalURL(ctx, vals.AccessURL.Value())
if isLocal || err != nil {
reason := "could not be resolved"
if isLocal {
reason = "isn't externally reachable"
}
cliui.Warnf(
inv.Stderr,
"The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n",
pretty.Sprint(cliui.DefaultStyles.Field, vals.AccessURL.String()), reason,
)
}
// A newline is added before for visibility in terminal output.
cliui.Infof(inv.Stdout, "\nView the Web UI: %s", vals.AccessURL.String())
// Used for zero-trust instance identity with Google Cloud.
googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
if err != nil {
return err
}
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(vals.SSHKeygenAlgorithm.String())
if err != nil {
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", vals.SSHKeygenAlgorithm, err)
}
defaultRegion := &tailcfg.DERPRegion{
EmbeddedRelay: true,
RegionID: int(vals.DERP.Server.RegionID.Value()),
RegionCode: vals.DERP.Server.RegionCode.String(),
RegionName: vals.DERP.Server.RegionName.String(),
Nodes: []*tailcfg.DERPNode{{
Name: fmt.Sprintf("%db", vals.DERP.Server.RegionID),
RegionID: int(vals.DERP.Server.RegionID.Value()),
HostName: vals.AccessURL.Value().Hostname(),
DERPPort: accessURLPort,
STUNPort: -1,
ForceHTTP: vals.AccessURL.Scheme == "http",
}},
}
if !vals.DERP.Server.Enable {
defaultRegion = nil
}
derpMap, err := tailnet.NewDERPMap(
ctx, defaultRegion, vals.DERP.Server.STUNAddresses,
vals.DERP.Config.URL.String(), vals.DERP.Config.Path.String(),
vals.DERP.Config.BlockDirect.Value(),
)
if err != nil {
return xerrors.Errorf("create derp map: %w", err)
}
appHostname := vals.WildcardAccessURL.String()
var appHostnameRegex *regexp.Regexp
if appHostname != "" {
appHostnameRegex, err = appurl.CompileHostnamePattern(appHostname)
if err != nil {
return xerrors.Errorf("parse wildcard access URL %q: %w", appHostname, err)
}
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
promRegistry := prometheus.NewRegistry()
oauthInstrument := promoauth.NewFactory(promRegistry)
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
externalAuthConfigs, err := externalauth.ConvertConfig(
oauthInstrument,
vals.ExternalAuthConfigs.Value,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range externalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
if err != nil {
return xerrors.Errorf("parse real ip config: %w", err)
}
configSSHOptions, err := vals.SSHConfig.ParseOptions()
if err != nil {
return xerrors.Errorf("parse ssh config options %q: %w", vals.SSHConfig.SSHConfigOptions.String(), err)
}
options := &coderd.Options{
AccessURL: vals.AccessURL.Value(),
AppHostname: appHostname,
AppHostnameRegex: appHostnameRegex,
Logger: logger.Named("coderd"),
Database: dbmem.New(),
BaseDERPMap: derpMap,
Pubsub: pubsub.NewInMemory(),
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
ExternalAuthConfigs: externalAuthConfigs,
RealIPConfig: realIPConfig,
SecureAuthCookie: vals.SecureAuthCookie.Value(),
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
Telemetry: telemetry.NewNoop(),
MetricsCacheRefreshInterval: vals.MetricsCacheRefreshInterval.Value(),
AgentStatsRefreshInterval: vals.AgentStatRefreshInterval.Value(),
DeploymentValues: vals,
// Do not pass secret values to DeploymentOptions. All values should be read from
// the DeploymentValues instead, this just serves to indicate the source of each
// option. This is just defensive to prevent accidentally leaking.
DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(opts),
PrometheusRegistry: promRegistry,
APIRateLimit: int(vals.RateLimit.API.Value()),
LoginRateLimit: loginRateLimit,
FilesRateLimit: filesRateLimit,
HTTPClient: httpClient,
TemplateScheduleStore: &atomic.Pointer[schedule.TemplateScheduleStore]{},
UserQuietHoursScheduleStore: &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{},
SSHConfig: codersdk.SSHConfigResponse{
HostnamePrefix: vals.SSHConfig.DeploymentName.String(),
SSHConfigOptions: configSSHOptions,
},
AllowWorkspaceRenames: vals.AllowWorkspaceRenames.Value(),
}
if httpServers.TLSConfig != nil {
options.TLSCertificates = httpServers.TLSConfig.Certificates
}
if vals.StrictTransportSecurity > 0 {
options.StrictTransportSecurityCfg, err = httpmw.HSTSConfigOptions(
int(vals.StrictTransportSecurity.Value()), vals.StrictTransportSecurityOptions,
)
if err != nil {
return xerrors.Errorf("coderd: setting hsts header failed (options: %v): %w", vals.StrictTransportSecurityOptions, err)
}
}
if vals.UpdateCheck {
options.UpdateCheckOptions = &updatecheck.Options{
// Avoid spamming GitHub API checking for updates.
Interval: 24 * time.Hour,
// Inform server admins of new versions.
Notify: func(r updatecheck.Result) {
if semver.Compare(r.Version, buildinfo.Version()) > 0 {
options.Logger.Info(
context.Background(),
"new version of coder available",
slog.F("new_version", r.Version),
slog.F("url", r.URL),
slog.F("upgrade_instructions", "https://coder.com/docs/coder-oss/latest/admin/upgrade"),
)
}
},
}
}
if vals.OAuth2.Github.ClientSecret != "" {
options.GithubOAuth2Config, err = configureGithubOAuth2(
oauthInstrument,
vals.AccessURL.Value(),
vals.OAuth2.Github.ClientID.String(),
vals.OAuth2.Github.ClientSecret.String(),
vals.OAuth2.Github.AllowSignups.Value(),
vals.OAuth2.Github.AllowEveryone.Value(),
vals.OAuth2.Github.AllowedOrgs,
vals.OAuth2.Github.AllowedTeams,
vals.OAuth2.Github.EnterpriseBaseURL.String(),
)
if err != nil {
return xerrors.Errorf("configure github oauth2: %w", err)
}
}
if vals.OIDC.ClientKeyFile != "" || vals.OIDC.ClientSecret != "" {
if vals.OIDC.IgnoreEmailVerified {
logger.Warn(ctx, "coder will not check email_verified for OIDC logins")
}
// This OIDC config is **not** being instrumented with the
// oauth2 instrument wrapper. If we implement the missing
// oidc methods, then we can instrument it.
// Missing:
// - Userinfo
// - Verify
oc, err := createOIDCConfig(ctx, vals)
if err != nil {
return xerrors.Errorf("create oidc config: %w", err)
}
options.OIDCConfig = oc
}
// We'll read from this channel in the select below that tracks shutdown. If it remains
// nil, that case of the select will just never fire, but it's important not to have a
// "bare" read on this channel.
var pubsubWatchdogTimeout <-chan struct{}
if vals.InMemoryDatabase {
// This is only used for testing.
options.Database = dbmem.New()
options.Pubsub = pubsub.NewInMemory()
} else {
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
defer func() {
_ = sqlDB.Close()
}()
options.Database = database.New(sqlDB)
ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL)
if err != nil {
return xerrors.Errorf("create pubsub: %w", err)
}
options.Pubsub = ps
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(ps)
}
defer options.Pubsub.Close()
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
pubsubWatchdogTimeout = psWatchdog.Timeout()
defer psWatchdog.Close()
}
if options.DeploymentValues.Prometheus.Enable && options.DeploymentValues.Prometheus.CollectDBMetrics {
options.Database = dbmetrics.New(options.Database, options.PrometheusRegistry)
}
var deploymentID string
err = options.Database.InTx(func(tx database.Store) error {
// This will block until the lock is acquired, and will be
// automatically released when the transaction ends.
err := tx.AcquireLock(ctx, database.LockIDDeploymentSetup)
if err != nil {
return xerrors.Errorf("acquire lock: %w", err)
}
deploymentID, err = tx.GetDeploymentID(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get deployment id: %w", err)
}
if deploymentID == "" {
deploymentID = uuid.NewString()
err = tx.InsertDeploymentID(ctx, deploymentID)
if err != nil {
return xerrors.Errorf("set deployment id: %w", err)
}
}
// Read the app signing key from the DB. We store it hex encoded
// since the config table uses strings for the value and we
// don't want to deal with automatic encoding issues.
appSecurityKeyStr, err := tx.GetAppSecurityKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get app signing key: %w", err)
}
// If the string in the DB is an invalid hex string or the
// length is not equal to the current key length, generate a new
// one.
//
// If the key is regenerated, old signed tokens and encrypted
// strings will become invalid. New signed app tokens will be
// generated automatically on failure. Any workspace app token
// smuggling operations in progress may fail, although with a
// helpful error.
if decoded, err := hex.DecodeString(appSecurityKeyStr); err != nil || len(decoded) != len(workspaceapps.SecurityKey{}) {
b := make([]byte, len(workspaceapps.SecurityKey{}))
_, err := rand.Read(b)
if err != nil {
return xerrors.Errorf("generate fresh app signing key: %w", err)
}
appSecurityKeyStr = hex.EncodeToString(b)
err = tx.UpsertAppSecurityKey(ctx, appSecurityKeyStr)
if err != nil {
return xerrors.Errorf("insert freshly generated app signing key to database: %w", err)
}
}
appSecurityKey, err := workspaceapps.KeyFromString(appSecurityKeyStr)
if err != nil {
return xerrors.Errorf("decode app signing key from database: %w", err)
}
options.AppSecurityKey = appSecurityKey
// Read the oauth signing key from the database. Like the app security, generate a new one
// if it is invalid for any reason.
oauthSigningKeyStr, err := tx.GetOAuthSigningKey(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get app oauth signing key: %w", err)
}
if decoded, err := hex.DecodeString(oauthSigningKeyStr); err != nil || len(decoded) != len(options.OAuthSigningKey) {
b := make([]byte, len(options.OAuthSigningKey))
_, err := rand.Read(b)
if err != nil {
return xerrors.Errorf("generate fresh oauth signing key: %w", err)
}
oauthSigningKeyStr = hex.EncodeToString(b)
err = tx.UpsertOAuthSigningKey(ctx, oauthSigningKeyStr)
if err != nil {
return xerrors.Errorf("insert freshly generated oauth signing key to database: %w", err)
}
}
keyBytes, err := hex.DecodeString(oauthSigningKeyStr)
if err != nil {
return xerrors.Errorf("decode oauth signing key from database: %w", err)
}
if len(keyBytes) != len(options.OAuthSigningKey) {
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(keyBytes))
}
copy(options.OAuthSigningKey[:], keyBytes)
if options.OAuthSigningKey == [32]byte{} {
return xerrors.Errorf("oauth signing key in database is empty")
}
return nil
}, nil)
if err != nil {
return err
}
// This should be output before the logs start streaming.
cliui.Infof(inv.Stdout, "\n==> Logs will stream in below (press ctrl+c to gracefully exit):")
if vals.Telemetry.Enable {
gitAuth := make([]telemetry.GitAuth, 0)
// TODO:
var gitAuthConfigs []codersdk.ExternalAuthConfig
for _, cfg := range gitAuthConfigs {
gitAuth = append(gitAuth, telemetry.GitAuth{
Type: cfg.Type,
})
}
options.Telemetry, err = telemetry.New(telemetry.Options{
BuiltinPostgres: builtinPostgres,
DeploymentID: deploymentID,
Database: options.Database,
Logger: logger.Named("telemetry"),
URL: vals.Telemetry.URL.Value(),
Wildcard: vals.WildcardAccessURL.String() != "",
DERPServerRelayURL: vals.DERP.Server.RelayURL.String(),
GitAuth: gitAuth,
GitHubOAuth: vals.OAuth2.Github.ClientID != "",
OIDCAuth: vals.OIDC.ClientID != "",
OIDCIssuerURL: vals.OIDC.IssuerURL.String(),
Prometheus: vals.Prometheus.Enable.Value(),
STUN: len(vals.DERP.Server.STUNAddresses) != 0,
Tunnel: tunnel != nil,
Experiments: vals.Experiments.Value(),
ParseLicenseJWT: func(lic *telemetry.License) error {
// This will be nil when running in AGPL-only mode.
if options.ParseLicenseClaims == nil {
return nil
}
email, trial, err := options.ParseLicenseClaims(lic.JWT)
if err != nil {
return err
}
if email != "" {
lic.Email = &email
}
lic.Trial = &trial
return nil
},
})
if err != nil {
return xerrors.Errorf("create telemetry reporter: %w", err)
}
defer options.Telemetry.Close()
} else {
logger.Warn(ctx, `telemetry disabled, unable to notify of security issues. Read more: https://coder.com/docs/v2/latest/admin/telemetry`)
}
// This prevents the pprof import from being accidentally deleted.
_ = pprof.Handler
if vals.Pprof.Enable {
//nolint:revive
defer ServeHandler(ctx, logger, nil, vals.Pprof.Address.String(), "pprof")()
}
if vals.Prometheus.Enable {
closeFn, err := enablePrometheus(
ctx,
logger.Named("prometheus"),
vals,
options,
)
if err != nil {
return xerrors.Errorf("enable prometheus: %w", err)
}
defer closeFn()
}
if vals.Swagger.Enable {
options.SwaggerEndpoint = vals.Swagger.Enable.Value()
}
batcher, closeBatcher, err := batchstats.New(ctx,
batchstats.WithLogger(options.Logger.Named("batchstats")),
batchstats.WithStore(options.Database),
)
if err != nil {
return xerrors.Errorf("failed to create agent stats batcher: %w", err)
}
options.StatsBatcher = batcher
defer closeBatcher()
// We use a separate coderAPICloser so the Enterprise API
// can have its own close functions. This is cleaner
// than abstracting the Coder API itself.
coderAPI, coderAPICloser, err := newAPI(ctx, options)
if err != nil {
return xerrors.Errorf("create coder API: %w", err)
}
if vals.Prometheus.Enable {
// Agent metrics require reference to the tailnet coordinator, so must be initiated after Coder API.
closeAgentsFunc, err := prometheusmetrics.Agents(ctx, logger, options.PrometheusRegistry, coderAPI.Database, &coderAPI.TailnetCoordinator, coderAPI.DERPMap, coderAPI.Options.AgentInactiveDisconnectTimeout, 0)
if err != nil {
return xerrors.Errorf("register agents prometheus metric: %w", err)
}
defer closeAgentsFunc()
var active codersdk.Experiments
for _, exp := range options.DeploymentValues.Experiments.Value() {
active = append(active, codersdk.Experiment(exp))
}
if err = prometheusmetrics.Experiments(options.PrometheusRegistry, active); err != nil {
return xerrors.Errorf("register experiments metric: %w", err)
}
}
client := codersdk.New(localURL)
if localURL.Scheme == "https" && IsLocalhost(localURL.Hostname()) {
// The certificate will likely be self-signed or for a different
// hostname, so we need to skip verification.
client.HTTPClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
}
}
defer client.HTTPClient.CloseIdleConnections()
// This is helpful for tests, but can be silently ignored.
// Coder may be ran as users that don't have permission to write in the homedir,
// such as via the systemd service.
err = config.URL().Write(client.URL.String())
if err != nil && flag.Lookup("test.v") != nil {
return xerrors.Errorf("write config url: %w", err)
}
// Since errCh only has one buffered slot, all routines
// sending on it must be wrapped in a select/default to
// avoid leaving dangling goroutines waiting for the
// channel to be consumed.
errCh := make(chan error, 1)
provisionerDaemons := make([]*provisionerd.Server, 0)
defer func() {
// We have no graceful shutdown of provisionerDaemons
// here because that's handled at the end of main, this
// is here in case the program exits early.
for _, daemon := range provisionerDaemons {
_ = daemon.Close()
}
}()
var provisionerdWaitGroup sync.WaitGroup
defer provisionerdWaitGroup.Wait()
provisionerdMetrics := provisionerd.NewMetrics(options.PrometheusRegistry)
// Built in provisioner daemons will support the same types.
// By default, this is the slice {"terraform"}
provisionerTypes := make([]codersdk.ProvisionerType, 0)
for _, pt := range vals.Provisioner.DaemonTypes {
provisionerTypes = append(provisionerTypes, codersdk.ProvisionerType(pt))
}
for i := int64(0); i < vals.Provisioner.Daemons.Value(); i++ {
suffix := fmt.Sprintf("%d", i)
// The suffix is added to the hostname, so we may need to trim to fit into
// the 64 character limit.
hostname := stringutil.Truncate(cliutil.Hostname(), 63-len(suffix))
name := fmt.Sprintf("%s-%s", hostname, suffix)
daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i))
daemon, err := newProvisionerDaemon(
ctx, coderAPI, provisionerdMetrics, logger, vals, daemonCacheDir, errCh, &provisionerdWaitGroup, name, provisionerTypes,
)
if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err)
}
provisionerDaemons = append(provisionerDaemons, daemon)
}
provisionerdMetrics.Runner.NumDaemons.Set(float64(len(provisionerDaemons)))
shutdownConnsCtx, shutdownConns := context.WithCancel(ctx)
defer shutdownConns()
// Ensures that old database entries are cleaned up over time!
purger := dbpurge.New(ctx, logger.Named("dbpurge"), options.Database)
defer purger.Close()
// Updates workspace usage
tracker := workspaceusage.New(options.Database,
workspaceusage.WithLogger(logger.Named("workspace_usage_tracker")),
)
options.WorkspaceUsageTracker = tracker
defer tracker.Close()
// Wrap the server in middleware that redirects to the access URL if
// the request is not to a local IP.
var handler http.Handler = coderAPI.RootHandler
if vals.RedirectToAccessURL {
handler = redirectToAccessURL(handler, vals.AccessURL.Value(), tunnel != nil, appHostnameRegex)
}
// ReadHeaderTimeout is purposefully not enabled. It caused some
// issues with websockets over the dev tunnel.
// See: https://github.com/coder/coder/pull/3730
//nolint:gosec
httpServer := &http.Server{
// These errors are typically noise like "TLS: EOF". Vault does
// similar:
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
ErrorLog: log.New(io.Discard, "", 0),
Handler: handler,
BaseContext: func(_ net.Listener) context.Context {
return shutdownConnsCtx
},
}
defer func() {
_ = shutdownWithTimeout(httpServer.Shutdown, 5*time.Second)
}()
// We call this in the routine so we can kill the other listeners if
// one of them fails.
closeListenersNow := func() {
httpServers.Close()
if tunnel != nil {
_ = tunnel.Listener.Close()
}
}
eg := errgroup.Group{}
eg.Go(func() error {
defer closeListenersNow()
return httpServers.Serve(httpServer)
})
if tunnel != nil {
eg.Go(func() error {
defer closeListenersNow()
return httpServer.Serve(tunnel.Listener)
})
}
go func() {
select {
case errCh <- eg.Wait():
default:
}
}()
// Updates the systemd status from activating to activated.
_, err = daemon.SdNotify(false, daemon.SdNotifyReady)
if err != nil {
return xerrors.Errorf("notify systemd: %w", err)
}
autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value())
defer autobuildTicker.Stop()
autobuildExecutor := autobuild.NewExecutor(
ctx, options.Database, options.Pubsub, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, logger, autobuildTicker.C)
autobuildExecutor.Run()
hangDetectorTicker := time.NewTicker(vals.JobHangDetectorInterval.Value())
defer hangDetectorTicker.Stop()
hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, logger, hangDetectorTicker.C)
hangDetector.Start()
defer hangDetector.Close()
waitForProvisionerJobs := false
// Currently there is no way to ask the server to shut
// itself down, so any exit signal will result in a non-zero
// exit of the server.
var exitErr error
select {
case <-stopCtx.Done():
exitErr = stopCtx.Err()
waitForProvisionerJobs = true
_, _ = io.WriteString(inv.Stdout, cliui.Bold("Stop caught, waiting for provisioner jobs to complete and gracefully exiting. Use ctrl+\\ to force quit"))
case <-interruptCtx.Done():
exitErr = interruptCtx.Err()
_, _ = io.WriteString(inv.Stdout, cliui.Bold("Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit"))
case <-tunnelDone:
exitErr = xerrors.New("dev tunnel closed unexpectedly")
case <-pubsubWatchdogTimeout:
exitErr = xerrors.New("pubsub Watchdog timed out")
case exitErr = <-errCh:
}
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
cliui.Errorf(inv.Stderr, "Unexpected error, shutting down server: %s\n", exitErr)
}
// Begin clean shut down stage, we try to shut down services
// gracefully in an order that gives the best experience.
// This procedure should not differ greatly from the order
// of `defer`s in this function, but allows us to inform
// the user about what's going on and handle errors more
// explicitly.
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
if err != nil {
cliui.Errorf(inv.Stderr, "Notify systemd failed: %s", err)
}
// Stop accepting new connections without interrupting
// in-flight requests, give in-flight requests 5 seconds to
// complete.
cliui.Info(inv.Stdout, "Shutting down API server..."+"\n")
err = shutdownWithTimeout(httpServer.Shutdown, 3*time.Second)
if err != nil {
cliui.Errorf(inv.Stderr, "API server shutdown took longer than 3s: %s\n", err)
} else {
cliui.Info(inv.Stdout, "Gracefully shut down API server\n")
}
// Cancel any remaining in-flight requests.
shutdownConns()
// Shut down provisioners before waiting for WebSockets
// connections to close.
var wg sync.WaitGroup
for i, provisionerDaemon := range provisionerDaemons {
id := i + 1
provisionerDaemon := provisionerDaemon
wg.Add(1)
go func() {
defer wg.Done()
r.Verbosef(inv, "Shutting down provisioner daemon %d...", id)
timeout := 5 * time.Second
if waitForProvisionerJobs {
// It can last for a long time...
timeout = 30 * time.Minute
}
err := shutdownWithTimeout(func(ctx context.Context) error {
// We only want to cancel active jobs if we aren't exiting gracefully.
return provisionerDaemon.Shutdown(ctx, !waitForProvisionerJobs)
}, timeout)
if err != nil {
cliui.Errorf(inv.Stderr, "Failed to shut down provisioner daemon %d: %s\n", id, err)
return
}
err = provisionerDaemon.Close()
if err != nil {
cliui.Errorf(inv.Stderr, "Close provisioner daemon %d: %s\n", id, err)
return
}
r.Verbosef(inv, "Gracefully shut down provisioner daemon %d", id)
}()
}
wg.Wait()
cliui.Info(inv.Stdout, "Waiting for WebSocket connections to close..."+"\n")
_ = coderAPICloser.Close()
cliui.Info(inv.Stdout, "Done waiting for WebSocket connections"+"\n")
// Close tunnel after we no longer have in-flight connections.
if tunnel != nil {
cliui.Infof(inv.Stdout, "Waiting for tunnel to close...")
_ = tunnel.Close()
<-tunnel.Wait()
cliui.Infof(inv.Stdout, "Done waiting for tunnel")
}
// Ensures a last report can be sent before exit!
options.Telemetry.Close()
// Trigger context cancellation for any remaining services.
cancel()
switch {
case xerrors.Is(exitErr, context.DeadlineExceeded):
cliui.Warnf(inv.Stderr, "Graceful shutdown timed out")
// Errors here cause a significant number of benign CI failures.
return nil
case xerrors.Is(exitErr, context.Canceled):
return nil
case exitErr != nil:
return xerrors.Errorf("graceful shutdown: %w", exitErr)
default:
return nil
}
},
}
var pgRawURL bool
postgresBuiltinURLCmd := &serpent.Command{
Use: "postgres-builtin-url",
Short: "Output the connection URL for the built-in PostgreSQL deployment.",
Handler: func(inv *serpent.Invocation) error {
url, err := embeddedPostgresURL(r.createConfig())
if err != nil {
return err
}
if pgRawURL {
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", url)
} else {
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", pretty.Sprint(cliui.DefaultStyles.Code, fmt.Sprintf("psql %q", url)))
}
return nil
},
}
postgresBuiltinServeCmd := &serpent.Command{
Use: "postgres-builtin-serve",
Short: "Run the built-in PostgreSQL deployment.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
cfg := r.createConfig()
logger := inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr))
if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok {
logger = logger.Leveled(slog.LevelDebug)
}
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer cancel()
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)
if err != nil {
return err
}
defer func() { _ = closePg() }()
if pgRawURL {
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", url)
} else {
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", pretty.Sprint(cliui.DefaultStyles.Code, fmt.Sprintf("psql %q", url)))
}
<-ctx.Done()
return nil
},
}
createAdminUserCmd := r.newCreateAdminUserCommand()
rawURLOpt := serpent.Option{
Flag: "raw-url",
Value: serpent.BoolOf(&pgRawURL),
Description: "Output the raw connection URL instead of a psql command.",
}
createAdminUserCmd.Options.Add(rawURLOpt)
postgresBuiltinURLCmd.Options.Add(rawURLOpt)
postgresBuiltinServeCmd.Options.Add(rawURLOpt)
serverCmd.Children = append(
serverCmd.Children,
createAdminUserCmd, postgresBuiltinURLCmd, postgresBuiltinServeCmd,
)
return serverCmd
}
// printDeprecatedOptions loops through all command options, and prints
// a warning for usage of deprecated options.
func PrintDeprecatedOptions() serpent.MiddlewareFunc {
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
return func(inv *serpent.Invocation) error {
opts := inv.Command.Options
// Print deprecation warnings.
for _, opt := range opts {
if opt.UseInstead == nil {
continue
}
if opt.ValueSource == serpent.ValueSourceNone || opt.ValueSource == serpent.ValueSourceDefault {
continue
}
warnStr := opt.Name + " is deprecated, please use "
for i, use := range opt.UseInstead {
warnStr += use.Name + " "
if i != len(opt.UseInstead)-1 {
warnStr += "and "
}
}
warnStr += "instead.\n"
cliui.Warn(inv.Stderr,
warnStr,
)
}
return next(inv)
}
}
}
// writeConfigMW will prevent the main command from running if the write-config
// flag is set. Instead, it will marshal the command options to YAML and write
// them to stdout.
func WriteConfigMW(cfg *codersdk.DeploymentValues) serpent.MiddlewareFunc {
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
return func(inv *serpent.Invocation) error {
if !cfg.WriteConfig {
return next(inv)
}
opts := inv.Command.Options
n, err := opts.MarshalYAML()
if err != nil {
return xerrors.Errorf("generate yaml: %w", err)
}
enc := yaml.NewEncoder(inv.Stdout)
enc.SetIndent(2)
err = enc.Encode(n)
if err != nil {
return xerrors.Errorf("encode yaml: %w", err)
}
err = enc.Close()
if err != nil {
return xerrors.Errorf("close yaml encoder: %w", err)
}
return nil
}
}
}
// isLocalURL returns true if the hostname of the provided URL appears to
// resolve to a loopback address.
func IsLocalURL(ctx context.Context, u *url.URL) (bool, error) {
// In tests, we commonly use "example.com" or "google.com", which
// are not loopback, so avoid the DNS lookup to avoid flakes.
if flag.Lookup("test.v") != nil {
if u.Hostname() == "example.com" || u.Hostname() == "google.com" {
return false, nil
}
}
resolver := &net.Resolver{}
ips, err := resolver.LookupIPAddr(ctx, u.Hostname())
if err != nil {
return false, err
}
for _, ip := range ips {
if ip.IP.IsLoopback() {
return true, nil
}
}
return false, nil
}
func shutdownWithTimeout(shutdown func(context.Context) error, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return shutdown(ctx)
}
// nolint:revive
func newProvisionerDaemon(
ctx context.Context,
coderAPI *coderd.API,
metrics provisionerd.Metrics,
logger slog.Logger,
cfg *codersdk.DeploymentValues,
cacheDir string,
errCh chan error,
wg *sync.WaitGroup,
name string,
provisionerTypes []codersdk.ProvisionerType,
) (srv *provisionerd.Server, err error) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
}
workDir := filepath.Join(cacheDir, "work")
err = os.MkdirAll(workDir, 0o700)
if err != nil {
return nil, xerrors.Errorf("mkdir work dir: %w", err)
}
// Omit any duplicates
provisionerTypes = slice.Unique(provisionerTypes)
// Populate the connector with the supported types.
connector := provisionerd.LocalProvisioners{}
for _, provisionerType := range provisionerTypes {
switch provisionerType {
case codersdk.ProvisionerTypeEcho:
echoClient, echoServer := drpc.MemTransportPipe()
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
_ = echoClient.Close()
_ = echoServer.Close()
}()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
err := echo.Serve(ctx, &provisionersdk.ServeOptions{
Listener: echoServer,
WorkDirectory: workDir,
Logger: logger.Named("echo"),
})
if err != nil {
select {
case errCh <- err:
default:
}
}
}()
connector[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
case codersdk.ProvisionerTypeTerraform:
tfDir := filepath.Join(cacheDir, "tf")
err = os.MkdirAll(tfDir, 0o700)
if err != nil {
return nil, xerrors.Errorf("mkdir terraform dir: %w", err)
}
tracer := coderAPI.TracerProvider.Tracer(tracing.TracerName)
terraformClient, terraformServer := drpc.MemTransportPipe()
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
_ = terraformClient.Close()
_ = terraformServer.Close()
}()
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
err := terraform.Serve(ctx, &terraform.ServeOptions{
ServeOptions: &provisionersdk.ServeOptions{
Listener: terraformServer,
Logger: logger.Named("terraform"),
WorkDirectory: workDir,
},
CachePath: tfDir,
Tracer: tracer,
})
if err != nil && !xerrors.Is(err, context.Canceled) {
select {
case errCh <- err:
default:
}
}
}()
connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient)
default:
return nil, xerrors.Errorf("unknown provisioner type %q", provisionerType)
}
}
return provisionerd.New(func(dialCtx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
// This debounces calls to listen every second. Read the comment
// in provisionerdserver.go to learn more!
return coderAPI.CreateInMemoryProvisionerDaemon(dialCtx, name, provisionerTypes)
}, &provisionerd.Options{
Logger: logger.Named(fmt.Sprintf("provisionerd-%s", name)),
UpdateInterval: time.Second,
ForceCancelInterval: cfg.Provisioner.ForceCancelInterval.Value(),
Connector: connector,
TracerProvider: coderAPI.TracerProvider,
Metrics: &metrics,
}), nil
}
// nolint: revive
func PrintLogo(inv *serpent.Invocation, daemonTitle string) {
// Only print the logo in TTYs.
if !isTTYOut(inv) {
return
}
versionString := cliui.Bold(daemonTitle + " " + buildinfo.Version())
_, _ = fmt.Fprintf(inv.Stdout, "%s - Your Self-Hosted Remote Development Platform\n", versionString)
}
func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) {
if len(tlsCertFiles) != len(tlsKeyFiles) {
return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times")
}
certs := make([]tls.Certificate, len(tlsCertFiles))
for i := range tlsCertFiles {
certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i]
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, xerrors.Errorf(
"load TLS key pair %d (%q, %q): %w\ncertFiles: %+v\nkeyFiles: %+v",
i, certFile, keyFile, err,
tlsCertFiles, tlsKeyFiles,
)
}
certs[i] = cert
}
return certs, nil
}
// generateSelfSignedCertificate creates an unsafe self-signed certificate
// at random that allows users to proceed with setup in the event they
// haven't configured any TLS certificates.
func generateSelfSignedCertificate() (*tls.Certificate, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}
var cert tls.Certificate
cert.Certificate = append(cert.Certificate, derBytes)
cert.PrivateKey = privateKey
return &cert, nil
}
// configureServerTLS returns the TLS config used for the Coderd server
// connections to clients. A logger is passed in to allow printing warning
// messages that do not block startup.
//
//nolint:revive
func configureServerTLS(ctx context.Context, logger slog.Logger, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string, ciphers []string, allowInsecureCiphers bool) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
NextProtos: []string{"h2", "http/1.1"},
}
switch tlsMinVersion {
case "tls10":
tlsConfig.MinVersion = tls.VersionTLS10
case "tls11":
tlsConfig.MinVersion = tls.VersionTLS11
case "tls12":
tlsConfig.MinVersion = tls.VersionTLS12
case "tls13":
tlsConfig.MinVersion = tls.VersionTLS13
default:
return nil, xerrors.Errorf("unrecognized tls version: %q", tlsMinVersion)
}
// A custom set of supported ciphers.
if len(ciphers) > 0 {
cipherIDs, err := configureCipherSuites(ctx, logger, ciphers, allowInsecureCiphers, tlsConfig.MinVersion, tls.VersionTLS13)
if err != nil {
return nil, err
}
tlsConfig.CipherSuites = cipherIDs
}
switch tlsClientAuth {
case "none":
tlsConfig.ClientAuth = tls.NoClientCert
case "request":
tlsConfig.ClientAuth = tls.RequestClientCert
case "require-any":
tlsConfig.ClientAuth = tls.RequireAnyClientCert
case "verify-if-given":
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
case "require-and-verify":
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
default:
return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth)
}
certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles)
if err != nil {
return nil, xerrors.Errorf("load certificates: %w", err)
}
if len(certs) == 0 {
selfSignedCertificate, err := generateSelfSignedCertificate()
if err != nil {
return nil, xerrors.Errorf("generate self signed certificate: %w", err)
}
certs = append(certs, *selfSignedCertificate)
}
tlsConfig.Certificates = certs
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
// If there's only one certificate, return it.
if len(certs) == 1 {
return &certs[0], nil
}
// Expensively check which certificate matches the client hello.
for _, cert := range certs {
cert := cert
if err := hi.SupportsCertificate(&cert); err == nil {
return &cert, nil
}
}
// Return the first certificate if we have one, or return nil so the
// server doesn't fail.
if len(certs) > 0 {
return &certs[0], nil
}
return nil, nil //nolint:nilnil
}
err = configureCAPool(tlsClientCAFile, tlsConfig)
if err != nil {
return nil, err
}
return tlsConfig, nil
}
//nolint:revive
func configureCipherSuites(ctx context.Context, logger slog.Logger, ciphers []string, allowInsecureCiphers bool, minTLS, maxTLS uint16) ([]uint16, error) {
if minTLS > maxTLS {
return nil, xerrors.Errorf("minimum tls version (%s) cannot be greater than maximum tls version (%s)", versionName(minTLS), versionName(maxTLS))
}
if minTLS >= tls.VersionTLS13 {
// The cipher suites config option is ignored for tls 1.3 and higher.
// So this user flag is a no-op if the min version is 1.3.
return nil, xerrors.Errorf("'--tls-ciphers' cannot be specified when using minimum tls version 1.3 or higher, %d ciphers found as input.", len(ciphers))
}
// Configure the cipher suites which parses the strings and converts them
// to golang cipher suites.
supported, err := parseTLSCipherSuites(ciphers)
if err != nil {
return nil, xerrors.Errorf("tls ciphers: %w", err)
}
// allVersions is all tls versions the server supports.
// We enumerate these to ensure if ciphers are configured, at least
// 1 cipher for each version exists.
allVersions := make(map[uint16]bool)
for v := minTLS; v <= maxTLS; v++ {
allVersions[v] = false
}
var insecure []string
cipherIDs := make([]uint16, 0, len(supported))
for _, cipher := range supported {
if cipher.Insecure {
// Always show this warning, even if they have allowInsecureCiphers
// specified.
logger.Warn(ctx, "insecure tls cipher specified for server use", slog.F("cipher", cipher.Name))
insecure = append(insecure, cipher.Name)
}
// This is a warning message to tell the user if they are specifying
// a cipher that does not support the tls versions they have specified.
// This makes the cipher essentially a "noop" cipher.
if !hasSupportedVersion(minTLS, maxTLS, cipher.SupportedVersions) {
versions := make([]string, 0, len(cipher.SupportedVersions))
for _, sv := range cipher.SupportedVersions {
versions = append(versions, versionName(sv))
}
logger.Warn(ctx, "cipher not supported for tls versions enabled, cipher will not be used",
slog.F("cipher", cipher.Name),
slog.F("cipher_supported_versions", strings.Join(versions, ",")),
slog.F("server_min_version", versionName(minTLS)),
slog.F("server_max_version", versionName(maxTLS)),
)
}
for _, v := range cipher.SupportedVersions {
allVersions[v] = true
}
cipherIDs = append(cipherIDs, cipher.ID)
}
if len(insecure) > 0 && !allowInsecureCiphers {
return nil, xerrors.Errorf("insecure tls ciphers specified, must use '--tls-allow-insecure-ciphers' to allow these: %s", strings.Join(insecure, ", "))
}
// This is an additional sanity check. The user can specify ciphers that
// do not cover the full range of tls versions they have specified.
// They can unintentionally break TLS for some tls configured versions.
var missedVersions []string
for version, covered := range allVersions {
if version == tls.VersionTLS13 {
continue // v1.3 ignores configured cipher suites.
}
if !covered {
missedVersions = append(missedVersions, versionName(version))
}
}
if len(missedVersions) > 0 {
return nil, xerrors.Errorf("no tls ciphers supported for tls versions %q."+
"Add additional ciphers, set the minimum version to 'tls13, or remove the ciphers configured and rely on the default",
strings.Join(missedVersions, ","))
}
return cipherIDs, nil
}
// parseTLSCipherSuites will parse cipher suite names like 'TLS_RSA_WITH_AES_128_CBC_SHA'
// to their tls cipher suite structs. If a cipher suite that is unsupported is
// passed in, this function will return an error.
// This function can return insecure cipher suites.
func parseTLSCipherSuites(ciphers []string) ([]tls.CipherSuite, error) {
if len(ciphers) == 0 {
return nil, nil
}
var unsupported []string
var supported []tls.CipherSuite
// A custom set of supported ciphers.
allCiphers := append(tls.CipherSuites(), tls.InsecureCipherSuites()...)
for _, cipher := range ciphers {
// For each cipher specified by the client, find the cipher in the
// list of golang supported ciphers.
var found *tls.CipherSuite
for _, supported := range allCiphers {
if strings.EqualFold(supported.Name, cipher) {
found = supported
break
}
}
if found == nil {
unsupported = append(unsupported, cipher)
continue
}
supported = append(supported, *found)
}
if len(unsupported) > 0 {
return nil, xerrors.Errorf("unsupported tls ciphers specified, see https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53-L75: %s", strings.Join(unsupported, ", "))
}
return supported, nil
}
// hasSupportedVersion is a helper function that returns true if the list
// of supported versions contains a version between min and max.
// If the versions list is outside the min/max, then it returns false.
func hasSupportedVersion(min, max uint16, versions []uint16) bool {
for _, v := range versions {
if v >= min && v <= max {
// If one version is in between min/max, return true.
return true
}
}
return false
}
// versionName is tls.VersionName in go 1.21.
// Until the switch, the function is copied locally.
func versionName(version uint16) string {
switch version {
case tls.VersionSSL30:
return "SSLv3"
case tls.VersionTLS10:
return "TLS 1.0"
case tls.VersionTLS11:
return "TLS 1.1"
case tls.VersionTLS12:
return "TLS 1.2"
case tls.VersionTLS13:
return "TLS 1.3"
default:
return fmt.Sprintf("0x%04X", version)
}
}
func configureOIDCPKI(orig *oauth2.Config, keyFile string, certFile string) (*oauthpki.Config, error) {
// Read the files
keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, xerrors.Errorf("read oidc client key file: %w", err)
}
var certData []byte
// According to the spec, this is not required. So do not require it on the initial loading
// of the PKI config.
if certFile != "" {
certData, err = os.ReadFile(certFile)
if err != nil {
return nil, xerrors.Errorf("read oidc client cert file: %w", err)
}
}
return oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: orig.ClientID,
TokenURL: orig.Endpoint.TokenURL,
Scopes: orig.Scopes,
PemEncodedKey: keyData,
PemEncodedCert: certData,
Config: orig,
})
}
func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {
if tlsClientCAFile != "" {
caPool := x509.NewCertPool()
data, err := os.ReadFile(tlsClientCAFile)
if err != nil {
return xerrors.Errorf("read %q: %w", tlsClientCAFile, err)
}
if !caPool.AppendCertsFromPEM(data) {
return xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file")
}
tlsConfig.ClientCAs = caPool
}
return nil
}
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
func configureGithubOAuth2(instrument *promoauth.Factory, accessURL *url.URL, clientID, clientSecret string, allowSignups, allowEveryone bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
if err != nil {
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
}
if allowEveryone && len(allowOrgs) > 0 {
return nil, xerrors.New("allow everyone and allowed orgs cannot be used together")
}
if allowEveryone && len(rawTeams) > 0 {
return nil, xerrors.New("allow everyone and allowed teams cannot be used together")
}
if !allowEveryone && len(allowOrgs) == 0 {
return nil, xerrors.New("allowed orgs is empty: must specify at least one org or allow everyone")
}
allowTeams := make([]coderd.GithubOAuth2Team, 0, len(rawTeams))
for _, rawTeam := range rawTeams {
parts := strings.SplitN(rawTeam, "/", 2)
if len(parts) != 2 {
return nil, xerrors.Errorf("github team allowlist is formatted incorrectly. got %s; wanted <organization>/<team>", rawTeam)
}
allowTeams = append(allowTeams, coderd.GithubOAuth2Team{
Organization: parts[0],
Slug: parts[1],
})
}
endpoint := xgithub.Endpoint
if enterpriseBaseURL != "" {
enterpriseURL, err := url.Parse(enterpriseBaseURL)
if err != nil {
return nil, xerrors.Errorf("parse enterprise base url: %w", err)
}
authURL, err := enterpriseURL.Parse("/login/oauth/authorize")
if err != nil {
return nil, xerrors.Errorf("parse enterprise auth url: %w", err)
}
tokenURL, err := enterpriseURL.Parse("/login/oauth/access_token")
if err != nil {
return nil, xerrors.Errorf("parse enterprise token url: %w", err)
}
endpoint = oauth2.Endpoint{
AuthURL: authURL.String(),
TokenURL: tokenURL.String(),
}
}
instrumentedOauth := instrument.NewGithub("github-login", &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: endpoint,
RedirectURL: redirectURL.String(),
Scopes: []string{
"read:user",
"read:org",
"user:email",
},
})
createClient := func(client *http.Client, source promoauth.Oauth2Source) (*github.Client, error) {
client = instrumentedOauth.InstrumentHTTPClient(client, source)
if enterpriseBaseURL != "" {
return github.NewEnterpriseClient(enterpriseBaseURL, "", client)
}
return github.NewClient(client), nil
}
return &coderd.GithubOAuth2Config{
OAuth2Config: instrumentedOauth,
AllowSignups: allowSignups,
AllowEveryone: allowEveryone,
AllowOrganizations: allowOrgs,
AllowTeams: allowTeams,
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
api, err := createClient(client, promoauth.SourceGitAPIAuthUser)
if err != nil {
return nil, err
}
user, _, err := api.Users.Get(ctx, "")
return user, err
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
api, err := createClient(client, promoauth.SourceGitAPIListEmails)
if err != nil {
return nil, err
}
emails, _, err := api.Users.ListEmails(ctx, &github.ListOptions{})
return emails, err
},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
api, err := createClient(client, promoauth.SourceGitAPIOrgMemberships)
if err != nil {
return nil, err
}
memberships, _, err := api.Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{
State: "active",
ListOptions: github.ListOptions{
PerPage: 100,
},
})
return memberships, err
},
TeamMembership: func(ctx context.Context, client *http.Client, org, teamSlug, username string) (*github.Membership, error) {
api, err := createClient(client, promoauth.SourceGitAPITeamMemberships)
if err != nil {
return nil, err
}
team, _, err := api.Teams.GetTeamMembershipBySlug(ctx, org, teamSlug, username)
return team, err
},
}, nil
}
// embeddedPostgresURL returns the URL for the embedded PostgreSQL deployment.
func embeddedPostgresURL(cfg config.Root) (string, error) {
pgPassword, err := cfg.PostgresPassword().Read()
if errors.Is(err, os.ErrNotExist) {
pgPassword, err = cryptorand.String(16)
if err != nil {
return "", xerrors.Errorf("generate password: %w", err)
}
err = cfg.PostgresPassword().Write(pgPassword)
if err != nil {
return "", xerrors.Errorf("write password: %w", err)
}
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return "", err
}
pgPort, err := cfg.PostgresPort().Read()
if errors.Is(err, os.ErrNotExist) {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return "", xerrors.Errorf("listen for random port: %w", err)
}
_ = listener.Close()
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
if !valid {
return "", xerrors.Errorf("listener returned non TCP addr: %T", tcpAddr)
}
pgPort = strconv.Itoa(tcpAddr.Port)
err = cfg.PostgresPort().Write(pgPort)
if err != nil {
return "", xerrors.Errorf("write postgres port: %w", err)
}
}
return fmt.Sprintf("postgres://coder@localhost:%s/coder?sslmode=disable&password=%s", pgPort, pgPassword), nil
}
func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logger) (string, func() error, error) {
usr, err := user.Current()
if err != nil {
return "", nil, err
}
if usr.Uid == "0" {
return "", nil, xerrors.New("The built-in PostgreSQL cannot run as the root user. Create a non-root user and run again!")
}
// Ensure a password and port have been generated!
connectionURL, err := embeddedPostgresURL(cfg)
if err != nil {
return "", nil, err
}
pgPassword, err := cfg.PostgresPassword().Read()
if err != nil {
return "", nil, xerrors.Errorf("read postgres password: %w", err)
}
pgPortRaw, err := cfg.PostgresPort().Read()
if err != nil {
return "", nil, xerrors.Errorf("read postgres port: %w", err)
}
pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16)
if err != nil {
return "", nil, xerrors.Errorf("parse postgres port: %w", err)
}
stdlibLogger := slog.Stdlib(ctx, logger.Named("postgres"), slog.LevelDebug)
ep := embeddedpostgres.NewDatabase(
embeddedpostgres.DefaultConfig().
Version(embeddedpostgres.V13).
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
DataPath(filepath.Join(cfg.PostgresPath(), "data")).
RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")).
CachePath(filepath.Join(cfg.PostgresPath(), "cache")).
Username("coder").
Password(pgPassword).
Database("coder").
Port(uint32(pgPort)).
Logger(stdlibLogger.Writer()),
)
err = ep.Start()
if err != nil {
return "", nil, xerrors.Errorf("Failed to start built-in PostgreSQL. Optionally, specify an external deployment with `--postgres-url`: %w", err)
}
return connectionURL, ep.Stop, nil
}
func ConfigureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) {
if clientCertFile != "" && clientKeyFile != "" {
certificates, err := loadCertificates([]string{clientCertFile}, []string{clientKeyFile})
if err != nil {
return ctx, nil, err
}
tlsClientConfig := &tls.Config{ //nolint:gosec
Certificates: certificates,
NextProtos: []string{"h2", "http/1.1"},
}
err = configureCAPool(tlsClientCAFile, tlsClientConfig)
if err != nil {
return nil, nil, err
}
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsClientConfig,
},
}
return context.WithValue(ctx, oauth2.HTTPClient, httpClient), httpClient, nil
}
return ctx, &http.Client{}, nil
}
// nolint:revive
func redirectToAccessURL(handler http.Handler, accessURL *url.URL, tunnel bool, appHostnameRegex *regexp.Regexp) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
redirect := func() {
http.Redirect(w, r, accessURL.String(), http.StatusTemporaryRedirect)
}
// Exception: /healthz
// Kubernetes doesn't like it if you redirect your healthcheck or liveness check endpoint.
if r.URL.Path == "/healthz" {
handler.ServeHTTP(w, r)
return
}
// Exception: DERP
// We use this endpoint when creating a DERP-mesh in the enterprise version to directly
// dial other Coderd derpers. Redirecting to the access URL breaks direct dial since the
// access URL will be load-balanced in a multi-replica deployment.
//
// It's totally fine to access DERP over TLS, but we also don't need to redirect HTTP to
// HTTPS as DERP is itself an encrypted protocol.
if isDERPPath(r.URL.Path) {
handler.ServeHTTP(w, r)
return
}
// Only do this if we aren't tunneling.
// If we are tunneling, we want to allow the request to go through
// because the tunnel doesn't proxy with TLS.
if !tunnel && accessURL.Scheme == "https" && r.TLS == nil {
redirect()
return
}
if r.Host == accessURL.Host {
handler.ServeHTTP(w, r)
return
}
if r.Header.Get("X-Forwarded-Host") == accessURL.Host {
handler.ServeHTTP(w, r)
return
}
if appHostnameRegex != nil && appHostnameRegex.MatchString(r.Host) {
handler.ServeHTTP(w, r)
return
}
redirect()
})
}
func isDERPPath(p string) bool {
segments := strings.SplitN(p, "/", 3)
if len(segments) < 2 {
return false
}
return segments[1] == "derp"
}
// IsLocalhost returns true if the host points to the local machine. Intended to
// be called with `u.Hostname()`.
func IsLocalhost(host string) bool {
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
logger.Debug(ctx, "connecting to postgresql")
// Try to connect for 30 seconds.
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
defer func() {
if err == nil {
return
}
if sqlDB != nil {
_ = sqlDB.Close()
sqlDB = nil
}
logger.Error(ctx, "connect to postgres failed", slog.Error(err))
}()
var tries int
for r := retry.New(time.Second, 3*time.Second); r.Wait(ctx); {
tries++
sqlDB, err = sql.Open(driver, dbURL)
if err != nil {
logger.Warn(ctx, "connect to postgres: retrying", slog.Error(err), slog.F("try", tries))
continue
}
err = pingPostgres(ctx, sqlDB)
if err != nil {
logger.Warn(ctx, "ping postgres: retrying", slog.Error(err), slog.F("try", tries))
_ = sqlDB.Close()
sqlDB = nil
continue
}
break
}
if err == nil {
err = ctx.Err()
}
if err != nil {
return nil, xerrors.Errorf("unable to connect after %d tries; last error: %w", tries, err)
}
// Ensure the PostgreSQL version is >=13.0.0!
version, err := sqlDB.QueryContext(ctx, "SHOW server_version_num;")
if err != nil {
return nil, xerrors.Errorf("get postgres version: %w", err)
}
if !version.Next() {
return nil, xerrors.Errorf("no rows returned for version select")
}
var versionNum int
err = version.Scan(&versionNum)
if err != nil {
return nil, xerrors.Errorf("scan version: %w", err)
}
_ = version.Close()
if versionNum < 130000 {
return nil, xerrors.Errorf("PostgreSQL version must be v13.0.0 or higher! Got: %d", versionNum)
}
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
err = migrations.Up(sqlDB)
if err != nil {
return nil, xerrors.Errorf("migrate up: %w", err)
}
// The default is 0 but the request will fail with a 500 if the DB
// cannot accept new connections, so we try to limit that here.
// Requests will wait for a new connection instead of a hard error
// if a limit is set.
sqlDB.SetMaxOpenConns(10)
// Allow a max of 3 idle connections at a time. Lower values end up
// creating a lot of connection churn. Since each connection uses about
// 10MB of memory, we're allocating 30MB to Postgres connections per
// replica, but is better than causing Postgres to spawn a thread 15-20
// times/sec. PGBouncer's transaction pooling is not the greatest so
// it's not optimal for us to deploy.
//
// This was set to 10 before we started doing HA deployments, but 3 was
// later determined to be a better middle ground as to not use up all
// of PGs default connection limit while simultaneously avoiding a lot
// of connection churn.
sqlDB.SetMaxIdleConns(3)
return sqlDB, nil
}
func pingPostgres(ctx context.Context, db *sql.DB) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return db.PingContext(ctx)
}
type HTTPServers struct {
HTTPUrl *url.URL
HTTPListener net.Listener
// TLS
TLSUrl *url.URL
TLSListener net.Listener
TLSConfig *tls.Config
}
// Serve acts just like http.Serve. It is a blocking call until the server
// is closed, and an error is returned if any underlying Serve call fails.
func (s *HTTPServers) Serve(srv *http.Server) error {
eg := errgroup.Group{}
if s.HTTPListener != nil {
eg.Go(func() error {
defer s.Close() // close all listeners on error
return srv.Serve(s.HTTPListener)
})
}
if s.TLSListener != nil {
eg.Go(func() error {
defer s.Close() // close all listeners on error
return srv.Serve(s.TLSListener)
})
}
return eg.Wait()
}
func (s *HTTPServers) Close() {
if s.HTTPListener != nil {
_ = s.HTTPListener.Close()
}
if s.TLSListener != nil {
_ = s.TLSListener.Close()
}
}
func ConfigureTraceProvider(
ctx context.Context,
logger slog.Logger,
cfg *codersdk.DeploymentValues,
) (trace.TracerProvider, string, func(context.Context) error) {
var (
tracerProvider = trace.NewNoopTracerProvider()
closeTracing = func(context.Context) error { return nil }
sqlDriver = "postgres"
)
otel.SetTextMapPropagator(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
)
if cfg.Trace.Enable.Value() || cfg.Trace.DataDog.Value() || cfg.Trace.HoneycombAPIKey != "" {
sdkTracerProvider, _closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
Default: cfg.Trace.Enable.Value(),
DataDog: cfg.Trace.DataDog.Value(),
Honeycomb: cfg.Trace.HoneycombAPIKey.String(),
})
if err != nil {
logger.Warn(ctx, "start telemetry exporter", slog.Error(err))
} else {
d, err := tracing.PostgresDriver(sdkTracerProvider, "coderd.database")
if err != nil {
logger.Warn(ctx, "start postgres tracing driver", slog.Error(err))
} else {
sqlDriver = d
}
tracerProvider = sdkTracerProvider
closeTracing = _closeTracing
}
}
return tracerProvider, sqlDriver, closeTracing
}
func ConfigureHTTPServers(logger slog.Logger, inv *serpent.Invocation, cfg *codersdk.DeploymentValues) (_ *HTTPServers, err error) {
ctx := inv.Context()
httpServers := &HTTPServers{}
defer func() {
if err != nil {
// Always close the listeners if we fail.
httpServers.Close()
}
}()
// Validate bind addresses.
if cfg.Address.String() != "" {
if cfg.TLS.Enable {
cfg.HTTPAddress = ""
cfg.TLS.Address = cfg.Address
} else {
_ = cfg.HTTPAddress.Set(cfg.Address.String())
cfg.TLS.Address.Host = ""
cfg.TLS.Address.Port = ""
}
}
if cfg.TLS.Enable && cfg.TLS.Address.String() == "" {
return nil, xerrors.Errorf("TLS address must be set if TLS is enabled")
}
if !cfg.TLS.Enable && cfg.HTTPAddress.String() == "" {
return nil, xerrors.Errorf("TLS is disabled. Enable with --tls-enable or specify a HTTP address")
}
if cfg.AccessURL.String() != "" &&
!(cfg.AccessURL.Scheme == "http" || cfg.AccessURL.Scheme == "https") {
return nil, xerrors.Errorf("access-url must include a scheme (e.g. 'http://' or 'https://)")
}
addrString := func(l net.Listener) string {
listenAddrStr := l.Addr().String()
// For some reason if 0.0.0.0:x is provided as the https
// address, httpsListener.Addr().String() likes to return it as
// an ipv6 address (i.e. [::]:x). If the input ip is 0.0.0.0,
// try to coerce the output back to ipv4 to make it less
// confusing.
if strings.Contains(cfg.HTTPAddress.String(), "0.0.0.0") {
listenAddrStr = strings.ReplaceAll(listenAddrStr, "[::]", "0.0.0.0")
}
return listenAddrStr
}
if cfg.HTTPAddress.String() != "" {
httpServers.HTTPListener, err = net.Listen("tcp", cfg.HTTPAddress.String())
if err != nil {
return nil, err
}
// We want to print out the address the user supplied, not the
// loopback device.
_, _ = fmt.Fprintf(inv.Stdout, "Started HTTP listener at %s\n", (&url.URL{Scheme: "http", Host: addrString(httpServers.HTTPListener)}).String())
// Set the http URL we want to use when connecting to ourselves.
tcpAddr, tcpAddrValid := httpServers.HTTPListener.Addr().(*net.TCPAddr)
if !tcpAddrValid {
return nil, xerrors.Errorf("invalid TCP address type %T", httpServers.HTTPListener.Addr())
}
if tcpAddr.IP.IsUnspecified() {
tcpAddr.IP = net.IPv4(127, 0, 0, 1)
}
httpServers.HTTPUrl = &url.URL{
Scheme: "http",
Host: tcpAddr.String(),
}
}
if cfg.TLS.Enable {
if cfg.TLS.Address.String() == "" {
return nil, xerrors.New("tls address must be set if tls is enabled")
}
redirectHTTPToHTTPSDeprecation(ctx, logger, inv, cfg)
tlsConfig, err := configureServerTLS(
ctx,
logger,
cfg.TLS.MinVersion.String(),
cfg.TLS.ClientAuth.String(),
cfg.TLS.CertFiles,
cfg.TLS.KeyFiles,
cfg.TLS.ClientCAFile.String(),
cfg.TLS.SupportedCiphers.Value(),
cfg.TLS.AllowInsecureCiphers.Value(),
)
if err != nil {
return nil, xerrors.Errorf("configure tls: %w", err)
}
httpsListenerInner, err := net.Listen("tcp", cfg.TLS.Address.String())
if err != nil {
return nil, err
}
httpServers.TLSConfig = tlsConfig
httpServers.TLSListener = tls.NewListener(httpsListenerInner, tlsConfig)
// We want to print out the address the user supplied, not the
// loopback device.
_, _ = fmt.Fprintf(inv.Stdout, "Started TLS/HTTPS listener at %s\n", (&url.URL{Scheme: "https", Host: addrString(httpServers.TLSListener)}).String())
// Set the https URL we want to use when connecting to
// ourselves.
tcpAddr, tcpAddrValid := httpServers.TLSListener.Addr().(*net.TCPAddr)
if !tcpAddrValid {
return nil, xerrors.Errorf("invalid TCP address type %T", httpServers.TLSListener.Addr())
}
if tcpAddr.IP.IsUnspecified() {
tcpAddr.IP = net.IPv4(127, 0, 0, 1)
}
httpServers.TLSUrl = &url.URL{
Scheme: "https",
Host: tcpAddr.String(),
}
}
if httpServers.HTTPListener == nil && httpServers.TLSListener == nil {
return nil, xerrors.New("must listen on at least one address")
}
return httpServers, nil
}
// redirectHTTPToHTTPSDeprecation handles deprecation of the --tls-redirect-http-to-https flag and
// "related" environment variables.
//
// --tls-redirect-http-to-https used to default to true.
// It made more sense to have the redirect be opt-in.
//
// Also, for a while we have been accepting the environment variable (but not the
// corresponding flag!) "CODER_TLS_REDIRECT_HTTP", and it appeared in a configuration
// example, so we keep accepting it to not break backward compat.
func redirectHTTPToHTTPSDeprecation(ctx context.Context, logger slog.Logger, inv *serpent.Invocation, cfg *codersdk.DeploymentValues) {
truthy := func(s string) bool {
b, err := strconv.ParseBool(s)
if err != nil {
return false
}
return b
}
if truthy(inv.Environ.Get("CODER_TLS_REDIRECT_HTTP")) ||
truthy(inv.Environ.Get("CODER_TLS_REDIRECT_HTTP_TO_HTTPS")) ||
inv.ParsedFlags().Changed("tls-redirect-http-to-https") {
logger.Warn(ctx, "⚠️ --tls-redirect-http-to-https is deprecated, please use --redirect-to-access-url instead")
cfg.RedirectToAccessURL = cfg.TLS.RedirectHTTP
}
}
// ReadExternalAuthProvidersFromEnv is provided for compatibility purposes with
// the viper CLI.
func ReadExternalAuthProvidersFromEnv(environ []string) ([]codersdk.ExternalAuthConfig, error) {
providers, err := parseExternalAuthProvidersFromEnv("CODER_EXTERNAL_AUTH_", environ)
if err != nil {
return nil, err
}
// Deprecated: To support legacy git auth!
gitProviders, err := parseExternalAuthProvidersFromEnv("CODER_GITAUTH_", environ)
if err != nil {
return nil, err
}
return append(providers, gitProviders...), nil
}
// parseExternalAuthProvidersFromEnv consumes environment variables to parse
// external auth providers. A prefix is provided to support the legacy
// parsing of `GITAUTH` environment variables.
func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]codersdk.ExternalAuthConfig, error) {
// The index numbers must be in-order.
sort.Strings(environ)
var providers []codersdk.ExternalAuthConfig
for _, v := range serpent.ParseEnviron(environ, prefix) {
tokens := strings.SplitN(v.Name, "_", 2)
if len(tokens) != 2 {
return nil, xerrors.Errorf("invalid env var: %s", v.Name)
}
providerNum, err := strconv.Atoi(tokens[0])
if err != nil {
return nil, xerrors.Errorf("parse number: %s", v.Name)
}
var provider codersdk.ExternalAuthConfig
switch {
case len(providers) < providerNum:
return nil, xerrors.Errorf(
"provider num %v skipped: %s",
len(providers),
v.Name,
)
case len(providers) == providerNum:
// At the next next provider.
providers = append(providers, provider)
case len(providers) == providerNum+1:
// At the current provider.
provider = providers[providerNum]
}
key := tokens[1]
switch key {
case "ID":
provider.ID = v.Value
case "TYPE":
provider.Type = v.Value
case "CLIENT_ID":
provider.ClientID = v.Value
case "CLIENT_SECRET":
provider.ClientSecret = v.Value
case "AUTH_URL":
provider.AuthURL = v.Value
case "TOKEN_URL":
provider.TokenURL = v.Value
case "VALIDATE_URL":
provider.ValidateURL = v.Value
case "REGEX":
provider.Regex = v.Value
case "DEVICE_FLOW":
b, err := strconv.ParseBool(v.Value)
if err != nil {
return nil, xerrors.Errorf("parse bool: %s", v.Value)
}
provider.DeviceFlow = b
case "DEVICE_CODE_URL":
provider.DeviceCodeURL = v.Value
case "NO_REFRESH":
b, err := strconv.ParseBool(v.Value)
if err != nil {
return nil, xerrors.Errorf("parse bool: %s", v.Value)
}
provider.NoRefresh = b
case "SCOPES":
provider.Scopes = strings.Split(v.Value, " ")
case "EXTRA_TOKEN_KEYS":
provider.ExtraTokenKeys = strings.Split(v.Value, " ")
case "APP_INSTALL_URL":
provider.AppInstallURL = v.Value
case "APP_INSTALLATIONS_URL":
provider.AppInstallationsURL = v.Value
case "DISPLAY_NAME":
provider.DisplayName = v.Value
case "DISPLAY_ICON":
provider.DisplayIcon = v.Value
}
providers[providerNum] = provider
}
return providers, nil
}
// If the user provides a postgres URL with a password that contains special
// characters, the URL will be invalid. We need to escape the password so that
// the URL parse doesn't fail at the DB connector level.
func escapePostgresURLUserInfo(v string) (string, error) {
_, err := url.Parse(v)
// I wish I could use errors.Is here, but this error is not declared as a
// variable in net/url. :(
if err != nil {
if strings.Contains(err.Error(), "net/url: invalid userinfo") {
// If the URL is invalid, we assume it is because the password contains
// special characters that need to be escaped.
// get everything before first @
parts := strings.SplitN(v, "@", 2)
if len(parts) != 2 {
return "", xerrors.Errorf("invalid postgres url with userinfo: %s", v)
}
start := parts[0]
// get password, which is the last item in start when split by :
startParts := strings.Split(start, ":")
password := startParts[len(startParts)-1]
// escape password, and replace the last item in the startParts slice
// with the escaped password.
//
// url.PathEscape is used here because url.QueryEscape
// will not escape spaces correctly.
newPassword := url.PathEscape(password)
startParts[len(startParts)-1] = newPassword
start = strings.Join(startParts, ":")
return start + "@" + parts[1], nil
}
return "", xerrors.Errorf("parse postgres url: %w", err)
}
return v, nil
}
func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os.Signal) (context.Context, context.CancelFunc) {
// On Windows, some of our signal functions lack support.
// If we pass in no signals, we should just return the context as-is.
if len(sig) == 0 {
return context.WithCancel(ctx)
}
return inv.SignalNotifyContext(ctx, sig...)
}
func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
dbURL, err := escapePostgresURLUserInfo(postgresURL)
if err != nil {
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
}
if auth == codersdk.PostgresAuthAWSIAMRDS {
sqlDriver, err = awsiamrds.Register(ctx, sqlDriver)
if err != nil {
return nil, "", xerrors.Errorf("register aws rds iam auth: %w", err)
}
}
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
if err != nil {
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
}
return sqlDB, dbURL, nil
}