coder/cli/server.go

1561 lines
50 KiB
Go

//go:build !slim
package cli
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"os/signal"
"os/user"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-systemd/daemon"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver"
"golang.org/x/oauth2"
xgithub "golang.org/x/oauth2/github"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"cdr.dev/slog/sloggers/slogjson"
"cdr.dev/slog/sloggers/slogstackdriver"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/cli/config"
"github.com/coder/coder/cli/deployment"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/autobuild/executor"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/migrations"
"github.com/coder/coder/coderd/devtunnel"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/prometheusmetrics"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisioner/terraform"
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/tailnet"
)
// nolint:gocyclo
func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command {
root := &cobra.Command{
Use: "server",
Short: "Start a Coder server",
RunE: func(cmd *cobra.Command, args []string) error {
// Main command context for managing cancellation of running
// services.
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
go dumpHandler(ctx)
cfg, err := deployment.Config(cmd.Flags(), vip)
if err != nil {
return xerrors.Errorf("getting deployment config: %w", err)
}
// Validate bind addresses.
if cfg.Address.Value != "" {
cmd.PrintErr(cliui.Styles.Warn.Render("WARN:") + " --address and -a are deprecated, please use --http-address and --tls-address instead")
if cfg.TLS.Enable.Value {
cfg.HTTPAddress.Value = ""
cfg.TLS.Address.Value = cfg.Address.Value
} else {
cfg.HTTPAddress.Value = cfg.Address.Value
cfg.TLS.Address.Value = ""
}
}
if cfg.TLS.Enable.Value && cfg.TLS.Address.Value == "" {
return xerrors.Errorf("TLS address must be set if TLS is enabled")
}
if !cfg.TLS.Enable.Value && cfg.HTTPAddress.Value == "" {
return xerrors.Errorf("either HTTP or TLS must be enabled")
}
// Disable rate limits if the `--dangerous-disable-rate-limits` flag
// was specified.
loginRateLimit := 60
filesRateLimit := 12
if cfg.RateLimit.DisableAll.Value {
cfg.RateLimit.API.Value = -1
loginRateLimit = -1
filesRateLimit = -1
}
printLogo(cmd)
logger, logCloser, err := buildLogger(cmd, cfg)
if err != nil {
return xerrors.Errorf("make logger: %w", err)
}
defer logCloser()
// Register signals early on so that graceful shutdown can't
// be interrupted by additional signals. Note that we avoid
// shadowing cancel() (from above) here because notifyStop()
// restores default behavior for the signals. This protects
// the shutdown sequence from abruptly terminating things
// like: database migrations, provisioner work, workspace
// cleanup in dev-mode, etc.
//
// To get out of a graceful shutdown, the user can send
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
defer notifyStop()
// Ensure we have a unique cache directory for this process.
cacheDir := filepath.Join(cfg.CacheDirectory.Value, uuid.NewString())
err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
defer os.RemoveAll(cacheDir)
// Clean up idle connections at the end, e.g.
// embedded-postgres can leave an idle connection
// which is caught by goleaks.
defer http.DefaultClient.CloseIdleConnections()
var (
tracerProvider trace.TracerProvider
sqlDriver = "postgres"
)
// Coder tracing should be disabled if telemetry is disabled unless
// --telemetry-trace was explicitly provided.
shouldCoderTrace := cfg.Telemetry.Enable.Value && !isTest()
// Only override if telemetryTraceEnable was specifically set.
// By default we want it to be controlled by telemetryEnable.
if cmd.Flags().Changed("telemetry-trace") {
shouldCoderTrace = cfg.Telemetry.Trace.Value
}
if cfg.Trace.Enable.Value || shouldCoderTrace || cfg.Trace.HoneycombAPIKey.Value != "" {
sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
Default: cfg.Trace.Enable.Value,
Coder: shouldCoderTrace,
Honeycomb: cfg.Trace.HoneycombAPIKey.Value,
})
if err != nil {
logger.Warn(ctx, "start telemetry exporter", slog.Error(err))
} else {
// allow time for traces to flush even if command context is canceled
defer func() {
_ = shutdownWithTimeout(closeTracing, 5*time.Second)
}()
d, err := tracing.PostgresDriver(sdkTracerProvider, "coderd.database")
if err != nil {
logger.Warn(ctx, "start postgres tracing driver", slog.Error(err))
} else {
sqlDriver = d
}
tracerProvider = sdkTracerProvider
}
}
config := createConfig(cmd)
builtinPostgres := false
// Only use built-in if PostgreSQL URL isn't specified!
if !cfg.InMemoryDatabase.Value && cfg.PostgresURL.Value == "" {
var closeFunc func() error
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath())
cfg.PostgresURL.Value, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
if err != nil {
return err
}
builtinPostgres = true
defer func() {
cmd.Printf("Stopping built-in PostgreSQL...\n")
// Gracefully shut PostgreSQL down!
if err := closeFunc(); err != nil {
cmd.Printf("Failed to stop built-in PostgreSQL: %v\n", err)
} else {
cmd.Printf("Stopped built-in PostgreSQL\n")
}
}()
}
var (
httpListener net.Listener
httpURL *url.URL
)
if cfg.HTTPAddress.Value != "" {
httpListener, err = net.Listen("tcp", cfg.HTTPAddress.Value)
if err != nil {
return xerrors.Errorf("listen %q: %w", cfg.HTTPAddress.Value, err)
}
defer httpListener.Close()
listenAddrStr := httpListener.Addr().String()
// For some reason if 0.0.0.0:x is provided as the http address,
// httpListener.Addr().String() likes to return it as an ipv6
// address (i.e. [::]:x). If the input ip is 0.0.0.0, try to
// coerce the output back to ipv4 to make it less confusing.
if strings.Contains(cfg.HTTPAddress.Value, "0.0.0.0") {
listenAddrStr = strings.ReplaceAll(listenAddrStr, "[::]", "0.0.0.0")
}
// We want to print out the address the user supplied, not the
// loopback device.
cmd.Println("Started HTTP listener at", (&url.URL{Scheme: "http", Host: listenAddrStr}).String())
// Set the http URL we want to use when connecting to ourselves.
tcpAddr, tcpAddrValid := httpListener.Addr().(*net.TCPAddr)
if !tcpAddrValid {
return xerrors.Errorf("invalid TCP address type %T", httpListener.Addr())
}
if tcpAddr.IP.IsUnspecified() {
tcpAddr.IP = net.IPv4(127, 0, 0, 1)
}
httpURL = &url.URL{
Scheme: "http",
Host: tcpAddr.String(),
}
}
var (
tlsConfig *tls.Config
httpsListener net.Listener
httpsURL *url.URL
)
if cfg.TLS.Enable.Value {
if cfg.TLS.Address.Value == "" {
return xerrors.New("tls address must be set if tls is enabled")
}
tlsConfig, err = configureTLS(
cfg.TLS.MinVersion.Value,
cfg.TLS.ClientAuth.Value,
cfg.TLS.CertFiles.Value,
cfg.TLS.KeyFiles.Value,
cfg.TLS.ClientCAFile.Value,
)
if err != nil {
return xerrors.Errorf("configure tls: %w", err)
}
httpsListenerInner, err := net.Listen("tcp", cfg.TLS.Address.Value)
if err != nil {
return xerrors.Errorf("listen %q: %w", cfg.TLS.Address.Value, err)
}
defer httpsListenerInner.Close()
httpsListener = tls.NewListener(httpsListenerInner, tlsConfig)
defer httpsListener.Close()
listenAddrStr := httpsListener.Addr().String()
// For some reason if 0.0.0.0:x is provided as the https
// address, httpsListener.Addr().String() likes to return it as
// an ipv6 address (i.e. [::]:x). If the input ip is 0.0.0.0,
// try to coerce the output back to ipv4 to make it less
// confusing.
if strings.Contains(cfg.HTTPAddress.Value, "0.0.0.0") {
listenAddrStr = strings.ReplaceAll(listenAddrStr, "[::]", "0.0.0.0")
}
// We want to print out the address the user supplied, not the
// loopback device.
cmd.Println("Started TLS/HTTPS listener at", (&url.URL{Scheme: "https", Host: listenAddrStr}).String())
// Set the https URL we want to use when connecting to
// ourselves.
tcpAddr, tcpAddrValid := httpsListener.Addr().(*net.TCPAddr)
if !tcpAddrValid {
return xerrors.Errorf("invalid TCP address type %T", httpsListener.Addr())
}
if tcpAddr.IP.IsUnspecified() {
tcpAddr.IP = net.IPv4(127, 0, 0, 1)
}
httpsURL = &url.URL{
Scheme: "https",
Host: tcpAddr.String(),
}
}
// Sanity check that at least one listener was started.
if httpListener == nil && httpsListener == nil {
return xerrors.New("must listen on at least one address")
}
// Prefer HTTP because it's less prone to TLS errors over localhost.
localURL := httpsURL
if httpURL != nil {
localURL = httpURL
}
ctx, httpClient, err := configureHTTPClient(
ctx,
cfg.TLS.ClientCertFile.Value,
cfg.TLS.ClientKeyFile.Value,
cfg.TLS.ClientCAFile.Value,
)
if err != nil {
return xerrors.Errorf("configure http client: %w", err)
}
var (
ctxTunnel, closeTunnel = context.WithCancel(ctx)
tunnel *devtunnel.Tunnel
tunnelErr <-chan error
)
defer closeTunnel()
// If the access URL is empty, we attempt to run a reverse-proxy
// tunnel to make the initial setup really simple.
if cfg.AccessURL.Value == "" {
cmd.Printf("Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n")
tunnel, tunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
if err != nil {
return xerrors.Errorf("create tunnel: %w", err)
}
cfg.AccessURL.Value = tunnel.URL
if cfg.WildcardAccessURL.Value == "" {
u, err := parseURL(tunnel.URL)
if err != nil {
return xerrors.Errorf("parse tunnel url: %w", err)
}
// Suffixed wildcard access URL.
cfg.WildcardAccessURL.Value = fmt.Sprintf("*--%s", u.Hostname())
}
}
accessURLParsed, err := parseURL(cfg.AccessURL.Value)
if err != nil {
return xerrors.Errorf("parse URL: %w", err)
}
accessURLPortRaw := accessURLParsed.Port()
if accessURLPortRaw == "" {
accessURLPortRaw = "80"
if accessURLParsed.Scheme == "https" {
accessURLPortRaw = "443"
}
}
accessURLPort, err := strconv.Atoi(accessURLPortRaw)
if err != nil {
return xerrors.Errorf("parse access URL port: %w", err)
}
// Warn the user if the access URL appears to be a loopback address.
isLocal, err := isLocalURL(ctx, accessURLParsed)
if isLocal || err != nil {
reason := "could not be resolved"
if isLocal {
reason = "isn't externally reachable"
}
cmd.Printf("%s The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n", cliui.Styles.Warn.Render("Warning:"), cliui.Styles.Field.Render(accessURLParsed.String()), reason)
}
// Redirect from the HTTP listener to the access URL if:
// 1. The redirect flag is enabled.
// 2. HTTP listening is enabled (obviously).
// 3. TLS is enabled (otherwise they're likely using a reverse proxy
// which can do this instead).
// 4. The access URL has been set manually (not a tunnel).
// 5. The access URL is HTTPS.
shouldRedirectHTTPToAccessURL := cfg.TLS.RedirectHTTP.Value && cfg.HTTPAddress.Value != "" && cfg.TLS.Enable.Value && tunnel == nil && accessURLParsed.Scheme == "https"
// A newline is added before for visibility in terminal output.
cmd.Printf("\nView the Web UI: %s\n", accessURLParsed.String())
// Used for zero-trust instance identity with Google Cloud.
googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
if err != nil {
return err
}
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(cfg.SSHKeygenAlgorithm.Value)
if err != nil {
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", cfg.SSHKeygenAlgorithm.Value, err)
}
defaultRegion := &tailcfg.DERPRegion{
EmbeddedRelay: true,
RegionID: cfg.DERP.Server.RegionID.Value,
RegionCode: cfg.DERP.Server.RegionCode.Value,
RegionName: cfg.DERP.Server.RegionName.Value,
Nodes: []*tailcfg.DERPNode{{
Name: fmt.Sprintf("%db", cfg.DERP.Server.RegionID.Value),
RegionID: cfg.DERP.Server.RegionID.Value,
HostName: accessURLParsed.Hostname(),
DERPPort: accessURLPort,
STUNPort: -1,
ForceHTTP: accessURLParsed.Scheme == "http",
}},
}
if !cfg.DERP.Server.Enable.Value {
defaultRegion = nil
}
derpMap, err := tailnet.NewDERPMap(ctx, defaultRegion, cfg.DERP.Server.STUNAddresses.Value, cfg.DERP.Config.URL.Value, cfg.DERP.Config.Path.Value)
if err != nil {
return xerrors.Errorf("create derp map: %w", err)
}
appHostname := strings.TrimSpace(cfg.WildcardAccessURL.Value)
var appHostnameRegex *regexp.Regexp
if appHostname != "" {
appHostnameRegex, err = httpapi.CompileHostnamePattern(appHostname)
if err != nil {
return xerrors.Errorf("parse wildcard access URL %q: %w", appHostname, err)
}
}
gitAuthConfigs, err := gitauth.ConvertConfig(cfg.GitAuth.Value, accessURLParsed)
if err != nil {
return xerrors.Errorf("parse git auth config: %w", err)
}
realIPConfig, err := httpmw.ParseRealIPConfig(cfg.ProxyTrustedHeaders.Value, cfg.ProxyTrustedOrigins.Value)
if err != nil {
return xerrors.Errorf("parse real ip config: %w", err)
}
options := &coderd.Options{
AccessURL: accessURLParsed,
AppHostname: appHostname,
AppHostnameRegex: appHostnameRegex,
Logger: logger.Named("coderd"),
Database: databasefake.New(),
DERPMap: derpMap,
Pubsub: database.NewPubsubInMemory(),
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
GitAuthConfigs: gitAuthConfigs,
RealIPConfig: realIPConfig,
SecureAuthCookie: cfg.SecureAuthCookie.Value,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
Telemetry: telemetry.NewNoop(),
MetricsCacheRefreshInterval: cfg.MetricsCacheRefreshInterval.Value,
AgentStatsRefreshInterval: cfg.AgentStatRefreshInterval.Value,
DeploymentConfig: cfg,
PrometheusRegistry: prometheus.NewRegistry(),
APIRateLimit: cfg.RateLimit.API.Value,
LoginRateLimit: loginRateLimit,
FilesRateLimit: filesRateLimit,
HTTPClient: httpClient,
}
if tlsConfig != nil {
options.TLSCertificates = tlsConfig.Certificates
}
if cfg.UpdateCheck.Value {
options.UpdateCheckOptions = &updatecheck.Options{
// Avoid spamming GitHub API checking for updates.
Interval: 24 * time.Hour,
// Inform server admins of new versions.
Notify: func(r updatecheck.Result) {
if semver.Compare(r.Version, buildinfo.Version()) > 0 {
options.Logger.Info(
context.Background(),
"new version of coder available",
slog.F("new_version", r.Version),
slog.F("url", r.URL),
slog.F("upgrade_instructions", "https://coder.com/docs/coder-oss/latest/admin/upgrade"),
)
}
},
}
}
if cfg.OAuth2.Github.ClientSecret.Value != "" {
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed,
cfg.OAuth2.Github.ClientID.Value,
cfg.OAuth2.Github.ClientSecret.Value,
cfg.OAuth2.Github.AllowSignups.Value,
cfg.OAuth2.Github.AllowEveryone.Value,
cfg.OAuth2.Github.AllowedOrgs.Value,
cfg.OAuth2.Github.AllowedTeams.Value,
cfg.OAuth2.Github.EnterpriseBaseURL.Value,
)
if err != nil {
return xerrors.Errorf("configure github oauth2: %w", err)
}
}
if cfg.OIDC.ClientSecret.Value != "" {
if cfg.OIDC.ClientID.Value == "" {
return xerrors.Errorf("OIDC client ID be set!")
}
if cfg.OIDC.IssuerURL.Value == "" {
return xerrors.Errorf("OIDC issuer URL must be set!")
}
if cfg.OIDC.IgnoreEmailVerified.Value {
logger.Warn(ctx, "coder will not check email_verified for OIDC logins")
}
oidcProvider, err := oidc.NewProvider(ctx, cfg.OIDC.IssuerURL.Value)
if err != nil {
return xerrors.Errorf("configure oidc provider: %w", err)
}
redirectURL, err := accessURLParsed.Parse("/api/v2/users/oidc/callback")
if err != nil {
return xerrors.Errorf("parse oidc oauth callback url: %w", err)
}
options.OIDCConfig = &coderd.OIDCConfig{
OAuth2Config: &oauth2.Config{
ClientID: cfg.OIDC.ClientID.Value,
ClientSecret: cfg.OIDC.ClientSecret.Value,
RedirectURL: redirectURL.String(),
Endpoint: oidcProvider.Endpoint(),
Scopes: cfg.OIDC.Scopes.Value,
},
Provider: oidcProvider,
Verifier: oidcProvider.Verifier(&oidc.Config{
ClientID: cfg.OIDC.ClientID.Value,
}),
EmailDomain: cfg.OIDC.EmailDomain.Value,
AllowSignups: cfg.OIDC.AllowSignups.Value,
UsernameField: cfg.OIDC.UsernameField.Value,
}
}
if cfg.InMemoryDatabase.Value {
options.Database = databasefake.New()
options.Pubsub = database.NewPubsubInMemory()
} else {
logger.Debug(ctx, "connecting to postgresql")
sqlDB, err := sql.Open(sqlDriver, cfg.PostgresURL.Value)
if err != nil {
return xerrors.Errorf("dial postgres: %w", err)
}
defer sqlDB.Close()
pingCtx, pingCancel := context.WithTimeout(ctx, 15*time.Second)
defer pingCancel()
err = sqlDB.PingContext(pingCtx)
if err != nil {
return xerrors.Errorf("ping postgres: %w", err)
}
// Ensure the PostgreSQL version is >=13.0.0!
version, err := sqlDB.QueryContext(ctx, "SHOW server_version;")
if err != nil {
return xerrors.Errorf("get postgres version: %w", err)
}
if !version.Next() {
return xerrors.Errorf("no rows returned for version select")
}
var versionStr string
err = version.Scan(&versionStr)
if err != nil {
return xerrors.Errorf("scan version: %w", err)
}
_ = version.Close()
versionStr = strings.Split(versionStr, " ")[0]
if semver.Compare("v"+versionStr, "v13") < 0 {
return xerrors.New("PostgreSQL version must be v13.0.0 or higher!")
}
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionStr))
err = migrations.Up(sqlDB)
if err != nil {
return xerrors.Errorf("migrate up: %w", err)
}
// The default is 0 but the request will fail with a 500 if the DB
// cannot accept new connections, so we try to limit that here.
// Requests will wait for a new connection instead of a hard error
// if a limit is set.
sqlDB.SetMaxOpenConns(10)
// Allow a max of 3 idle connections at a time. Lower values end up
// creating a lot of connection churn. Since each connection uses about
// 10MB of memory, we're allocating 30MB to Postgres connections per
// replica, but is better than causing Postgres to spawn a thread 15-20
// times/sec. PGBouncer's transaction pooling is not the greatest so
// it's not optimal for us to deploy.
//
// This was set to 10 before we started doing HA deployments, but 3 was
// later determined to be a better middle ground as to not use up all
// of PGs default connection limit while simultaneously avoiding a lot
// of connection churn.
sqlDB.SetMaxIdleConns(3)
options.Database = database.New(sqlDB)
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, cfg.PostgresURL.Value)
if err != nil {
return xerrors.Errorf("create pubsub: %w", err)
}
defer options.Pubsub.Close()
}
deploymentID, err := options.Database.GetDeploymentID(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
return xerrors.Errorf("get deployment id: %w", err)
}
if deploymentID == "" {
deploymentID = uuid.NewString()
err = options.Database.InsertDeploymentID(ctx, deploymentID)
if err != nil {
return xerrors.Errorf("set deployment id: %w", err)
}
}
// Disable telemetry if the in-memory database is used unless explicitly defined!
if cfg.InMemoryDatabase.Value && !cmd.Flags().Changed(cfg.Telemetry.Enable.Flag) {
cfg.Telemetry.Enable.Value = false
}
if cfg.Telemetry.Enable.Value {
// Parse the raw telemetry URL!
telemetryURL, err := parseURL(cfg.Telemetry.URL.Value)
if err != nil {
return xerrors.Errorf("parse telemetry url: %w", err)
}
gitAuth := make([]telemetry.GitAuth, 0)
for _, cfg := range gitAuthConfigs {
gitAuth = append(gitAuth, telemetry.GitAuth{
Type: string(cfg.Type),
})
}
options.Telemetry, err = telemetry.New(telemetry.Options{
BuiltinPostgres: builtinPostgres,
DeploymentID: deploymentID,
Database: options.Database,
Logger: logger.Named("telemetry"),
URL: telemetryURL,
Wildcard: cfg.WildcardAccessURL.Value != "",
DERPServerRelayURL: cfg.DERP.Server.RelayURL.Value,
GitAuth: gitAuth,
GitHubOAuth: cfg.OAuth2.Github.ClientID.Value != "",
OIDCAuth: cfg.OIDC.ClientID.Value != "",
OIDCIssuerURL: cfg.OIDC.IssuerURL.Value,
Prometheus: cfg.Prometheus.Enable.Value,
STUN: len(cfg.DERP.Server.STUNAddresses.Value) != 0,
Tunnel: tunnel != nil,
})
if err != nil {
return xerrors.Errorf("create telemetry reporter: %w", err)
}
defer options.Telemetry.Close()
}
// This prevents the pprof import from being accidentally deleted.
_ = pprof.Handler
if cfg.Pprof.Enable.Value {
//nolint:revive
defer serveHandler(ctx, logger, nil, cfg.Pprof.Address.Value, "pprof")()
}
if cfg.Prometheus.Enable.Value {
options.PrometheusRegistry.MustRegister(collectors.NewGoCollector())
options.PrometheusRegistry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
closeUsersFunc, err := prometheusmetrics.ActiveUsers(ctx, options.PrometheusRegistry, options.Database, 0)
if err != nil {
return xerrors.Errorf("register active users prometheus metric: %w", err)
}
defer closeUsersFunc()
closeWorkspacesFunc, err := prometheusmetrics.Workspaces(ctx, options.PrometheusRegistry, options.Database, 0)
if err != nil {
return xerrors.Errorf("register workspaces prometheus metric: %w", err)
}
defer closeWorkspacesFunc()
//nolint:revive
defer serveHandler(ctx, logger, promhttp.InstrumentMetricHandler(
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
), cfg.Prometheus.Address.Value, "prometheus")()
}
if cfg.Swagger.Enable.Value {
options.SwaggerEndpoint = cfg.Swagger.Enable.Value
}
// We use a separate coderAPICloser so the Enterprise API
// can have it's own close functions. This is cleaner
// than abstracting the Coder API itself.
coderAPI, coderAPICloser, err := newAPI(ctx, options)
if err != nil {
return err
}
client := codersdk.New(localURL)
if localURL.Scheme == "https" && isLocalhost(localURL.Hostname()) {
// The certificate will likely be self-signed or for a different
// hostname, so we need to skip verification.
client.HTTPClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
}
}
defer client.HTTPClient.CloseIdleConnections()
// This is helpful for tests, but can be silently ignored.
// Coder may be ran as users that don't have permission to write in the homedir,
// such as via the systemd service.
_ = config.URL().Write(client.URL.String())
// Since errCh only has one buffered slot, all routines
// sending on it must be wrapped in a select/default to
// avoid leaving dangling goroutines waiting for the
// channel to be consumed.
errCh := make(chan error, 1)
provisionerDaemons := make([]*provisionerd.Server, 0)
defer func() {
// We have no graceful shutdown of provisionerDaemons
// here because that's handled at the end of main, this
// is here in case the program exits early.
for _, daemon := range provisionerDaemons {
_ = daemon.Close()
}
}()
provisionerdMetrics := provisionerd.NewMetrics(options.PrometheusRegistry)
for i := 0; i < cfg.Provisioner.Daemons.Value; i++ {
daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i))
daemon, err := newProvisionerDaemon(ctx, coderAPI, provisionerdMetrics, logger, cfg, daemonCacheDir, errCh, false)
if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err)
}
provisionerDaemons = append(provisionerDaemons, daemon)
}
shutdownConnsCtx, shutdownConns := context.WithCancel(ctx)
defer shutdownConns()
// Wrap the server in middleware that redirects to the access URL if
// the request is not to a local IP.
var handler http.Handler = coderAPI.RootHandler
if shouldRedirectHTTPToAccessURL {
handler = redirectHTTPToAccessURL(handler, accessURLParsed)
}
// ReadHeaderTimeout is purposefully not enabled. It caused some
// issues with websockets over the dev tunnel.
// See: https://github.com/coder/coder/pull/3730
//nolint:gosec
httpServer := &http.Server{
// These errors are typically noise like "TLS: EOF". Vault does
// similar:
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
ErrorLog: log.New(io.Discard, "", 0),
Handler: handler,
BaseContext: func(_ net.Listener) context.Context {
return shutdownConnsCtx
},
}
defer func() {
_ = shutdownWithTimeout(httpServer.Shutdown, 5*time.Second)
}()
// We call this in the routine so we can kill the other listeners if
// one of them fails.
closeListenersNow := func() {
if httpListener != nil {
_ = httpListener.Close()
}
if httpsListener != nil {
_ = httpsListener.Close()
}
if tunnel != nil {
_ = tunnel.Listener.Close()
}
}
eg := errgroup.Group{}
if httpListener != nil {
eg.Go(func() error {
defer closeListenersNow()
return httpServer.Serve(httpListener)
})
}
if httpsListener != nil {
eg.Go(func() error {
defer closeListenersNow()
return httpServer.Serve(httpsListener)
})
}
if tunnel != nil {
eg.Go(func() error {
defer closeListenersNow()
return httpServer.Serve(tunnel.Listener)
})
}
go func() {
select {
case errCh <- eg.Wait():
default:
}
}()
hasFirstUser, err := client.HasFirstUser(ctx)
if err != nil {
cmd.Println("\nFailed to check for the first user: " + err.Error())
} else if !hasFirstUser {
cmd.Println("\nGet started by creating the first user (in a new terminal):")
cmd.Println(cliui.Styles.Code.Render("coder login " + accessURLParsed.String()))
}
cmd.Println("\n==> Logs will stream in below (press ctrl+c to gracefully exit):")
// Updates the systemd status from activating to activated.
_, err = daemon.SdNotify(false, daemon.SdNotifyReady)
if err != nil {
return xerrors.Errorf("notify systemd: %w", err)
}
autobuildPoller := time.NewTicker(cfg.AutobuildPollInterval.Value)
defer autobuildPoller.Stop()
autobuildExecutor := executor.New(ctx, options.Database, logger, autobuildPoller.C)
autobuildExecutor.Run()
// Currently there is no way to ask the server to shut
// itself down, so any exit signal will result in a non-zero
// exit of the server.
var exitErr error
select {
case <-notifyCtx.Done():
exitErr = notifyCtx.Err()
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
))
case exitErr = <-tunnelErr:
if exitErr == nil {
exitErr = xerrors.New("dev tunnel closed unexpectedly")
}
case exitErr = <-errCh:
}
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr)
}
// Begin clean shut down stage, we try to shut down services
// gracefully in an order that gives the best experience.
// This procedure should not differ greatly from the order
// of `defer`s in this function, but allows us to inform
// the user about what's going on and handle errors more
// explicitly.
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
if err != nil {
cmd.Printf("Notify systemd failed: %s", err)
}
// Stop accepting new connections without interrupting
// in-flight requests, give in-flight requests 5 seconds to
// complete.
cmd.Println("Shutting down API server...")
err = shutdownWithTimeout(httpServer.Shutdown, 3*time.Second)
if err != nil {
cmd.Printf("API server shutdown took longer than 3s: %s\n", err)
} else {
cmd.Printf("Gracefully shut down API server\n")
}
// Cancel any remaining in-flight requests.
shutdownConns()
// Shut down provisioners before waiting for WebSockets
// connections to close.
var wg sync.WaitGroup
for i, provisionerDaemon := range provisionerDaemons {
id := i + 1
provisionerDaemon := provisionerDaemon
wg.Add(1)
go func() {
defer wg.Done()
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
cmd.Printf("Shutting down provisioner daemon %d...\n", id)
}
err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second)
if err != nil {
cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err)
return
}
err = provisionerDaemon.Close()
if err != nil {
cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err)
return
}
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
cmd.Printf("Gracefully shut down provisioner daemon %d\n", id)
}
}()
}
wg.Wait()
cmd.Println("Waiting for WebSocket connections to close...")
_ = coderAPICloser.Close()
cmd.Println("Done waiting for WebSocket connections")
// Close tunnel after we no longer have in-flight connections.
if tunnel != nil {
cmd.Println("Waiting for tunnel to close...")
closeTunnel()
<-tunnelErr
cmd.Println("Done waiting for tunnel")
}
// Ensures a last report can be sent before exit!
options.Telemetry.Close()
// Trigger context cancellation for any remaining services.
cancel()
if xerrors.Is(exitErr, context.Canceled) {
return nil
}
return exitErr
},
}
var pgRawURL bool
postgresBuiltinURLCmd := &cobra.Command{
Use: "postgres-builtin-url",
Short: "Output the connection URL for the built-in PostgreSQL deployment.",
RunE: func(cmd *cobra.Command, _ []string) error {
cfg := createConfig(cmd)
url, err := embeddedPostgresURL(cfg)
if err != nil {
return err
}
if pgRawURL {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", url)
} else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url)))
}
return nil
},
}
postgresBuiltinServeCmd := &cobra.Command{
Use: "postgres-builtin-serve",
Short: "Run the built-in PostgreSQL deployment.",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
cfg := createConfig(cmd)
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
logger = logger.Leveled(slog.LevelDebug)
}
ctx, cancel := signal.NotifyContext(ctx, InterruptSignals...)
defer cancel()
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)
if err != nil {
return err
}
defer func() { _ = closePg() }()
if pgRawURL {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", url)
} else {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url)))
}
<-ctx.Done()
return nil
},
}
postgresBuiltinURLCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.")
postgresBuiltinServeCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.")
root.AddCommand(postgresBuiltinURLCmd, postgresBuiltinServeCmd)
deployment.AttachFlags(root.Flags(), vip, false)
return root
}
// parseURL parses a string into a URL.
func parseURL(u string) (*url.URL, error) {
var (
hasScheme = strings.HasPrefix(u, "http:") || strings.HasPrefix(u, "https:")
)
if !hasScheme {
return nil, xerrors.Errorf("URL %q must have a scheme of either http or https", u)
}
parsed, err := url.Parse(u)
if err != nil {
return nil, err
}
return parsed, nil
}
// isLocalURL returns true if the hostname of the provided URL appears to
// resolve to a loopback address.
func isLocalURL(ctx context.Context, u *url.URL) (bool, error) {
resolver := &net.Resolver{}
ips, err := resolver.LookupIPAddr(ctx, u.Hostname())
if err != nil {
return false, err
}
for _, ip := range ips {
if ip.IP.IsLoopback() {
return true, nil
}
}
return false, nil
}
func shutdownWithTimeout(shutdown func(context.Context) error, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return shutdown(ctx)
}
// nolint:revive
func newProvisionerDaemon(
ctx context.Context,
coderAPI *coderd.API,
metrics provisionerd.Metrics,
logger slog.Logger,
cfg *codersdk.DeploymentConfig,
cacheDir string,
errCh chan error,
dev bool,
) (srv *provisionerd.Server, err error) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
}
terraformClient, terraformServer := provisionersdk.MemTransportPipe()
go func() {
<-ctx.Done()
_ = terraformClient.Close()
_ = terraformServer.Close()
}()
go func() {
defer cancel()
err := terraform.Serve(ctx, &terraform.ServeOptions{
ServeOptions: &provisionersdk.ServeOptions{
Listener: terraformServer,
},
CachePath: cacheDir,
Logger: logger,
})
if err != nil && !xerrors.Is(err, context.Canceled) {
select {
case errCh <- err:
default:
}
}
}()
tempDir, err := os.MkdirTemp("", "provisionerd")
if err != nil {
return nil, err
}
provisioners := provisionerd.Provisioners{
string(database.ProvisionerTypeTerraform): sdkproto.NewDRPCProvisionerClient(terraformClient),
}
// include echo provisioner when in dev mode
if dev {
echoClient, echoServer := provisionersdk.MemTransportPipe()
go func() {
<-ctx.Done()
_ = echoClient.Close()
_ = echoServer.Close()
}()
go func() {
defer cancel()
err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer})
if err != nil {
select {
case errCh <- err:
default:
}
}
}()
provisioners[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
}
debounce := time.Second
return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
// This debounces calls to listen every second. Read the comment
// in provisionerdserver.go to learn more!
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, debounce)
}, &provisionerd.Options{
Logger: logger,
JobPollInterval: cfg.Provisioner.DaemonPollInterval.Value,
JobPollJitter: cfg.Provisioner.DaemonPollJitter.Value,
JobPollDebounce: debounce,
UpdateInterval: 500 * time.Millisecond,
ForceCancelInterval: cfg.Provisioner.ForceCancelInterval.Value,
Provisioners: provisioners,
WorkDirectory: tempDir,
TracerProvider: coderAPI.TracerProvider,
Metrics: &metrics,
}), nil
}
// nolint: revive
func printLogo(cmd *cobra.Command) {
// Only print the logo in TTYs.
if !isTTYOut(cmd) {
return
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Software development on your infrastucture\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version()))
}
func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) {
if len(tlsCertFiles) != len(tlsKeyFiles) {
return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times")
}
if len(tlsCertFiles) == 0 {
return nil, xerrors.New("--tls-cert-file is required when tls is enabled")
}
if len(tlsKeyFiles) == 0 {
return nil, xerrors.New("--tls-key-file is required when tls is enabled")
}
certs := make([]tls.Certificate, len(tlsCertFiles))
for i := range tlsCertFiles {
certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i]
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err)
}
certs[i] = cert
}
return certs, nil
}
func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
switch tlsMinVersion {
case "tls10":
tlsConfig.MinVersion = tls.VersionTLS10
case "tls11":
tlsConfig.MinVersion = tls.VersionTLS11
case "tls12":
tlsConfig.MinVersion = tls.VersionTLS12
case "tls13":
tlsConfig.MinVersion = tls.VersionTLS13
default:
return nil, xerrors.Errorf("unrecognized tls version: %q", tlsMinVersion)
}
switch tlsClientAuth {
case "none":
tlsConfig.ClientAuth = tls.NoClientCert
case "request":
tlsConfig.ClientAuth = tls.RequestClientCert
case "require-any":
tlsConfig.ClientAuth = tls.RequireAnyClientCert
case "verify-if-given":
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
case "require-and-verify":
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
default:
return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth)
}
certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles)
if err != nil {
return nil, xerrors.Errorf("load certificates: %w", err)
}
tlsConfig.Certificates = certs
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
// If there's only one certificate, return it.
if len(certs) == 1 {
return &certs[0], nil
}
// Expensively check which certificate matches the client hello.
for _, cert := range certs {
cert := cert
if err := hi.SupportsCertificate(&cert); err == nil {
return &cert, nil
}
}
// Return the first certificate if we have one, or return nil so the
// server doesn't fail.
if len(certs) > 0 {
return &certs[0], nil
}
return nil, nil //nolint:nilnil
}
err = configureCAPool(tlsClientCAFile, tlsConfig)
if err != nil {
return nil, err
}
return tlsConfig, nil
}
func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {
if tlsClientCAFile != "" {
caPool := x509.NewCertPool()
data, err := os.ReadFile(tlsClientCAFile)
if err != nil {
return xerrors.Errorf("read %q: %w", tlsClientCAFile, err)
}
if !caPool.AppendCertsFromPEM(data) {
return xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file")
}
tlsConfig.ClientCAs = caPool
}
return nil
}
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups, allowEveryone bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
if err != nil {
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
}
if allowEveryone && len(allowOrgs) > 0 {
return nil, xerrors.New("allow everyone and allowed orgs cannot be used together")
}
if allowEveryone && len(rawTeams) > 0 {
return nil, xerrors.New("allow everyone and allowed teams cannot be used together")
}
if !allowEveryone && len(allowOrgs) == 0 {
return nil, xerrors.New("allowed orgs is empty: must specify at least one org or allow everyone")
}
allowTeams := make([]coderd.GithubOAuth2Team, 0, len(rawTeams))
for _, rawTeam := range rawTeams {
parts := strings.SplitN(rawTeam, "/", 2)
if len(parts) != 2 {
return nil, xerrors.Errorf("github team allowlist is formatted incorrectly. got %s; wanted <organization>/<team>", rawTeam)
}
allowTeams = append(allowTeams, coderd.GithubOAuth2Team{
Organization: parts[0],
Slug: parts[1],
})
}
createClient := func(client *http.Client) (*github.Client, error) {
if enterpriseBaseURL != "" {
return github.NewEnterpriseClient(enterpriseBaseURL, "", client)
}
return github.NewClient(client), nil
}
endpoint := xgithub.Endpoint
if enterpriseBaseURL != "" {
enterpriseURL, err := url.Parse(enterpriseBaseURL)
if err != nil {
return nil, xerrors.Errorf("parse enterprise base url: %w", err)
}
authURL, err := enterpriseURL.Parse("/login/oauth/authorize")
if err != nil {
return nil, xerrors.Errorf("parse enterprise auth url: %w", err)
}
tokenURL, err := enterpriseURL.Parse("/login/oauth/access_token")
if err != nil {
return nil, xerrors.Errorf("parse enterprise token url: %w", err)
}
endpoint = oauth2.Endpoint{
AuthURL: authURL.String(),
TokenURL: tokenURL.String(),
}
}
return &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: endpoint,
RedirectURL: redirectURL.String(),
Scopes: []string{
"read:user",
"read:org",
"user:email",
},
},
AllowSignups: allowSignups,
AllowEveryone: allowEveryone,
AllowOrganizations: allowOrgs,
AllowTeams: allowTeams,
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
api, err := createClient(client)
if err != nil {
return nil, err
}
user, _, err := api.Users.Get(ctx, "")
return user, err
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
api, err := createClient(client)
if err != nil {
return nil, err
}
emails, _, err := api.Users.ListEmails(ctx, &github.ListOptions{})
return emails, err
},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
api, err := createClient(client)
if err != nil {
return nil, err
}
memberships, _, err := api.Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{
State: "active",
ListOptions: github.ListOptions{
PerPage: 100,
},
})
return memberships, err
},
TeamMembership: func(ctx context.Context, client *http.Client, org, teamSlug, username string) (*github.Membership, error) {
api, err := createClient(client)
if err != nil {
return nil, err
}
team, _, err := api.Teams.GetTeamMembershipBySlug(ctx, org, teamSlug, username)
return team, err
},
}, nil
}
// embeddedPostgresURL returns the URL for the embedded PostgreSQL deployment.
func embeddedPostgresURL(cfg config.Root) (string, error) {
pgPassword, err := cfg.PostgresPassword().Read()
if errors.Is(err, os.ErrNotExist) {
pgPassword, err = cryptorand.String(16)
if err != nil {
return "", xerrors.Errorf("generate password: %w", err)
}
err = cfg.PostgresPassword().Write(pgPassword)
if err != nil {
return "", xerrors.Errorf("write password: %w", err)
}
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return "", err
}
pgPort, err := cfg.PostgresPort().Read()
if errors.Is(err, os.ErrNotExist) {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return "", xerrors.Errorf("listen for random port: %w", err)
}
_ = listener.Close()
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
if !valid {
return "", xerrors.Errorf("listener returned non TCP addr: %T", tcpAddr)
}
pgPort = strconv.Itoa(tcpAddr.Port)
err = cfg.PostgresPort().Write(pgPort)
if err != nil {
return "", xerrors.Errorf("write postgres port: %w", err)
}
}
return fmt.Sprintf("postgres://coder@localhost:%s/coder?sslmode=disable&password=%s", pgPort, pgPassword), nil
}
func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logger) (string, func() error, error) {
usr, err := user.Current()
if err != nil {
return "", nil, err
}
if usr.Uid == "0" {
return "", nil, xerrors.New("The built-in PostgreSQL cannot run as the root user. Create a non-root user and run again!")
}
// Ensure a password and port have been generated!
connectionURL, err := embeddedPostgresURL(cfg)
if err != nil {
return "", nil, err
}
pgPassword, err := cfg.PostgresPassword().Read()
if err != nil {
return "", nil, xerrors.Errorf("read postgres password: %w", err)
}
pgPortRaw, err := cfg.PostgresPort().Read()
if err != nil {
return "", nil, xerrors.Errorf("read postgres port: %w", err)
}
pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16)
if err != nil {
return "", nil, xerrors.Errorf("parse postgres port: %w", err)
}
stdlibLogger := slog.Stdlib(ctx, logger.Named("postgres"), slog.LevelDebug)
ep := embeddedpostgres.NewDatabase(
embeddedpostgres.DefaultConfig().
Version(embeddedpostgres.V13).
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
DataPath(filepath.Join(cfg.PostgresPath(), "data")).
RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")).
CachePath(filepath.Join(cfg.PostgresPath(), "cache")).
Username("coder").
Password(pgPassword).
Database("coder").
Port(uint32(pgPort)).
Logger(stdlibLogger.Writer()),
)
err = ep.Start()
if err != nil {
return "", nil, xerrors.Errorf("Failed to start built-in PostgreSQL. Optionally, specify an external deployment with `--postgres-url`: %w", err)
}
return connectionURL, ep.Stop, nil
}
func configureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) {
if clientCertFile != "" && clientKeyFile != "" {
certificates, err := loadCertificates([]string{clientCertFile}, []string{clientKeyFile})
if err != nil {
return ctx, nil, err
}
tlsClientConfig := &tls.Config{ //nolint:gosec
Certificates: certificates,
}
err = configureCAPool(tlsClientCAFile, tlsClientConfig)
if err != nil {
return nil, nil, err
}
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsClientConfig,
},
}
return context.WithValue(ctx, oauth2.HTTPClient, httpClient), httpClient, nil
}
return ctx, &http.Client{}, nil
}
func redirectHTTPToAccessURL(handler http.Handler, accessURL *url.URL) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
http.Redirect(w, r, accessURL.String(), http.StatusTemporaryRedirect)
return
}
handler.ServeHTTP(w, r)
})
}
// isLocalhost returns true if the host points to the local machine. Intended to
// be called with `u.Hostname()`.
func isLocalhost(host string) bool {
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}
func buildLogger(cmd *cobra.Command, cfg *codersdk.DeploymentConfig) (slog.Logger, func(), error) {
var (
sinks = []slog.Sink{}
closers = []func() error{}
)
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
switch loc {
case "":
case "/dev/stdout":
sinks = append(sinks, sinkFn(cmd.OutOrStdout()))
case "/dev/stderr":
sinks = append(sinks, sinkFn(cmd.ErrOrStderr()))
default:
fi, err := os.OpenFile(loc, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
return xerrors.Errorf("open log file %q: %w", loc, err)
}
closers = append(closers, fi.Close)
sinks = append(sinks, sinkFn(fi))
}
return nil
}
err := addSinkIfProvided(sloghuman.Sink, cfg.Logging.Human.Value)
if err != nil {
return slog.Logger{}, nil, xerrors.Errorf("add human sink: %w", err)
}
err = addSinkIfProvided(slogjson.Sink, cfg.Logging.JSON.Value)
if err != nil {
return slog.Logger{}, nil, xerrors.Errorf("add json sink: %w", err)
}
err = addSinkIfProvided(slogstackdriver.Sink, cfg.Logging.Stackdriver.Value)
if err != nil {
return slog.Logger{}, nil, xerrors.Errorf("add stackdriver sink: %w", err)
}
if cfg.Trace.CaptureLogs.Value {
sinks = append(sinks, tracing.SlogSink{})
}
level := slog.LevelInfo
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
level = slog.LevelDebug
}
if len(sinks) == 0 {
return slog.Logger{}, nil, xerrors.New("no loggers provided")
}
return slog.Make(sinks...).Leveled(level), func() {
for _, closer := range closers {
_ = closer()
}
}, nil
}