From 2ba4a62a0de717a7cd2799bc55c3dc558984d8a5 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Mon, 17 Oct 2022 08:43:30 -0500 Subject: [PATCH] feat: Add high availability for multiple replicas (#4555) * feat: HA tailnet coordinator * fixup! feat: HA tailnet coordinator * fixup! feat: HA tailnet coordinator * remove printlns * close all connections on coordinator * impelement high availability feature * fixup! impelement high availability feature * fixup! impelement high availability feature * fixup! impelement high availability feature * fixup! impelement high availability feature * Add replicas * Add DERP meshing to arbitrary addresses * Move packages to highavailability folder * Move coordinator to high availability package * Add flags for HA * Rename to replicasync * Denest packages for replicas * Add test for multiple replicas * Fix coordination test * Add HA to the helm chart * Rename function pointer * Add warnings for HA * Add the ability to block endpoints * Add flag to disable P2P connections * Wow, I made the tests pass * Add replicas endpoint * Ensure close kills replica * Update sql * Add database latency to high availability * Pipe TLS to DERP mesh * Fix DERP mesh with TLS * Add tests for TLS * Fix replica sync TLS * Fix RootCA for replica meshing * Remove ID from replicasync * Fix getting certificates for meshing * Remove excessive locking * Fix linting * Store mesh key in the database * Fix replica key for tests * Fix types gen * Fix unlocking unlocked * Fix race in tests * Update enterprise/derpmesh/derpmesh.go Co-authored-by: Colin Adler * Rename to syncReplicas * Reuse http client * Delete old replicas on a CRON * Fix race condition in connection tests * Fix linting * Fix nil type * Move pubsub to in-memory for twenty test * Add comment for configuration tweaking * Fix leak with transport * Fix close leak in derpmesh * Fix race when creating server * Remove handler update * Skip test on Windows * Fix DERP mesh test * Wrap HTTP handler replacement in mutex * Fix error message for relay * Fix API handler for normal tests * Fix speedtest * Fix replica resend * Fix derpmesh send * Ping async * Increase wait time of template version jobd * Fix race when closing replica sync * Add name to client * Log the derpmap being used * Don't connect if DERP is empty * Improve agent coordinator logging * Fix lock in coordinator * Fix relay addr * Fix race when updating durations * Fix client publish race * Run pubsub loop in a queue * Store agent nodes in order * Fix coordinator locking * Check for closed pipe Co-authored-by: Colin Adler --- .vscode/settings.json | 3 + agent/agent.go | 1 + agent/agent_test.go | 6 +- cli/agent_test.go | 14 +- cli/config/file.go | 5 + cli/configssh_test.go | 3 +- cli/deployment/flags.go | 7 + cli/portforward.go | 5 +- cli/root.go | 6 +- cli/server.go | 23 +- cli/speedtest.go | 6 +- cli/ssh.go | 4 +- coderd/activitybump_test.go | 6 +- coderd/coderd.go | 20 +- coderd/coderdtest/coderdtest.go | 144 +++-- coderd/database/databasefake/databasefake.go | 88 +++ coderd/database/db.go | 9 + coderd/database/dump.sql | 17 +- .../migrations/000061_replicas.down.sql | 2 + .../migrations/000061_replicas.up.sql | 28 + coderd/database/models.go | 15 + coderd/database/pubsub_memory.go | 3 +- coderd/database/querier.go | 6 + coderd/database/queries.sql.go | 189 +++++- coderd/database/queries/replicas.sql | 31 + coderd/database/queries/siteconfig.sql | 6 + coderd/provisionerjobs.go | 2 +- coderd/rbac/object.go | 4 + coderd/templates_test.go | 4 +- coderd/workspaceagents.go | 20 +- coderd/workspaceagents_test.go | 8 +- coderd/workspacebuilds.go | 2 +- coderd/wsconncache/wsconncache_test.go | 4 +- codersdk/agentconn.go | 6 +- codersdk/features.go | 15 +- codersdk/flags.go | 1 + codersdk/replicas.go | 44 ++ codersdk/workspaceagents.go | 34 +- enterprise/cli/features_test.go | 4 +- enterprise/cli/server.go | 55 +- enterprise/coderd/authorize_test.go | 2 +- enterprise/coderd/coderd.go | 112 +++- enterprise/coderd/coderd_test.go | 8 +- .../coderd/coderdenttest/coderdenttest.go | 65 +- .../coderdenttest/coderdenttest_test.go | 6 +- enterprise/coderd/groups_test.go | 36 +- enterprise/coderd/license/license.go | 62 +- enterprise/coderd/license/license_test.go | 113 +++- enterprise/coderd/licenses_test.go | 50 +- enterprise/coderd/replicas.go | 37 ++ enterprise/coderd/replicas_test.go | 138 +++++ enterprise/coderd/templates_test.go | 34 +- enterprise/coderd/workspaceagents_test.go | 14 +- enterprise/coderd/workspaces_test.go | 2 +- enterprise/derpmesh/derpmesh.go | 165 +++++ enterprise/derpmesh/derpmesh_test.go | 219 +++++++ enterprise/replicasync/replicasync.go | 391 ++++++++++++ enterprise/replicasync/replicasync_test.go | 239 ++++++++ enterprise/tailnet/coordinator.go | 575 ++++++++++++++++++ enterprise/tailnet/coordinator_test.go | 261 ++++++++ go.mod | 2 +- go.sum | 4 +- helm/templates/coder.yaml | 12 +- helm/templates/service.yaml | 1 + helm/values.yaml | 8 +- site/src/api/api.ts | 1 + site/src/api/typesGenerated.ts | 13 + .../LicenseBanner/LicenseBanner.tsx | 6 +- .../LicenseBannerView.stories.tsx | 10 + .../LicenseBanner/LicenseBannerView.tsx | 66 +- site/src/testHelpers/entities.ts | 3 + .../entitlements/entitlementsXService.ts | 1 + tailnet/conn.go | 30 +- tailnet/coordinator.go | 248 +++++--- tailnet/coordinator_test.go | 4 +- testutil/certificate.go | 53 ++ 76 files changed, 3437 insertions(+), 404 deletions(-) create mode 100644 coderd/database/migrations/000061_replicas.down.sql create mode 100644 coderd/database/migrations/000061_replicas.up.sql create mode 100644 coderd/database/queries/replicas.sql create mode 100644 codersdk/replicas.go create mode 100644 enterprise/coderd/replicas.go create mode 100644 enterprise/coderd/replicas_test.go create mode 100644 enterprise/derpmesh/derpmesh.go create mode 100644 enterprise/derpmesh/derpmesh_test.go create mode 100644 enterprise/replicasync/replicasync.go create mode 100644 enterprise/replicasync/replicasync_test.go create mode 100644 enterprise/tailnet/coordinator.go create mode 100644 enterprise/tailnet/coordinator_test.go create mode 100644 testutil/certificate.go diff --git a/.vscode/settings.json b/.vscode/settings.json index 8b92ff2228..9771a27a0d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -19,6 +19,7 @@ "derphttp", "derpmap", "devel", + "dflags", "drpc", "drpcconn", "drpcmux", @@ -86,8 +87,10 @@ "ptytest", "quickstart", "reconfig", + "replicasync", "retrier", "rpty", + "SCIM", "sdkproto", "sdktrace", "Signup", diff --git a/agent/agent.go b/agent/agent.go index 6d0a9a952f..f7c5598b7b 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -170,6 +170,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) { if a.isClosed() { return } + a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", derpMap)) if a.network != nil { a.network.SetDERPMap(derpMap) return diff --git a/agent/agent_test.go b/agent/agent_test.go index 06a33598b7..e10eee7f11 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -465,7 +465,7 @@ func TestAgent(t *testing.T) { conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) require.Eventually(t, func() bool { - _, err := conn.Ping() + _, err := conn.Ping(context.Background()) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) @@ -483,9 +483,7 @@ func TestAgent(t *testing.T) { t.Run("Speedtest", func(t *testing.T) { t.Parallel() - if testing.Short() { - t.Skip("The minimum duration for a speedtest is hardcoded in Tailscale to 5s!") - } + t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") derpMap := tailnettest.RunDERPAndSTUN(t) conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{ DERPMap: derpMap, diff --git a/cli/agent_test.go b/cli/agent_test.go index dd0cb1d789..f487ebfc00 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -7,8 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "cdr.dev/slog" - "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" @@ -67,11 +65,11 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.Eventually(t, func() bool { - _, err := dialer.Ping() + _, err := dialer.Ping(ctx) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() @@ -128,11 +126,11 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.Eventually(t, func() bool { - _, err := dialer.Ping() + _, err := dialer.Ping(ctx) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() @@ -189,11 +187,11 @@ func TestWorkspaceAgent(t *testing.T) { if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.Eventually(t, func() bool { - _, err := dialer.Ping() + _, err := dialer.Ping(ctx) return err == nil }, testutil.WaitMedium, testutil.IntervalFast) cancelFunc() diff --git a/cli/config/file.go b/cli/config/file.go index a98237afed..388ce0881f 100644 --- a/cli/config/file.go +++ b/cli/config/file.go @@ -13,6 +13,11 @@ func (r Root) Session() File { return File(filepath.Join(string(r), "session")) } +// ReplicaID is a unique identifier for the Coder server. +func (r Root) ReplicaID() File { + return File(filepath.Join(string(r), "replica_id")) +} + func (r Root) URL() File { return File(filepath.Join(string(r), "url")) } diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 3e1512a0c3..4553cbe431 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" @@ -115,7 +114,7 @@ func TestConfigSSH(t *testing.T) { _ = agentCloser.Close() }() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID) + agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) require.NoError(t, err) defer agentConn.Close() diff --git a/cli/deployment/flags.go b/cli/deployment/flags.go index df18e95027..714365cc8e 100644 --- a/cli/deployment/flags.go +++ b/cli/deployment/flags.go @@ -85,6 +85,13 @@ func Flags() *codersdk.DeploymentFlags { Description: "Addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections.", Default: []string{"stun.l.google.com:19302"}, }, + DerpServerRelayAddress: &codersdk.StringFlag{ + Name: "DERP Server Relay Address", + Flag: "derp-server-relay-address", + EnvVar: "CODER_DERP_SERVER_RELAY_ADDRESS", + Description: "An HTTP address that is accessible by other replicas to relay DERP traffic. Required for high availability.", + Enterprise: true, + }, DerpConfigURL: &codersdk.StringFlag{ Name: "DERP Config URL", Flag: "derp-config-url", diff --git a/cli/portforward.go b/cli/portforward.go index 476809d601..5a6f4391dd 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -16,7 +16,6 @@ import ( "github.com/spf13/cobra" "golang.org/x/xerrors" - "cdr.dev/slog" "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" @@ -96,7 +95,7 @@ func portForward() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) if err != nil { return err } @@ -156,7 +155,7 @@ func portForward() *cobra.Command { case <-ticker.C: } - _, err = conn.Ping() + _, err = conn.Ping(ctx) if err != nil { continue } diff --git a/cli/root.go b/cli/root.go index e7104e6428..91d4551916 100644 --- a/cli/root.go +++ b/cli/root.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "io" "net/http" "net/url" "os" @@ -100,8 +101,9 @@ func Core() []*cobra.Command { } func AGPL() []*cobra.Command { - all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, error) { - return coderd.New(o), nil + all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) { + api := coderd.New(o) + return api, api, nil })) return all } diff --git a/cli/server.go b/cli/server.go index 9d828abbea..c2dbeac07e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -69,7 +69,7 @@ import ( ) // nolint:gocyclo -func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command { +func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command { root := &cobra.Command{ Use: "server", Short: "Start a Coder server", @@ -167,9 +167,10 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code } defer listener.Close() + var tlsConfig *tls.Config if dflags.TLSEnable.Value { - listener, err = configureServerTLS( - listener, dflags.TLSMinVersion.Value, + tlsConfig, err = configureTLS( + dflags.TLSMinVersion.Value, dflags.TLSClientAuth.Value, dflags.TLSCertFiles.Value, dflags.TLSKeyFiles.Value, @@ -178,6 +179,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code if err != nil { return xerrors.Errorf("configure tls: %w", err) } + listener = tls.NewListener(listener, tlsConfig) } tcpAddr, valid := listener.Addr().(*net.TCPAddr) @@ -328,6 +330,9 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code Experimental: ExperimentalEnabled(cmd), DeploymentFlags: dflags, } + if tlsConfig != nil { + options.TLSCertificates = tlsConfig.Certificates + } if dflags.OAuth2GithubClientSecret.Value != "" { options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, @@ -471,11 +476,14 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code ), dflags.PromAddress.Value, "prometheus")() } - coderAPI, err := newAPI(ctx, options) + // We use a separate closer so the Enterprise API + // can have it's own close functions. This is cleaner + // than abstracting the Coder API itself. + coderAPI, closer, err := newAPI(ctx, options) if err != nil { return err } - defer coderAPI.Close() + defer closer.Close() client := codersdk.New(localURL) if dflags.TLSEnable.Value { @@ -893,7 +901,7 @@ func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, er return certs, nil } -func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (net.Listener, error) { +func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, } @@ -929,6 +937,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri if err != nil { return nil, xerrors.Errorf("load certificates: %w", err) } + tlsConfig.Certificates = certs tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { // If there's only one certificate, return it. if len(certs) == 1 { @@ -963,7 +972,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri tlsConfig.ClientCAs = caPool } - return tls.NewListener(listener, tlsConfig), nil + return tlsConfig, nil } func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) { diff --git a/cli/speedtest.go b/cli/speedtest.go index 357048f63e..f6c06641ec 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -55,7 +55,9 @@ func speedtest() *cobra.Command { if cliflag.IsSetBool(cmd, varVerbose) { logger = logger.Leveled(slog.LevelDebug) } - conn, err := client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: logger, + }) if err != nil { return err } @@ -68,7 +70,7 @@ func speedtest() *cobra.Command { return ctx.Err() case <-ticker.C: } - dur, err := conn.Ping() + dur, err := conn.Ping(ctx) if err != nil { continue } diff --git a/cli/ssh.go b/cli/ssh.go index ef8538764e..b4d4f6420d 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -20,8 +20,6 @@ import ( "golang.org/x/term" "golang.org/x/xerrors" - "cdr.dev/slog" - "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -86,7 +84,7 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) if err != nil { return err } diff --git a/coderd/activitybump_test.go b/coderd/activitybump_test.go index cd43b774d5..e498b98fa0 100644 --- a/coderd/activitybump_test.go +++ b/coderd/activitybump_test.go @@ -72,7 +72,7 @@ func TestWorkspaceActivityBump(t *testing.T) { "deadline %v never updated", firstDeadline, ) - require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, time.Second) + require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, 3*time.Second) } } @@ -82,7 +82,9 @@ func TestWorkspaceActivityBump(t *testing.T) { client, workspace, assertBumped := setupActivityTest(t) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil), resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil), + }) require.NoError(t, err) defer conn.Close() diff --git a/coderd/coderd.go b/coderd/coderd.go index 992ae6c7f5..cf8a20d373 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1,6 +1,7 @@ package coderd import ( + "crypto/tls" "crypto/x509" "io" "net/http" @@ -82,7 +83,10 @@ type Options struct { TracerProvider trace.TracerProvider AutoImportTemplates []AutoImportTemplate - TailnetCoordinator *tailnet.Coordinator + // TLSCertificates is used to mesh DERP servers securely. + TLSCertificates []tls.Certificate + TailnetCoordinator tailnet.Coordinator + DERPServer *derp.Server DERPMap *tailcfg.DERPMap MetricsCacheRefreshInterval time.Duration @@ -130,6 +134,9 @@ func New(options *Options) *API { if options.TailnetCoordinator == nil { options.TailnetCoordinator = tailnet.NewCoordinator() } + if options.DERPServer == nil { + options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp"))) + } if options.Auditor == nil { options.Auditor = audit.NewNop() } @@ -168,7 +175,7 @@ func New(options *Options) *API { api.Auditor.Store(&options.Auditor) api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer) api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0) - api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger)) + api.TailnetCoordinator.Store(&options.TailnetCoordinator) oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, OIDC: options.OIDCConfig, @@ -246,7 +253,7 @@ func New(options *Options) *API { r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps) r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps) r.Route("/derp", func(r chi.Router) { - r.Get("/", derphttp.Handler(api.derpServer).ServeHTTP) + r.Get("/", derphttp.Handler(api.DERPServer).ServeHTTP) // This is used when UDP is blocked, and latency must be checked via HTTP(s). r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -550,6 +557,7 @@ type API struct { Auditor atomic.Pointer[audit.Auditor] WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool] WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer] + TailnetCoordinator atomic.Pointer[tailnet.Coordinator] HTTPAuth *HTTPAuthorizer // APIHandler serves "/api/v2" @@ -557,7 +565,6 @@ type API struct { // RootHandler serves "/" RootHandler chi.Router - derpServer *derp.Server metricsCache *metricscache.Cache siteHandler http.Handler websocketWaitMutex sync.Mutex @@ -572,7 +579,10 @@ func (api *API) Close() error { api.websocketWaitMutex.Unlock() api.metricsCache.Close() - + coordinator := api.TailnetCoordinator.Load() + if coordinator != nil { + _ = (*coordinator).Close() + } return api.workspaceAgentCache.Close() } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f8695deb04..5cf307d842 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" @@ -23,6 +24,7 @@ import ( "regexp" "strconv" "strings" + "sync" "testing" "time" @@ -37,8 +39,10 @@ import ( "golang.org/x/xerrors" "google.golang.org/api/idtoken" "google.golang.org/api/option" + "tailscale.com/derp" "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/types/nettype" "cdr.dev/slog" @@ -60,6 +64,7 @@ import ( "github.com/coder/coder/provisionerd" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/tailnet" "github.com/coder/coder/testutil" ) @@ -77,12 +82,19 @@ type Options struct { AutobuildTicker <-chan time.Time AutobuildStats chan<- executor.Stats Auditor audit.Auditor + TLSCertificates []tls.Certificate // IncludeProvisionerDaemon when true means to start an in-memory provisionerD IncludeProvisionerDaemon bool MetricsCacheRefreshInterval time.Duration AgentStatsRefreshInterval time.Duration DeploymentFlags *codersdk.DeploymentFlags + + // Overriding the database is heavily discouraged. + // It should only be used in cases where multiple Coder + // test instances are running against the same database. + Database database.Store + Pubsub database.Pubsub } // New constructs a codersdk client connected to an in-memory API instance. @@ -116,7 +128,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) return client, closer } -func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) { +func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.CancelFunc, *coderd.Options) { if options == nil { options = &Options{} } @@ -137,23 +149,40 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance close(options.AutobuildStats) }) } - - db, pubsub := dbtestutil.NewDB(t) + if options.Database == nil { + options.Database, options.Pubsub = dbtestutil.NewDB(t) + } ctx, cancelFunc := context.WithCancel(context.Background()) lifecycleExecutor := executor.New( ctx, - db, + options.Database, slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug), options.AutobuildTicker, ).WithStatsChannel(options.AutobuildStats) lifecycleExecutor.Run() - srv := httptest.NewUnstartedServer(nil) + var mutex sync.RWMutex + var handler http.Handler + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mutex.RLock() + defer mutex.RUnlock() + if handler != nil { + handler.ServeHTTP(w, r) + } + })) srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } - srv.Start() + if options.TLSCertificates != nil { + srv.TLS = &tls.Config{ + Certificates: options.TLSCertificates, + MinVersion: tls.VersionTLS12, + } + srv.StartTLS() + } else { + srv.Start() + } t.Cleanup(srv.Close) tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr) @@ -169,6 +198,9 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{}) t.Cleanup(stunCleanup) + derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp"))) + derpServer.SetMeshKey("test-key") + // match default with cli default if options.SSHKeygenAlgorithm == "" { options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 @@ -181,53 +213,59 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance require.NoError(t, err) } - return srv, cancelFunc, &coderd.Options{ - AgentConnectionUpdateFrequency: 150 * time.Millisecond, - // Force a long disconnection timeout to ensure - // agents are not marked as disconnected during slow tests. - AgentInactiveDisconnectTimeout: testutil.WaitShort, - AccessURL: serverURL, - AppHostname: options.AppHostname, - AppHostnameRegex: appHostnameRegex, - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), - CacheDir: t.TempDir(), - Database: db, - Pubsub: pubsub, + return func(h http.Handler) { + mutex.Lock() + defer mutex.Unlock() + handler = h + }, cancelFunc, &coderd.Options{ + AgentConnectionUpdateFrequency: 150 * time.Millisecond, + // Force a long disconnection timeout to ensure + // agents are not marked as disconnected during slow tests. + AgentInactiveDisconnectTimeout: testutil.WaitShort, + AccessURL: serverURL, + AppHostname: options.AppHostname, + AppHostnameRegex: appHostnameRegex, + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + CacheDir: t.TempDir(), + Database: options.Database, + Pubsub: options.Pubsub, - Auditor: options.Auditor, - AWSCertificates: options.AWSCertificates, - AzureCertificates: options.AzureCertificates, - GithubOAuth2Config: options.GithubOAuth2Config, - OIDCConfig: options.OIDCConfig, - GoogleTokenValidator: options.GoogleTokenValidator, - SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, - APIRateLimit: options.APIRateLimit, - Authorizer: options.Authorizer, - Telemetry: telemetry.NewNoop(), - DERPMap: &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - EmbeddedRelay: true, - RegionID: 1, - RegionCode: "coder", - RegionName: "Coder", - Nodes: []*tailcfg.DERPNode{{ - Name: "1a", - RegionID: 1, - IPv4: "127.0.0.1", - DERPPort: derpPort, - STUNPort: stunAddr.Port, - InsecureForTests: true, - ForceHTTP: true, - }}, + Auditor: options.Auditor, + AWSCertificates: options.AWSCertificates, + AzureCertificates: options.AzureCertificates, + GithubOAuth2Config: options.GithubOAuth2Config, + OIDCConfig: options.OIDCConfig, + GoogleTokenValidator: options.GoogleTokenValidator, + SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, + DERPServer: derpServer, + APIRateLimit: options.APIRateLimit, + Authorizer: options.Authorizer, + Telemetry: telemetry.NewNoop(), + TLSCertificates: options.TLSCertificates, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + EmbeddedRelay: true, + RegionID: 1, + RegionCode: "coder", + RegionName: "Coder", + Nodes: []*tailcfg.DERPNode{{ + Name: "1a", + RegionID: 1, + IPv4: "127.0.0.1", + DERPPort: derpPort, + STUNPort: stunAddr.Port, + InsecureForTests: true, + ForceHTTP: options.TLSCertificates == nil, + }}, + }, }, }, - }, - AutoImportTemplates: options.AutoImportTemplates, - MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, - AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, - DeploymentFlags: options.DeploymentFlags, - } + AutoImportTemplates: options.AutoImportTemplates, + MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, + AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, + DeploymentFlags: options.DeploymentFlags, + } } // NewWithAPI constructs an in-memory API instance and returns a client to talk to it. @@ -237,10 +275,10 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c if options == nil { options = &Options{} } - srv, cancelFunc, newOptions := NewOptions(t, options) + setHandler, cancelFunc, newOptions := NewOptions(t, options) // We set the handler after server creation for the access URL. coderAPI := coderd.New(newOptions) - srv.Config.Handler = coderAPI.RootHandler + setHandler(coderAPI.RootHandler) var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = NewProvisionerDaemon(t, coderAPI) @@ -459,7 +497,7 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid var err error templateVersion, err = client.TemplateVersion(context.Background(), version) return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil - }, testutil.WaitShort, testutil.IntervalFast) + }, testutil.WaitMedium, testutil.IntervalFast) return templateVersion } diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 63239bdf4d..65043d2412 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -107,11 +107,17 @@ type data struct { workspaceApps []database.WorkspaceApp workspaces []database.Workspace licenses []database.License + replicas []database.Replica deploymentID string + derpMeshKey string lastLicenseID int32 } +func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) { + return 0, nil +} + // InTx doesn't rollback data properly for in-memory yet. func (q *fakeQuerier) InTx(fn func(database.Store) error) error { q.mutex.Lock() @@ -2931,6 +2937,21 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) { return q.deploymentID, nil } +func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.derpMeshKey = id + return nil +} + +func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.derpMeshKey, nil +} + func (q *fakeQuerier) InsertLicense( _ context.Context, arg database.InsertLicenseParams, ) (database.License, error) { @@ -3196,3 +3217,70 @@ func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error { return sql.ErrNoRows } + +func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, replica := range q.replicas { + if replica.UpdatedAt.Before(before) { + q.replicas = append(q.replicas[:i], q.replicas[i+1:]...) + } + } + + return nil +} + +func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + replica := database.Replica{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + StartedAt: arg.StartedAt, + UpdatedAt: arg.UpdatedAt, + Hostname: arg.Hostname, + RegionID: arg.RegionID, + RelayAddress: arg.RelayAddress, + Version: arg.Version, + DatabaseLatency: arg.DatabaseLatency, + } + q.replicas = append(q.replicas, replica) + return replica, nil +} + +func (q *fakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, replica := range q.replicas { + if replica.ID != arg.ID { + continue + } + replica.Hostname = arg.Hostname + replica.StartedAt = arg.StartedAt + replica.StoppedAt = arg.StoppedAt + replica.UpdatedAt = arg.UpdatedAt + replica.RelayAddress = arg.RelayAddress + replica.RegionID = arg.RegionID + replica.Version = arg.Version + replica.Error = arg.Error + replica.DatabaseLatency = arg.DatabaseLatency + q.replicas[index] = replica + return replica, nil + } + return database.Replica{}, sql.ErrNoRows +} + +func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + replicas := make([]database.Replica, 0) + for _, replica := range q.replicas { + if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid { + replicas = append(replicas, replica) + } + } + return replicas, nil +} diff --git a/coderd/database/db.go b/coderd/database/db.go index 4cbbdb399f..020000888f 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "errors" + "time" "github.com/jmoiron/sqlx" "golang.org/x/xerrors" @@ -24,6 +25,7 @@ type Store interface { // customQuerier contains custom queries that are not generated. customQuerier + Ping(ctx context.Context) (time.Duration, error) InTx(func(Store) error) error } @@ -58,6 +60,13 @@ type sqlQuerier struct { db DBTX } +// Ping returns the time it takes to ping the database. +func (q *sqlQuerier) Ping(ctx context.Context) (time.Duration, error) { + start := time.Now() + err := q.sdb.PingContext(ctx) + return time.Since(start), err +} + // InTx performs database operations inside a transaction. func (q *sqlQuerier) InTx(function func(Store) error) error { if _, ok := q.db.(*sqlx.Tx); ok { diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index de2d352a6a..b946a1130e 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -256,7 +256,8 @@ CREATE TABLE provisioner_daemons ( created_at timestamp with time zone NOT NULL, updated_at timestamp with time zone, name character varying(64) NOT NULL, - provisioners provisioner_type[] NOT NULL + provisioners provisioner_type[] NOT NULL, + replica_id uuid ); CREATE TABLE provisioner_job_logs ( @@ -287,6 +288,20 @@ CREATE TABLE provisioner_jobs ( file_id uuid NOT NULL ); +CREATE TABLE replicas ( + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + started_at timestamp with time zone NOT NULL, + stopped_at timestamp with time zone, + updated_at timestamp with time zone NOT NULL, + hostname text NOT NULL, + region_id integer NOT NULL, + relay_address text NOT NULL, + database_latency integer NOT NULL, + version text NOT NULL, + error text DEFAULT ''::text NOT NULL +); + CREATE TABLE site_configs ( key character varying(256) NOT NULL, value character varying(8192) NOT NULL diff --git a/coderd/database/migrations/000061_replicas.down.sql b/coderd/database/migrations/000061_replicas.down.sql new file mode 100644 index 0000000000..4cca6615d4 --- /dev/null +++ b/coderd/database/migrations/000061_replicas.down.sql @@ -0,0 +1,2 @@ +DROP TABLE replicas; +ALTER TABLE provisioner_daemons DROP COLUMN replica_id; diff --git a/coderd/database/migrations/000061_replicas.up.sql b/coderd/database/migrations/000061_replicas.up.sql new file mode 100644 index 0000000000..1400662e30 --- /dev/null +++ b/coderd/database/migrations/000061_replicas.up.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS replicas ( + -- A unique identifier for the replica that is stored on disk. + -- For persistent replicas, this will be reused. + -- For ephemeral replicas, this will be a new UUID for each one. + id uuid NOT NULL, + created_at timestamp with time zone NOT NULL, + -- The time the replica was created. + started_at timestamp with time zone NOT NULL, + -- The time the replica was last seen. + stopped_at timestamp with time zone, + -- Updated periodically to ensure the replica is still alive. + updated_at timestamp with time zone NOT NULL, + -- Hostname is the hostname of the replica. + hostname text NOT NULL, + -- Region is the region the replica is in. + -- We only DERP mesh to the same region ID of a running replica. + region_id integer NOT NULL, + -- An address that should be accessible to other replicas. + relay_address text NOT NULL, + -- The latency of the replica to the database in microseconds. + database_latency int NOT NULL, + -- Version is the Coder version of the replica. + version text NOT NULL, + error text NOT NULL DEFAULT '' +); + +-- Associates a provisioner daemon with a replica. +ALTER TABLE provisioner_daemons ADD COLUMN replica_id uuid; diff --git a/coderd/database/models.go b/coderd/database/models.go index e30615244e..53e074984a 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -508,6 +508,7 @@ type ProvisionerDaemon struct { UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"` Name string `db:"name" json:"name"` Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"` + ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"` } type ProvisionerJob struct { @@ -538,6 +539,20 @@ type ProvisionerJobLog struct { Output string `db:"output" json:"output"` } +type Replica struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + StartedAt time.Time `db:"started_at" json:"started_at"` + StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Hostname string `db:"hostname" json:"hostname"` + RegionID int32 `db:"region_id" json:"region_id"` + RelayAddress string `db:"relay_address" json:"relay_address"` + DatabaseLatency int32 `db:"database_latency" json:"database_latency"` + Version string `db:"version" json:"version"` + Error string `db:"error" json:"error"` +} + type SiteConfig struct { Key string `db:"key" json:"key"` Value string `db:"value" json:"value"` diff --git a/coderd/database/pubsub_memory.go b/coderd/database/pubsub_memory.go index 148d2f57b1..de5a940414 100644 --- a/coderd/database/pubsub_memory.go +++ b/coderd/database/pubsub_memory.go @@ -47,8 +47,9 @@ func (m *memoryPubsub) Publish(event string, message []byte) error { return nil } for _, listener := range listeners { - listener(context.Background(), message) + go listener(context.Background(), message) } + return nil } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ad26413873..393ab81fdd 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -26,6 +26,7 @@ type sqlcQuerier interface { DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOldAgentStats(ctx context.Context) error DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error + DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) @@ -38,6 +39,7 @@ type sqlcQuerier interface { // This function returns roles for authorization purposes. Implied member roles // are included. GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) + GetDERPMeshKey(ctx context.Context) (string, error) GetDeploymentID(ctx context.Context) (string, error) GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error) GetFileByID(ctx context.Context, id uuid.UUID) (File, error) @@ -67,6 +69,7 @@ type sqlcQuerier interface { GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error) + GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error) GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error) GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error) @@ -123,6 +126,7 @@ type sqlcQuerier interface { // every member of the org. InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) + InsertDERPMeshKey(ctx context.Context, value string) error InsertDeploymentID(ctx context.Context, value string) error InsertFile(ctx context.Context, arg InsertFileParams) (File, error) InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error) @@ -136,6 +140,7 @@ type sqlcQuerier interface { InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error) InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error) InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error) + InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) @@ -156,6 +161,7 @@ type sqlcQuerier interface { UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error + UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error UpdateTemplateDeletedByID(ctx context.Context, arg UpdateTemplateDeletedByIDParams) error UpdateTemplateMetaByID(ctx context.Context, arg UpdateTemplateMetaByIDParams) (Template, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 41eb029b59..3621050bc0 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2031,7 +2031,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one SELECT - id, created_at, updated_at, name, provisioners + id, created_at, updated_at, name, provisioners, replica_id FROM provisioner_daemons WHERE @@ -2047,13 +2047,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) &i.UpdatedAt, &i.Name, pq.Array(&i.Provisioners), + &i.ReplicaID, ) return i, err } const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many SELECT - id, created_at, updated_at, name, provisioners + id, created_at, updated_at, name, provisioners, replica_id FROM provisioner_daemons ` @@ -2073,6 +2074,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa &i.UpdatedAt, &i.Name, pq.Array(&i.Provisioners), + &i.ReplicaID, ); err != nil { return nil, err } @@ -2096,7 +2098,7 @@ INSERT INTO provisioners ) VALUES - ($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners + ($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id ` type InsertProvisionerDaemonParams struct { @@ -2120,6 +2122,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv &i.UpdatedAt, &i.Name, pq.Array(&i.Provisioners), + &i.ReplicaID, ) return i, err } @@ -2577,6 +2580,177 @@ func (q *sqlQuerier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, a return err } +const deleteReplicasUpdatedBefore = `-- name: DeleteReplicasUpdatedBefore :exec +DELETE FROM replicas WHERE updated_at < $1 +` + +func (q *sqlQuerier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error { + _, err := q.db.ExecContext(ctx, deleteReplicasUpdatedBefore, updatedAt) + return err +} + +const getReplicasUpdatedAfter = `-- name: GetReplicasUpdatedAfter :many +SELECT id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL +` + +func (q *sqlQuerier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) { + rows, err := q.db.QueryContext(ctx, getReplicasUpdatedAfter, updatedAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Replica + for rows.Next() { + var i Replica + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.StartedAt, + &i.StoppedAt, + &i.UpdatedAt, + &i.Hostname, + &i.RegionID, + &i.RelayAddress, + &i.DatabaseLatency, + &i.Version, + &i.Error, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertReplica = `-- name: InsertReplica :one +INSERT INTO replicas ( + id, + created_at, + started_at, + updated_at, + hostname, + region_id, + relay_address, + version, + database_latency +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error +` + +type InsertReplicaParams struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + StartedAt time.Time `db:"started_at" json:"started_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Hostname string `db:"hostname" json:"hostname"` + RegionID int32 `db:"region_id" json:"region_id"` + RelayAddress string `db:"relay_address" json:"relay_address"` + Version string `db:"version" json:"version"` + DatabaseLatency int32 `db:"database_latency" json:"database_latency"` +} + +func (q *sqlQuerier) InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) { + row := q.db.QueryRowContext(ctx, insertReplica, + arg.ID, + arg.CreatedAt, + arg.StartedAt, + arg.UpdatedAt, + arg.Hostname, + arg.RegionID, + arg.RelayAddress, + arg.Version, + arg.DatabaseLatency, + ) + var i Replica + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.StartedAt, + &i.StoppedAt, + &i.UpdatedAt, + &i.Hostname, + &i.RegionID, + &i.RelayAddress, + &i.DatabaseLatency, + &i.Version, + &i.Error, + ) + return i, err +} + +const updateReplica = `-- name: UpdateReplica :one +UPDATE replicas SET + updated_at = $2, + started_at = $3, + stopped_at = $4, + relay_address = $5, + region_id = $6, + hostname = $7, + version = $8, + error = $9, + database_latency = $10 +WHERE id = $1 RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error +` + +type UpdateReplicaParams struct { + ID uuid.UUID `db:"id" json:"id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + StartedAt time.Time `db:"started_at" json:"started_at"` + StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"` + RelayAddress string `db:"relay_address" json:"relay_address"` + RegionID int32 `db:"region_id" json:"region_id"` + Hostname string `db:"hostname" json:"hostname"` + Version string `db:"version" json:"version"` + Error string `db:"error" json:"error"` + DatabaseLatency int32 `db:"database_latency" json:"database_latency"` +} + +func (q *sqlQuerier) UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) { + row := q.db.QueryRowContext(ctx, updateReplica, + arg.ID, + arg.UpdatedAt, + arg.StartedAt, + arg.StoppedAt, + arg.RelayAddress, + arg.RegionID, + arg.Hostname, + arg.Version, + arg.Error, + arg.DatabaseLatency, + ) + var i Replica + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.StartedAt, + &i.StoppedAt, + &i.UpdatedAt, + &i.Hostname, + &i.RegionID, + &i.RelayAddress, + &i.DatabaseLatency, + &i.Version, + &i.Error, + ) + return i, err +} + +const getDERPMeshKey = `-- name: GetDERPMeshKey :one +SELECT value FROM site_configs WHERE key = 'derp_mesh_key' +` + +func (q *sqlQuerier) GetDERPMeshKey(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getDERPMeshKey) + var value string + err := row.Scan(&value) + return value, err +} + const getDeploymentID = `-- name: GetDeploymentID :one SELECT value FROM site_configs WHERE key = 'deployment_id' ` @@ -2588,6 +2762,15 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) { return value, err } +const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec +INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1) +` + +func (q *sqlQuerier) InsertDERPMeshKey(ctx context.Context, value string) error { + _, err := q.db.ExecContext(ctx, insertDERPMeshKey, value) + return err +} + const insertDeploymentID = `-- name: InsertDeploymentID :exec INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1) ` diff --git a/coderd/database/queries/replicas.sql b/coderd/database/queries/replicas.sql new file mode 100644 index 0000000000..e87c1f4643 --- /dev/null +++ b/coderd/database/queries/replicas.sql @@ -0,0 +1,31 @@ +-- name: GetReplicasUpdatedAfter :many +SELECT * FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL; + +-- name: InsertReplica :one +INSERT INTO replicas ( + id, + created_at, + started_at, + updated_at, + hostname, + region_id, + relay_address, + version, + database_latency +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; + +-- name: UpdateReplica :one +UPDATE replicas SET + updated_at = $2, + started_at = $3, + stopped_at = $4, + relay_address = $5, + region_id = $6, + hostname = $7, + version = $8, + error = $9, + database_latency = $10 +WHERE id = $1 RETURNING *; + +-- name: DeleteReplicasUpdatedBefore :exec +DELETE FROM replicas WHERE updated_at < $1; diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index 9d3936e238..b975d2f68c 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -3,3 +3,9 @@ INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1); -- name: GetDeploymentID :one SELECT value FROM site_configs WHERE key = 'deployment_id'; + +-- name: InsertDERPMeshKey :exec +INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1); + +-- name: GetDERPMeshKey :one +SELECT value FROM site_configs WHERE key = 'derp_mesh_key'; diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index 294b013e00..04f050f0c5 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -270,7 +270,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, } } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading job agent.", diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index 5492e4397d..1a8861c984 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -146,6 +146,10 @@ var ( ResourceDeploymentFlags = Object{ Type: "deployment_flags", } + + ResourceReplicas = Object{ + Type: "replicas", + } ) // Object is used to create objects for authz checks when you have none in diff --git a/coderd/templates_test.go b/coderd/templates_test.go index 637ced633c..f6aacba8a5 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -627,7 +627,9 @@ func TestTemplateMetrics(t *testing.T) { require.NoError(t, err) assert.Zero(t, workspaces[0].LastUsedAt) - conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("tailnet"), resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Named("tailnet"), + }) require.NoError(t, err) defer func() { _ = conn.Close() diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 295beff0d2..fb7f765cc7 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -49,7 +49,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) { }) return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -78,7 +78,7 @@ func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -98,7 +98,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -152,7 +152,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { httpapi.ResourceNotFound(rw) return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -229,7 +229,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req return } - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error reading workspace agent.", @@ -376,8 +376,9 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (* }) conn.SetNodeCallback(sendNodes) go func() { - err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID) + err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) if err != nil { + api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err)) _ = conn.Close() } }() @@ -514,8 +515,9 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request closeChan := make(chan struct{}) go func() { defer close(closeChan) - err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID) + err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID) if err != nil { + api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err)) _ = conn.Close(websocket.StatusInternalError, err.Error()) return } @@ -583,7 +585,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R go httpapi.Heartbeat(ctx, conn) defer conn.Close(websocket.StatusNormalClosure, "") - err = api.TailnetCoordinator.ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) + err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) return @@ -611,7 +613,7 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp { return apps } -func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { +func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) { var envs map[string]string if dbAgent.EnvironmentVariables.Valid { err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 6bd569dde9..e8dd772095 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -123,13 +123,13 @@ func TestWorkspaceAgentListen(t *testing.T) { defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer func() { _ = conn.Close() }() require.Eventually(t, func() bool { - _, err := conn.Ping() + _, err := conn.Ping(ctx) return err == nil }, testutil.WaitLong, testutil.IntervalFast) }) @@ -253,7 +253,9 @@ func TestWorkspaceAgentTailnet(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + }) require.NoError(t, err) defer conn.Close() sshClient, err := conn.SSHClient() diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index ed136f372b..dc89f576b5 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -861,7 +861,7 @@ func (api *API) convertWorkspaceBuild( apiAgents := make([]codersdk.WorkspaceAgent, 0) for _, agent := range agents { apps := appsByAgentID[agent.ID] - apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(apps), api.AgentInactiveDisconnectTimeout) + apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(apps), api.AgentInactiveDisconnectTimeout) if err != nil { return codersdk.WorkspaceBuild{}, xerrors.Errorf("converting workspace agent: %w", err) } diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index 003d3cddb8..d4345ce9d5 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -128,7 +128,9 @@ func TestCache(t *testing.T) { return } defer release() - proxy.Transport = conn.HTTPTransport() + transport := conn.HTTPTransport() + defer transport.CloseIdleConnections() + proxy.Transport = transport res := httptest.NewRecorder() proxy.ServeHTTP(res, req) resp := res.Result() diff --git a/codersdk/agentconn.go b/codersdk/agentconn.go index b11c440ce3..ddfb9541a1 100644 --- a/codersdk/agentconn.go +++ b/codersdk/agentconn.go @@ -132,10 +132,10 @@ type AgentConn struct { CloseFunc func() } -func (c *AgentConn) Ping() (time.Duration, error) { +func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) { errCh := make(chan error, 1) durCh := make(chan time.Duration, 1) - c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { + go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { if pr.Err != "" { errCh <- xerrors.New(pr.Err) return @@ -145,6 +145,8 @@ func (c *AgentConn) Ping() (time.Duration, error) { select { case err := <-errCh: return 0, err + case <-ctx.Done(): + return 0, ctx.Err() case dur := <-durCh: return dur, nil } diff --git a/codersdk/features.go b/codersdk/features.go index 291b5575a7..862411de62 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -15,12 +15,13 @@ const ( ) const ( - FeatureUserLimit = "user_limit" - FeatureAuditLog = "audit_log" - FeatureBrowserOnly = "browser_only" - FeatureSCIM = "scim" - FeatureWorkspaceQuota = "workspace_quota" - FeatureTemplateRBAC = "template_rbac" + FeatureUserLimit = "user_limit" + FeatureAuditLog = "audit_log" + FeatureBrowserOnly = "browser_only" + FeatureSCIM = "scim" + FeatureWorkspaceQuota = "workspace_quota" + FeatureTemplateRBAC = "template_rbac" + FeatureHighAvailability = "high_availability" ) var FeatureNames = []string{ @@ -30,6 +31,7 @@ var FeatureNames = []string{ FeatureSCIM, FeatureWorkspaceQuota, FeatureTemplateRBAC, + FeatureHighAvailability, } type Feature struct { @@ -42,6 +44,7 @@ type Feature struct { type Entitlements struct { Features map[string]Feature `json:"features"` Warnings []string `json:"warnings"` + Errors []string `json:"errors"` HasLicense bool `json:"has_license"` Experimental bool `json:"experimental"` Trial bool `json:"trial"` diff --git a/codersdk/flags.go b/codersdk/flags.go index 92f02941a5..09ca65b1ea 100644 --- a/codersdk/flags.go +++ b/codersdk/flags.go @@ -19,6 +19,7 @@ type DeploymentFlags struct { DerpServerRegionCode *StringFlag `json:"derp_server_region_code" typescript:",notnull"` DerpServerRegionName *StringFlag `json:"derp_server_region_name" typescript:",notnull"` DerpServerSTUNAddresses *StringArrayFlag `json:"derp_server_stun_address" typescript:",notnull"` + DerpServerRelayAddress *StringFlag `json:"derp_server_relay_address" typescript:",notnull"` DerpConfigURL *StringFlag `json:"derp_config_url" typescript:",notnull"` DerpConfigPath *StringFlag `json:"derp_config_path" typescript:",notnull"` PromEnabled *BoolFlag `json:"prom_enabled" typescript:",notnull"` diff --git a/codersdk/replicas.go b/codersdk/replicas.go new file mode 100644 index 0000000000..e74af021ee --- /dev/null +++ b/codersdk/replicas.go @@ -0,0 +1,44 @@ +package codersdk + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" +) + +type Replica struct { + // ID is the unique identifier for the replica. + ID uuid.UUID `json:"id"` + // Hostname is the hostname of the replica. + Hostname string `json:"hostname"` + // CreatedAt is when the replica was first seen. + CreatedAt time.Time `json:"created_at"` + // RelayAddress is the accessible address to relay DERP connections. + RelayAddress string `json:"relay_address"` + // RegionID is the region of the replica. + RegionID int32 `json:"region_id"` + // Error is the error. + Error string `json:"error"` + // DatabaseLatency is the latency in microseconds to the database. + DatabaseLatency int32 `json:"database_latency"` +} + +// Replicas fetches the list of replicas. +func (c *Client) Replicas(ctx context.Context) ([]Replica, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/v2/replicas", nil) + if err != nil { + return nil, xerrors.Errorf("execute request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, readBodyAsError(res) + } + + var replicas []Replica + return replicas, json.NewDecoder(res.Body).Decode(&replicas) +} diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 253e8713fd..c86944ae2b 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -21,7 +21,6 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/tailnet" "github.com/coder/retry" ) @@ -316,7 +315,8 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err Value: c.SessionToken, }}) httpClient := &http.Client{ - Jar: jar, + Jar: jar, + Transport: c.HTTPClient.Transport, } // nolint:bodyclose conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ @@ -332,7 +332,17 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil } -func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*AgentConn, error) { +// @typescript-ignore DialWorkspaceAgentOptions +type DialWorkspaceAgentOptions struct { + Logger slog.Logger + // BlockEndpoints forced a direct connection through DERP. + BlockEndpoints bool +} + +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) { + if options == nil { + options = &DialWorkspaceAgentOptions{} + } res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil) if err != nil { return nil, err @@ -349,9 +359,10 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg ip := tailnet.IP() conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, - DERPMap: connInfo.DERPMap, - Logger: logger, + Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, + DERPMap: connInfo.DERPMap, + Logger: options.Logger, + BlockEndpoints: options.BlockEndpoints, }) if err != nil { return nil, xerrors.Errorf("create tailnet: %w", err) @@ -370,7 +381,8 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg Value: c.SessionToken, }}) httpClient := &http.Client{ - Jar: jar, + Jar: jar, + Transport: c.HTTPClient.Transport, } ctx, cancelFunc := context.WithCancel(ctx) closed := make(chan struct{}) @@ -379,7 +391,7 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg defer close(closed) isFirst := true for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { - logger.Debug(ctx, "connecting") + options.Logger.Debug(ctx, "connecting") // nolint:bodyclose ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, @@ -398,21 +410,21 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg if errors.Is(err, context.Canceled) { return } - logger.Debug(ctx, "failed to dial", slog.Error(err)) + options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error { return conn.UpdateNodes(node) }) conn.SetNodeCallback(sendNode) - logger.Debug(ctx, "serving coordinator") + options.Logger.Debug(ctx, "serving coordinator") err = <-errChan if errors.Is(err, context.Canceled) { _ = ws.Close(websocket.StatusGoingAway, "") return } if err != nil { - logger.Debug(ctx, "error serving coordinator", slog.Error(err)) + options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err)) _ = ws.Close(websocket.StatusGoingAway, "") continue } diff --git a/enterprise/cli/features_test.go b/enterprise/cli/features_test.go index 215809c173..78b94a6509 100644 --- a/enterprise/cli/features_test.go +++ b/enterprise/cli/features_test.go @@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) { var entitlements codersdk.Entitlements err := json.Unmarshal(buf.Bytes(), &entitlements) require.NoError(t, err, "unmarshal JSON output") - assert.Len(t, entitlements.Features, 6) + assert.Len(t, entitlements.Features, 7) assert.Empty(t, entitlements.Warnings) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureUserLimit].Entitlement) @@ -71,6 +71,8 @@ func TestFeaturesList(t *testing.T) { entitlements.Features[codersdk.FeatureTemplateRBAC].Entitlement) assert.Equal(t, codersdk.EntitlementNotEntitled, entitlements.Features[codersdk.FeatureSCIM].Entitlement) + assert.Equal(t, codersdk.EntitlementNotEntitled, + entitlements.Features[codersdk.FeatureHighAvailability].Entitlement) assert.False(t, entitlements.HasLicense) assert.False(t, entitlements.Experimental) }) diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index 62af6f2888..a65b8e8faa 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -2,11 +2,20 @@ package cli import ( "context" + "database/sql" + "errors" + "io" + "net/url" "github.com/spf13/cobra" + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/types/key" "github.com/coder/coder/cli/deployment" + "github.com/coder/coder/cryptorand" "github.com/coder/coder/enterprise/coderd" + "github.com/coder/coder/tailnet" agpl "github.com/coder/coder/cli" agplcoderd "github.com/coder/coder/coderd" @@ -14,23 +23,49 @@ import ( func server() *cobra.Command { dflags := deployment.Flags() - cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) { + cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, io.Closer, error) { + if dflags.DerpServerRelayAddress.Value != "" { + _, err := url.Parse(dflags.DerpServerRelayAddress.Value) + if err != nil { + return nil, nil, xerrors.Errorf("derp-server-relay-address must be a valid HTTP URL: %w", err) + } + } + + options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp"))) + meshKey, err := options.Database.GetDERPMeshKey(ctx) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, nil, xerrors.Errorf("get mesh key: %w", err) + } + meshKey, err = cryptorand.String(32) + if err != nil { + return nil, nil, xerrors.Errorf("generate mesh key: %w", err) + } + err = options.Database.InsertDERPMeshKey(ctx, meshKey) + if err != nil { + return nil, nil, xerrors.Errorf("insert mesh key: %w", err) + } + } + options.DERPServer.SetMeshKey(meshKey) + o := &coderd.Options{ - AuditLogging: dflags.AuditLogging.Value, - BrowserOnly: dflags.BrowserOnly.Value, - SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value), - UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value, - RBACEnabled: true, - Options: options, + AuditLogging: dflags.AuditLogging.Value, + BrowserOnly: dflags.BrowserOnly.Value, + SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value), + UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value, + RBAC: true, + DERPServerRelayAddress: dflags.DerpServerRelayAddress.Value, + DERPServerRegionID: dflags.DerpServerRegionID.Value, + + Options: options, } api, err := coderd.New(ctx, o) if err != nil { - return nil, err + return nil, nil, err } - return api.AGPL, nil + return api.AGPL, api, nil }) deployment.AttachFlags(cmd.Flags(), dflags, true) - return cmd } diff --git a/enterprise/coderd/authorize_test.go b/enterprise/coderd/authorize_test.go index 72cc4c5f38..9195387632 100644 --- a/enterprise/coderd/authorize_test.go +++ b/enterprise/coderd/authorize_test.go @@ -28,7 +28,7 @@ func TestCheckACLPermissions(t *testing.T) { // Create adminClient, member, and org adminClient adminUser := coderdtest.CreateFirstUser(t, adminClient) _ = coderdenttest.AddLicense(t, adminClient, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) memberClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID) diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2c341dd13a..1250e6ae12 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -3,6 +3,8 @@ package coderd import ( "context" "crypto/ed25519" + "crypto/tls" + "crypto/x509" "net/http" "sync" "time" @@ -23,6 +25,10 @@ import ( "github.com/coder/coder/enterprise/audit" "github.com/coder/coder/enterprise/audit/backends" "github.com/coder/coder/enterprise/coderd/license" + "github.com/coder/coder/enterprise/derpmesh" + "github.com/coder/coder/enterprise/replicasync" + "github.com/coder/coder/enterprise/tailnet" + agpltailnet "github.com/coder/coder/tailnet" ) // New constructs an Enterprise coderd API instance. @@ -47,6 +53,7 @@ func New(ctx context.Context, options *Options) (*API, error) { Options: options, cancelEntitlementsLoop: cancelFunc, } + oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, OIDC: options.OIDCConfig, @@ -59,6 +66,10 @@ func New(ctx context.Context, options *Options) (*API, error) { api.AGPL.APIHandler.Group(func(r chi.Router) { r.Get("/entitlements", api.serveEntitlements) + r.Route("/replicas", func(r chi.Router) { + r.Use(apiKeyMiddleware) + r.Get("/", api.replicas) + }) r.Route("/licenses", func(r chi.Router) { r.Use(apiKeyMiddleware) r.Post("/", api.postLicense) @@ -117,7 +128,40 @@ func New(ctx context.Context, options *Options) (*API, error) { }) } - err := api.updateEntitlements(ctx) + meshRootCA := x509.NewCertPool() + for _, certificate := range options.TLSCertificates { + for _, certificatePart := range certificate.Certificate { + certificate, err := x509.ParseCertificate(certificatePart) + if err != nil { + return nil, xerrors.Errorf("parse certificate %s: %w", certificate.Subject.CommonName, err) + } + meshRootCA.AddCert(certificate) + } + } + // This TLS configuration spoofs access from the access URL hostname + // assuming that the certificates provided will cover that hostname. + // + // Replica sync and DERP meshing require accessing replicas via their + // internal IP addresses, and if TLS is configured we use the same + // certificates. + meshTLSConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: options.TLSCertificates, + RootCAs: meshRootCA, + ServerName: options.AccessURL.Hostname(), + } + var err error + api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{ + RelayAddress: options.DERPServerRelayAddress, + RegionID: int32(options.DERPServerRegionID), + TLSConfig: meshTLSConfig, + }) + if err != nil { + return nil, xerrors.Errorf("initialize replica: %w", err) + } + api.derpMesh = derpmesh.New(options.Logger.Named("derpmesh"), api.DERPServer, meshTLSConfig) + + err = api.updateEntitlements(ctx) if err != nil { return nil, xerrors.Errorf("update entitlements: %w", err) } @@ -129,13 +173,17 @@ func New(ctx context.Context, options *Options) (*API, error) { type Options struct { *coderd.Options - RBACEnabled bool + RBAC bool AuditLogging bool // Whether to block non-browser connections. BrowserOnly bool SCIMAPIKey []byte UserWorkspaceQuota int + // Used for high availability. + DERPServerRelayAddress string + DERPServerRegionID int + EntitlementsUpdateInterval time.Duration Keys map[string]ed25519.PublicKey } @@ -144,6 +192,11 @@ type API struct { AGPL *coderd.API *Options + // Detects multiple Coder replicas running at the same time. + replicaManager *replicasync.Manager + // Meshes DERP connections from multiple replicas. + derpMesh *derpmesh.Mesh + cancelEntitlementsLoop func() entitlementsMu sync.RWMutex entitlements codersdk.Entitlements @@ -151,6 +204,8 @@ type API struct { func (api *API) Close() error { api.cancelEntitlementsLoop() + _ = api.replicaManager.Close() + _ = api.derpMesh.Close() return api.AGPL.Close() } @@ -158,12 +213,13 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.entitlementsMu.Lock() defer api.entitlementsMu.Unlock() - entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, api.Keys, map[string]bool{ - codersdk.FeatureAuditLog: api.AuditLogging, - codersdk.FeatureBrowserOnly: api.BrowserOnly, - codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, - codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0, - codersdk.FeatureTemplateRBAC: api.RBACEnabled, + entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, len(api.replicaManager.All()), api.Keys, map[string]bool{ + codersdk.FeatureAuditLog: api.AuditLogging, + codersdk.FeatureBrowserOnly: api.BrowserOnly, + codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, + codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0, + codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "", + codersdk.FeatureTemplateRBAC: api.RBAC, }) if err != nil { return err @@ -209,6 +265,46 @@ func (api *API) updateEntitlements(ctx context.Context) error { api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer) } + if changed, enabled := featureChanged(codersdk.FeatureHighAvailability); changed { + coordinator := agpltailnet.NewCoordinator() + if enabled { + haCoordinator, err := tailnet.NewCoordinator(api.Logger, api.Pubsub) + if err != nil { + api.Logger.Error(ctx, "unable to set up high availability coordinator", slog.Error(err)) + // If we try to setup the HA coordinator and it fails, nothing + // is actually changing. + changed = false + } else { + coordinator = haCoordinator + } + + api.replicaManager.SetCallback(func() { + addresses := make([]string, 0) + for _, replica := range api.replicaManager.Regional() { + addresses = append(addresses, replica.RelayAddress) + } + api.derpMesh.SetAddresses(addresses, false) + _ = api.updateEntitlements(ctx) + }) + } else { + api.derpMesh.SetAddresses([]string{}, false) + api.replicaManager.SetCallback(func() { + // If the amount of replicas change, so should our entitlements. + // This is to display a warning in the UI if the user is unlicensed. + _ = api.updateEntitlements(ctx) + }) + } + + // Recheck changed in case the HA coordinator failed to set up. + if changed { + oldCoordinator := *api.AGPL.TailnetCoordinator.Swap(&coordinator) + err := oldCoordinator.Close() + if err != nil { + api.Logger.Error(ctx, "close old tailnet coordinator", slog.Error(err)) + } + } + } + api.entitlements = entitlements return nil diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 050cad5f9b..7b51845ff3 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -41,9 +41,9 @@ func TestEntitlements(t *testing.T) { }) _ = coderdtest.CreateFirstUser(t, client) coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - TemplateRBACEnabled: true, + UserLimit: 100, + AuditLog: true, + TemplateRBAC: true, }) res, err := client.Entitlements(context.Background()) require.NoError(t, err) @@ -85,7 +85,7 @@ func TestEntitlements(t *testing.T) { assert.False(t, res.HasLicense) al = res.Features[codersdk.FeatureAuditLog] assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement) - assert.True(t, al.Enabled) + assert.False(t, al.Enabled) }) t.Run("Pubsub", func(t *testing.T) { t.Parallel() diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 75760b3d4f..a8595b5bc6 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -4,7 +4,9 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "crypto/tls" "io" + "net/http" "testing" "time" @@ -60,19 +62,21 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c if options.Options == nil { options.Options = &coderdtest.Options{} } - srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options) + setHandler, cancelFunc, oop := coderdtest.NewOptions(t, options.Options) coderAPI, err := coderd.New(context.Background(), &coderd.Options{ - RBACEnabled: true, + RBAC: true, AuditLogging: options.AuditLogging, BrowserOnly: options.BrowserOnly, SCIMAPIKey: options.SCIMAPIKey, + DERPServerRelayAddress: oop.AccessURL.String(), + DERPServerRegionID: oop.DERPMap.RegionIDs()[0], UserWorkspaceQuota: options.UserWorkspaceQuota, Options: oop, EntitlementsUpdateInterval: options.EntitlementsUpdateInterval, Keys: Keys, }) assert.NoError(t, err) - srv.Config.Handler = coderAPI.AGPL.RootHandler + setHandler(coderAPI.AGPL.RootHandler) var provisionerCloser io.Closer = nopcloser{} if options.IncludeProvisionerDaemon { provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL) @@ -83,22 +87,32 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c _ = provisionerCloser.Close() _ = coderAPI.Close() }) - return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI + client := codersdk.New(coderAPI.AccessURL) + client.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + }, + }, + } + return client, provisionerCloser, coderAPI } type LicenseOptions struct { - AccountType string - AccountID string - Trial bool - AllFeatures bool - GraceAt time.Time - ExpiresAt time.Time - UserLimit int64 - AuditLog bool - BrowserOnly bool - SCIM bool - WorkspaceQuota bool - TemplateRBACEnabled bool + AccountType string + AccountID string + Trial bool + AllFeatures bool + GraceAt time.Time + ExpiresAt time.Time + UserLimit int64 + AuditLog bool + BrowserOnly bool + SCIM bool + WorkspaceQuota bool + TemplateRBAC bool + HighAvailability bool } // AddLicense generates a new license with the options provided and inserts it. @@ -134,9 +148,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { if options.WorkspaceQuota { workspaceQuota = 1 } + highAvailability := int64(0) + if options.HighAvailability { + highAvailability = 1 + } rbacEnabled := int64(0) - if options.TemplateRBACEnabled { + if options.TemplateRBAC { rbacEnabled = 1 } @@ -154,12 +172,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { Version: license.CurrentVersion, AllFeatures: options.AllFeatures, Features: license.Features{ - UserLimit: options.UserLimit, - AuditLog: auditLog, - BrowserOnly: browserOnly, - SCIM: scim, - WorkspaceQuota: workspaceQuota, - TemplateRBAC: rbacEnabled, + UserLimit: options.UserLimit, + AuditLog: auditLog, + BrowserOnly: browserOnly, + SCIM: scim, + WorkspaceQuota: workspaceQuota, + HighAvailability: highAvailability, + TemplateRBAC: rbacEnabled, }, } tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go index d526f6927b..e8ad88cd02 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest_test.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -33,7 +33,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) { ctx, _ := testutil.Context(t) admin := coderdtest.CreateFirstUser(t, client) license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) group, err := client.CreateGroup(ctx, admin.OrganizationID, codersdk.CreateGroupRequest{ Name: "testgroup", @@ -58,6 +58,10 @@ func TestAuthorizeAllEndpoints(t *testing.T) { AssertAction: rbac.ActionRead, AssertObject: rbac.ResourceLicense, } + assertRoute["GET:/api/v2/replicas"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceReplicas, + } assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{ AssertAction: rbac.ActionDelete, AssertObject: rbac.ResourceLicense, diff --git a/enterprise/coderd/groups_test.go b/enterprise/coderd/groups_test.go index 2661da6bcc..eae51b0dfd 100644 --- a/enterprise/coderd/groups_test.go +++ b/enterprise/coderd/groups_test.go @@ -24,7 +24,7 @@ func TestCreateGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -43,7 +43,7 @@ func TestCreateGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) _, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -67,7 +67,7 @@ func TestCreateGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) _, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -90,7 +90,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -112,7 +112,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -138,7 +138,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -173,7 +173,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -197,7 +197,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -221,7 +221,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) ctx, _ := testutil.Context(t) @@ -247,7 +247,7 @@ func TestPatchGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -276,7 +276,7 @@ func TestGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -296,7 +296,7 @@ func TestGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -326,7 +326,7 @@ func TestGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -347,7 +347,7 @@ func TestGroup(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -380,7 +380,7 @@ func TestGroup(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -421,7 +421,7 @@ func TestGroups(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) _, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -467,7 +467,7 @@ func TestDeleteGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) group1, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{ @@ -492,7 +492,7 @@ func TestDeleteGroup(t *testing.T) { user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) ctx, _ := testutil.Context(t) err := client.DeleteGroup(ctx, user.OrganizationID) diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index ce9e5d1d59..c5bb689db6 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -17,12 +17,20 @@ import ( ) // Entitlements processes licenses to return whether features are enabled or not. -func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) { +func Entitlements( + ctx context.Context, + db database.Store, + logger slog.Logger, + replicaCount int, + keys map[string]ed25519.PublicKey, + enablements map[string]bool, +) (codersdk.Entitlements, error) { now := time.Now() // Default all entitlements to be disabled. entitlements := codersdk.Entitlements{ Features: map[string]codersdk.Feature{}, Warnings: []string{}, + Errors: []string{}, } for _, featureName := range codersdk.FeatureNames { entitlements.Features[featureName] = codersdk.Feature{ @@ -96,6 +104,12 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke Enabled: enablements[codersdk.FeatureWorkspaceQuota], } } + if claims.Features.HighAvailability > 0 { + entitlements.Features[codersdk.FeatureHighAvailability] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: enablements[codersdk.FeatureHighAvailability], + } + } if claims.Features.TemplateRBAC > 0 { entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{ Entitlement: entitlement, @@ -132,6 +146,10 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke if featureName == codersdk.FeatureUserLimit { continue } + // High availability has it's own warnings based on replica count! + if featureName == codersdk.FeatureHighAvailability { + continue + } feature := entitlements.Features[featureName] if !feature.Enabled { continue @@ -141,9 +159,6 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke case codersdk.EntitlementNotEntitled: entitlements.Warnings = append(entitlements.Warnings, fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName)) - // Disable the feature and add a warning... - feature.Enabled = false - entitlements.Features[featureName] = feature case codersdk.EntitlementGracePeriod: entitlements.Warnings = append(entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName)) @@ -152,6 +167,32 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke } } + if replicaCount > 1 { + feature := entitlements.Features[codersdk.FeatureHighAvailability] + + switch feature.Entitlement { + case codersdk.EntitlementNotEntitled: + if entitlements.HasLicense { + entitlements.Errors = append(entitlements.Warnings, + "You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.") + } else { + entitlements.Errors = append(entitlements.Warnings, + "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.") + } + case codersdk.EntitlementGracePeriod: + entitlements.Warnings = append(entitlements.Warnings, + "You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.") + } + } + + for _, featureName := range codersdk.FeatureNames { + feature := entitlements.Features[featureName] + if feature.Entitlement == codersdk.EntitlementNotEntitled { + feature.Enabled = false + entitlements.Features[featureName] = feature + } + } + return entitlements, nil } @@ -171,12 +212,13 @@ var ( ) type Features struct { - UserLimit int64 `json:"user_limit"` - AuditLog int64 `json:"audit_log"` - BrowserOnly int64 `json:"browser_only"` - SCIM int64 `json:"scim"` - WorkspaceQuota int64 `json:"workspace_quota"` - TemplateRBAC int64 `json:"template_rbac"` + UserLimit int64 `json:"user_limit"` + AuditLog int64 `json:"audit_log"` + BrowserOnly int64 `json:"browser_only"` + SCIM int64 `json:"scim"` + WorkspaceQuota int64 `json:"workspace_quota"` + TemplateRBAC int64 `json:"template_rbac"` + HighAvailability int64 `json:"high_availability"` } type Claims struct { diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index 8f15c5c009..6def291e3e 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -20,17 +20,18 @@ import ( func TestEntitlements(t *testing.T) { t.Parallel() all := map[string]bool{ - codersdk.FeatureAuditLog: true, - codersdk.FeatureBrowserOnly: true, - codersdk.FeatureSCIM: true, - codersdk.FeatureWorkspaceQuota: true, - codersdk.FeatureTemplateRBAC: true, + codersdk.FeatureAuditLog: true, + codersdk.FeatureBrowserOnly: true, + codersdk.FeatureSCIM: true, + codersdk.FeatureWorkspaceQuota: true, + codersdk.FeatureHighAvailability: true, + codersdk.FeatureTemplateRBAC: true, } t.Run("Defaults", func(t *testing.T) { t.Parallel() db := databasefake.New() - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.False(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -46,7 +47,7 @@ func TestEntitlements(t *testing.T) { JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -60,16 +61,17 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - WorkspaceQuota: true, - TemplateRBACEnabled: true, + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + HighAvailability: true, + TemplateRBAC: true, }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -82,18 +84,19 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - WorkspaceQuota: true, - TemplateRBACEnabled: true, - GraceAt: time.Now().Add(-time.Hour), - ExpiresAt: time.Now().Add(time.Hour), + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + WorkspaceQuota: true, + HighAvailability: true, + TemplateRBAC: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -101,6 +104,9 @@ func TestEntitlements(t *testing.T) { if featureName == codersdk.FeatureUserLimit { continue } + if featureName == codersdk.FeatureHighAvailability { + continue + } niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) require.Equal(t, codersdk.EntitlementGracePeriod, entitlements.Features[featureName].Entitlement) require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName)) @@ -113,7 +119,7 @@ func TestEntitlements(t *testing.T) { JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -121,6 +127,9 @@ func TestEntitlements(t *testing.T) { if featureName == codersdk.FeatureUserLimit { continue } + if featureName == codersdk.FeatureHighAvailability { + continue + } niceName := strings.Title(strings.ReplaceAll(featureName, "_", " ")) // Ensures features that are not entitled are properly disabled. require.False(t, entitlements.Features[featureName].Enabled) @@ -139,7 +148,7 @@ func TestEntitlements(t *testing.T) { }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.Contains(t, entitlements.Warnings, "Your deployment has 2 active users but is only licensed for 1.") @@ -161,7 +170,7 @@ func TestEntitlements(t *testing.T) { }), Exp: time.Now().Add(time.Hour), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.Empty(t, entitlements.Warnings) @@ -184,7 +193,7 @@ func TestEntitlements(t *testing.T) { }), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{}) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{}) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -199,7 +208,7 @@ func TestEntitlements(t *testing.T) { AllFeatures: true, }), }) - entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all) require.NoError(t, err) require.True(t, entitlements.HasLicense) require.False(t, entitlements.Trial) @@ -211,4 +220,52 @@ func TestEntitlements(t *testing.T) { require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement) } }) + + t.Run("MultipleReplicasNoLicense", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, all) + require.NoError(t, err) + require.False(t, entitlements.HasLicense) + require.Len(t, entitlements.Errors, 1) + require.Equal(t, "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.", entitlements.Errors[0]) + }) + + t.Run("MultipleReplicasNotEntitled", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + Exp: time.Now().Add(time.Hour), + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + AuditLog: true, + }), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{ + codersdk.FeatureHighAvailability: true, + }) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.Len(t, entitlements.Errors, 1) + require.Equal(t, "You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.", entitlements.Errors[0]) + }) + + t.Run("MultipleReplicasGrace", func(t *testing.T) { + t.Parallel() + db := databasefake.New() + db.InsertLicense(context.Background(), database.InsertLicenseParams{ + JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ + HighAvailability: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), + }), + Exp: time.Now().Add(time.Hour), + }) + entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{ + codersdk.FeatureHighAvailability: true, + }) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + require.Len(t, entitlements.Warnings, 1) + require.Equal(t, "You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.", entitlements.Warnings[0]) + }) } diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index f7c1c63999..aa4dddf1fd 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -78,21 +78,21 @@ func TestGetLicense(t *testing.T) { defer cancel() coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - AccountID: "testing", - AuditLog: true, - SCIM: true, - BrowserOnly: true, - TemplateRBACEnabled: true, + AccountID: "testing", + AuditLog: true, + SCIM: true, + BrowserOnly: true, + TemplateRBAC: true, }) coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - AccountID: "testing2", - AuditLog: true, - SCIM: true, - BrowserOnly: true, - Trial: true, - UserLimit: 200, - TemplateRBACEnabled: false, + AccountID: "testing2", + AuditLog: true, + SCIM: true, + BrowserOnly: true, + Trial: true, + UserLimit: 200, + TemplateRBAC: false, }) licenses, err := client.Licenses(ctx) @@ -101,23 +101,25 @@ func TestGetLicense(t *testing.T) { assert.Equal(t, int32(1), licenses[0].ID) assert.Equal(t, "testing", licenses[0].Claims["account_id"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("0"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureWorkspaceQuota: json.Number("0"), - codersdk.FeatureTemplateRBAC: json.Number("1"), + codersdk.FeatureUserLimit: json.Number("0"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureHighAvailability: json.Number("0"), + codersdk.FeatureTemplateRBAC: json.Number("1"), }, licenses[0].Claims["features"]) assert.Equal(t, int32(2), licenses[1].ID) assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) assert.Equal(t, true, licenses[1].Claims["trial"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("200"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureWorkspaceQuota: json.Number("0"), - codersdk.FeatureTemplateRBAC: json.Number("0"), + codersdk.FeatureUserLimit: json.Number("200"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureWorkspaceQuota: json.Number("0"), + codersdk.FeatureHighAvailability: json.Number("0"), + codersdk.FeatureTemplateRBAC: json.Number("0"), }, licenses[1].Claims["features"]) }) } diff --git a/enterprise/coderd/replicas.go b/enterprise/coderd/replicas.go new file mode 100644 index 0000000000..906597f257 --- /dev/null +++ b/enterprise/coderd/replicas.go @@ -0,0 +1,37 @@ +package coderd + +import ( + "net/http" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" +) + +// replicas returns the number of replicas that are active in Coder. +func (api *API) replicas(rw http.ResponseWriter, r *http.Request) { + if !api.AGPL.Authorize(r, rbac.ActionRead, rbac.ResourceReplicas) { + httpapi.ResourceNotFound(rw) + return + } + + replicas := api.replicaManager.All() + res := make([]codersdk.Replica, 0, len(replicas)) + for _, replica := range replicas { + res = append(res, convertReplica(replica)) + } + httpapi.Write(r.Context(), rw, http.StatusOK, res) +} + +func convertReplica(replica database.Replica) codersdk.Replica { + return codersdk.Replica{ + ID: replica.ID, + Hostname: replica.Hostname, + CreatedAt: replica.CreatedAt, + RelayAddress: replica.RelayAddress, + RegionID: replica.RegionID, + Error: replica.Error, + DatabaseLatency: replica.DatabaseLatency, + } +} diff --git a/enterprise/coderd/replicas_test.go b/enterprise/coderd/replicas_test.go new file mode 100644 index 0000000000..7a3e130cf7 --- /dev/null +++ b/enterprise/coderd/replicas_test.go @@ -0,0 +1,138 @@ +package coderd_test + +import ( + "context" + "crypto/tls" + "testing" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/testutil" +) + +func TestReplicas(t *testing.T) { + t.Parallel() + t.Run("ErrorWithoutLicense", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + firstClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: pubsub, + }, + }) + _ = coderdtest.CreateFirstUser(t, firstClient) + secondClient, _, secondAPI := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + }, + }) + secondClient.SessionToken = firstClient.SessionToken + ents, err := secondClient.Entitlements(context.Background()) + require.NoError(t, err) + require.Len(t, ents.Errors, 1) + _ = secondAPI.Close() + + ents, err = firstClient.Entitlements(context.Background()) + require.NoError(t, err) + require.Len(t, ents.Warnings, 0) + }) + t.Run("ConnectAcrossMultiple", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + firstClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: pubsub, + }, + }) + firstUser := coderdtest.CreateFirstUser(t, firstClient) + coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{ + HighAvailability: true, + }) + + secondClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + }, + }) + secondClient.SessionToken = firstClient.SessionToken + replicas, err := secondClient.Replicas(context.Background()) + require.NoError(t, err) + require.Len(t, replicas, 2) + + _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + BlockEndpoints: true, + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancelFunc() + _, err = conn.Ping(ctx) + return err == nil + }, testutil.WaitLong, testutil.IntervalFast) + _ = conn.Close() + }) + t.Run("ConnectAcrossMultipleTLS", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + certificates := []tls.Certificate{testutil.GenerateTLSCertificate(t, "localhost")} + firstClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + Database: db, + Pubsub: pubsub, + TLSCertificates: certificates, + }, + }) + firstUser := coderdtest.CreateFirstUser(t, firstClient) + coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{ + HighAvailability: true, + }) + + secondClient := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + TLSCertificates: certificates, + }, + }) + secondClient.SessionToken = firstClient.SessionToken + replicas, err := secondClient.Replicas(context.Background()) + require.NoError(t, err) + require.Len(t, replicas, 2) + + _, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0) + conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{ + BlockEndpoints: true, + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + }) + require.NoError(t, err) + require.Eventually(t, func() bool { + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow) + defer cancelFunc() + _, err = conn.Ping(ctx) + return err == nil + }, testutil.WaitLong, testutil.IntervalFast) + _ = conn.Close() + replicas, err = secondClient.Replicas(context.Background()) + require.NoError(t, err) + require.Len(t, replicas, 2) + for _, replica := range replicas { + require.Empty(t, replica.Error) + } + }) +} diff --git a/enterprise/coderd/templates_test.go b/enterprise/coderd/templates_test.go index fe6dd6f687..87aa5a4ca8 100644 --- a/enterprise/coderd/templates_test.go +++ b/enterprise/coderd/templates_test.go @@ -23,7 +23,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -64,7 +64,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -88,7 +88,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -138,7 +138,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -176,7 +176,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -214,7 +214,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -262,7 +262,7 @@ func TestTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -318,7 +318,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -361,7 +361,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -422,7 +422,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -447,7 +447,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -472,7 +472,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) _, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -498,7 +498,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -533,7 +533,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -575,7 +575,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) @@ -597,7 +597,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) @@ -662,7 +662,7 @@ func TestUpdateTemplateACL(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID) diff --git a/enterprise/coderd/workspaceagents_test.go b/enterprise/coderd/workspaceagents_test.go index 9fe3cfeaa3..18285bcb94 100644 --- a/enterprise/coderd/workspaceagents_test.go +++ b/enterprise/coderd/workspaceagents_test.go @@ -2,6 +2,7 @@ package coderd_test import ( "context" + "crypto/tls" "fmt" "net/http" "testing" @@ -9,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" @@ -42,7 +42,7 @@ func TestBlockNonBrowser(t *testing.T) { BrowserOnly: true, }) _, agent := setupWorkspaceAgent(t, client, user, 0) - _, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID) + _, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) var apiErr *codersdk.Error require.ErrorAs(t, err, &apiErr) require.Equal(t, http.StatusConflict, apiErr.StatusCode()) @@ -59,7 +59,7 @@ func TestBlockNonBrowser(t *testing.T) { BrowserOnly: false, }) _, agent := setupWorkspaceAgent(t, client, user, 0) - conn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID) + conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil) require.NoError(t, err) _ = conn.Close() }) @@ -109,6 +109,14 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) agentClient := codersdk.New(client.URL) + agentClient.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + //nolint:gosec + InsecureSkipVerify: true, + }, + }, + } agentClient.SessionToken = authToken agentCloser := agent.New(agent.Options{ FetchMetadata: agentClient.WorkspaceAgentMetadata, diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 33984e970d..824b3febb1 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -26,7 +26,7 @@ func TestCreateWorkspace(t *testing.T) { client := coderdenttest.New(t, nil) user := coderdtest.CreateFirstUser(t, client) _ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBACEnabled: true, + TemplateRBAC: true, }) version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) diff --git a/enterprise/derpmesh/derpmesh.go b/enterprise/derpmesh/derpmesh.go new file mode 100644 index 0000000000..3982542167 --- /dev/null +++ b/enterprise/derpmesh/derpmesh.go @@ -0,0 +1,165 @@ +package derpmesh + +import ( + "context" + "crypto/tls" + "net" + "net/url" + "sync" + + "golang.org/x/xerrors" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/types/key" + + "github.com/coder/coder/tailnet" + + "cdr.dev/slog" +) + +// New constructs a new mesh for DERP servers. +func New(logger slog.Logger, server *derp.Server, tlsConfig *tls.Config) *Mesh { + return &Mesh{ + logger: logger, + server: server, + tlsConfig: tlsConfig, + ctx: context.Background(), + closed: make(chan struct{}), + active: make(map[string]context.CancelFunc), + } +} + +type Mesh struct { + logger slog.Logger + server *derp.Server + ctx context.Context + tlsConfig *tls.Config + + mutex sync.Mutex + closed chan struct{} + active map[string]context.CancelFunc +} + +// SetAddresses performs a diff of the incoming addresses and adds +// or removes DERP clients from the mesh. +// +// Connect is only used for testing to ensure DERPs are meshed before +// exchanging messages. +// nolint:revive +func (m *Mesh) SetAddresses(addresses []string, connect bool) { + total := make(map[string]struct{}, 0) + for _, address := range addresses { + addressURL, err := url.Parse(address) + if err != nil { + m.logger.Error(m.ctx, "invalid address", slog.F("address", err), slog.Error(err)) + continue + } + derpURL, err := addressURL.Parse("/derp") + if err != nil { + m.logger.Error(m.ctx, "parse derp", slog.F("address", err), slog.Error(err)) + continue + } + address = derpURL.String() + + total[address] = struct{}{} + added, err := m.addAddress(address, connect) + if err != nil { + m.logger.Error(m.ctx, "failed to add address", slog.F("address", address), slog.Error(err)) + continue + } + if added { + m.logger.Debug(m.ctx, "added mesh address", slog.F("address", address)) + } + } + + m.mutex.Lock() + for address := range m.active { + _, found := total[address] + if found { + continue + } + removed := m.removeAddress(address) + if removed { + m.logger.Debug(m.ctx, "removed mesh address", slog.F("address", address)) + } + } + m.mutex.Unlock() +} + +// addAddress begins meshing with a new address. It returns false if the address is already being meshed with. +// It's expected that this is a full HTTP address with a path. +// e.g. http://127.0.0.1:8080/derp +// nolint:revive +func (m *Mesh) addAddress(address string, connect bool) (bool, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + if m.isClosed() { + return false, nil + } + _, isActive := m.active[address] + if isActive { + return false, nil + } + client, err := derphttp.NewClient(m.server.PrivateKey(), address, tailnet.Logger(m.logger.Named("client"))) + if err != nil { + return false, xerrors.Errorf("create derp client: %w", err) + } + client.TLSConfig = m.tlsConfig + client.MeshKey = m.server.MeshKey() + client.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + }) + if connect { + _ = client.Connect(m.ctx) + } + ctx, cancelFunc := context.WithCancel(m.ctx) + closed := make(chan struct{}) + closeFunc := func() { + cancelFunc() + _ = client.Close() + <-closed + } + m.active[address] = closeFunc + go func() { + defer close(closed) + client.RunWatchConnectionLoop(ctx, m.server.PublicKey(), tailnet.Logger(m.logger.Named("loop")), func(np key.NodePublic) { + m.server.AddPacketForwarder(np, client) + }, func(np key.NodePublic) { + m.server.RemovePacketForwarder(np, client) + }) + }() + return true, nil +} + +// removeAddress stops meshing with a given address. +func (m *Mesh) removeAddress(address string) bool { + cancelFunc, isActive := m.active[address] + if isActive { + cancelFunc() + } + return isActive +} + +// Close ends all active meshes with the DERP server. +func (m *Mesh) Close() error { + m.mutex.Lock() + defer m.mutex.Unlock() + if m.isClosed() { + return nil + } + close(m.closed) + for _, cancelFunc := range m.active { + cancelFunc() + } + return nil +} + +func (m *Mesh) isClosed() bool { + select { + case <-m.closed: + return true + default: + } + return false +} diff --git a/enterprise/derpmesh/derpmesh_test.go b/enterprise/derpmesh/derpmesh_test.go new file mode 100644 index 0000000000..7fad141238 --- /dev/null +++ b/enterprise/derpmesh/derpmesh_test.go @@ -0,0 +1,219 @@ +package derpmesh_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/types/key" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/enterprise/derpmesh" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestDERPMesh(t *testing.T) { + t.Parallel() + commonName := "something.org" + rawCert := testutil.GenerateTLSCertificate(t, commonName) + certificate, err := x509.ParseCertificate(rawCert.Certificate[0]) + require.NoError(t, err) + pool := x509.NewCertPool() + pool.AddCert(certificate) + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: commonName, + RootCAs: pool, + Certificates: []tls.Certificate{rawCert}, + } + + t.Run("ExchangeMessages", func(t *testing.T) { + // This tests messages passing through multiple DERP servers. + t.Parallel() + firstServer, firstServerURL := startDERP(t, tlsConfig) + defer firstServer.Close() + secondServer, secondServerURL := startDERP(t, tlsConfig) + firstMesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), firstServer, tlsConfig) + firstMesh.SetAddresses([]string{secondServerURL}, true) + secondMesh := derpmesh.New(slogtest.Make(t, nil).Named("second").Leveled(slog.LevelDebug), secondServer, tlsConfig) + secondMesh.SetAddresses([]string{firstServerURL}, true) + defer firstMesh.Close() + defer secondMesh.Close() + + first := key.NewNode() + second := key.NewNode() + firstClient, err := derphttp.NewClient(first, secondServerURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + firstClient.TLSConfig = tlsConfig + secondClient, err := derphttp.NewClient(second, firstServerURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + secondClient.TLSConfig = tlsConfig + err = secondClient.Connect(context.Background()) + require.NoError(t, err) + + closed := make(chan struct{}) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + sent := []byte("hello world") + go func() { + defer close(closed) + ticker := time.NewTicker(50 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err = firstClient.Send(second.Public(), sent) + require.NoError(t, err) + } + }() + + got := recvData(t, secondClient) + require.Equal(t, sent, got) + cancelFunc() + <-closed + }) + t.Run("RemoveAddress", func(t *testing.T) { + // This tests messages passing through multiple DERP servers. + t.Parallel() + server, serverURL := startDERP(t, tlsConfig) + mesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), server, tlsConfig) + mesh.SetAddresses([]string{"http://fake.com"}, false) + // This should trigger a removal... + mesh.SetAddresses([]string{}, false) + defer mesh.Close() + + first := key.NewNode() + second := key.NewNode() + firstClient, err := derphttp.NewClient(first, serverURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + firstClient.TLSConfig = tlsConfig + secondClient, err := derphttp.NewClient(second, serverURL, tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + secondClient.TLSConfig = tlsConfig + err = secondClient.Connect(context.Background()) + require.NoError(t, err) + + closed := make(chan struct{}) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + sent := []byte("hello world") + go func() { + defer close(closed) + ticker := time.NewTicker(50 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err = firstClient.Send(second.Public(), sent) + require.NoError(t, err) + } + }() + got := recvData(t, secondClient) + require.Equal(t, sent, got) + cancelFunc() + <-closed + }) + t.Run("TwentyMeshes", func(t *testing.T) { + t.Parallel() + meshes := make([]*derpmesh.Mesh, 0, 20) + serverURLs := make([]string, 0, 20) + for i := 0; i < 20; i++ { + server, url := startDERP(t, tlsConfig) + mesh := derpmesh.New(slogtest.Make(t, nil).Named("mesh").Leveled(slog.LevelDebug), server, tlsConfig) + t.Cleanup(func() { + _ = server.Close() + _ = mesh.Close() + }) + serverURLs = append(serverURLs, url) + meshes = append(meshes, mesh) + } + for _, mesh := range meshes { + mesh.SetAddresses(serverURLs, true) + } + + first := key.NewNode() + second := key.NewNode() + firstClient, err := derphttp.NewClient(first, serverURLs[9], tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + firstClient.TLSConfig = tlsConfig + secondClient, err := derphttp.NewClient(second, serverURLs[16], tailnet.Logger(slogtest.Make(t, nil))) + require.NoError(t, err) + secondClient.TLSConfig = tlsConfig + err = secondClient.Connect(context.Background()) + require.NoError(t, err) + + closed := make(chan struct{}) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + sent := []byte("hello world") + go func() { + defer close(closed) + ticker := time.NewTicker(50 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err = firstClient.Send(second.Public(), sent) + require.NoError(t, err) + } + }() + + got := recvData(t, secondClient) + require.Equal(t, sent, got) + cancelFunc() + <-closed + }) +} + +func recvData(t *testing.T, client *derphttp.Client) []byte { + for { + msg, err := client.Recv() + if errors.Is(err, io.EOF) { + return nil + } + assert.NoError(t, err) + t.Logf("derp: %T", msg) + switch msg := msg.(type) { + case derp.ReceivedPacket: + return msg.Data + default: + // Drop all others! + } + } +} + +func startDERP(t *testing.T, tlsConfig *tls.Config) (*derp.Server, string) { + logf := tailnet.Logger(slogtest.Make(t, nil)) + d := derp.NewServer(key.NewNode(), logf) + d.SetMeshKey("some-key") + server := httptest.NewUnstartedServer(derphttp.Handler(d)) + server.TLS = tlsConfig + server.StartTLS() + t.Cleanup(func() { + _ = d.Close() + }) + t.Cleanup(server.Close) + return d, server.URL +} diff --git a/enterprise/replicasync/replicasync.go b/enterprise/replicasync/replicasync.go new file mode 100644 index 0000000000..0534c55246 --- /dev/null +++ b/enterprise/replicasync/replicasync.go @@ -0,0 +1,391 @@ +package replicasync + +import ( + "context" + "crypto/tls" + "database/sql" + "errors" + "fmt" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/buildinfo" + "github.com/coder/coder/coderd/database" +) + +var ( + PubsubEvent = "replica" +) + +type Options struct { + CleanupInterval time.Duration + UpdateInterval time.Duration + PeerTimeout time.Duration + RelayAddress string + RegionID int32 + TLSConfig *tls.Config +} + +// New registers the replica with the database and periodically updates to ensure +// it's healthy. It contacts all other alive replicas to ensure they are reachable. +func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) { + if options == nil { + options = &Options{} + } + if options.PeerTimeout == 0 { + options.PeerTimeout = 3 * time.Second + } + if options.UpdateInterval == 0 { + options.UpdateInterval = 5 * time.Second + } + if options.CleanupInterval == 0 { + // The cleanup interval can be quite long, because it's + // primary purpose is to clean up dead replicas. + options.CleanupInterval = 30 * time.Minute + } + hostname, err := os.Hostname() + if err != nil { + return nil, xerrors.Errorf("get hostname: %w", err) + } + databaseLatency, err := db.Ping(ctx) + if err != nil { + return nil, xerrors.Errorf("ping database: %w", err) + } + id := uuid.New() + replica, err := db.InsertReplica(ctx, database.InsertReplicaParams{ + ID: id, + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: hostname, + RegionID: options.RegionID, + RelayAddress: options.RelayAddress, + Version: buildinfo.Version(), + DatabaseLatency: int32(databaseLatency.Microseconds()), + }) + if err != nil { + return nil, xerrors.Errorf("insert replica: %w", err) + } + err = pubsub.Publish(PubsubEvent, []byte(id.String())) + if err != nil { + return nil, xerrors.Errorf("publish new replica: %w", err) + } + ctx, cancelFunc := context.WithCancel(ctx) + manager := &Manager{ + id: id, + options: options, + db: db, + pubsub: pubsub, + self: replica, + logger: logger, + closed: make(chan struct{}), + closeCancel: cancelFunc, + } + err = manager.syncReplicas(ctx) + if err != nil { + return nil, xerrors.Errorf("run replica: %w", err) + } + peers := manager.Regional() + if len(peers) > 0 { + self := manager.Self() + if self.RelayAddress == "" { + return nil, xerrors.Errorf("a relay address must be specified when running multiple replicas in the same region") + } + } + + err = manager.subscribe(ctx) + if err != nil { + return nil, xerrors.Errorf("subscribe: %w", err) + } + manager.closeWait.Add(1) + go manager.loop(ctx) + return manager, nil +} + +// Manager keeps the replica up to date and in sync with other replicas. +type Manager struct { + id uuid.UUID + options *Options + db database.Store + pubsub database.Pubsub + logger slog.Logger + + closeWait sync.WaitGroup + closeMutex sync.Mutex + closed chan (struct{}) + closeCancel context.CancelFunc + + self database.Replica + mutex sync.Mutex + peers []database.Replica + callback func() +} + +// updateInterval is used to determine a replicas state. +// If the replica was updated > the time, it's considered healthy. +// If the replica was updated < the time, it's considered stale. +func (m *Manager) updateInterval() time.Time { + return database.Now().Add(-3 * m.options.UpdateInterval) +} + +// loop runs the replica update sequence on an update interval. +func (m *Manager) loop(ctx context.Context) { + defer m.closeWait.Done() + updateTicker := time.NewTicker(m.options.UpdateInterval) + defer updateTicker.Stop() + deleteTicker := time.NewTicker(m.options.CleanupInterval) + defer deleteTicker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-deleteTicker.C: + err := m.db.DeleteReplicasUpdatedBefore(ctx, m.updateInterval()) + if err != nil { + m.logger.Warn(ctx, "delete old replicas", slog.Error(err)) + } + continue + case <-updateTicker.C: + } + err := m.syncReplicas(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + m.logger.Warn(ctx, "run replica update loop", slog.Error(err)) + } + } +} + +// subscribe listens for new replica information! +func (m *Manager) subscribe(ctx context.Context) error { + var ( + needsUpdate = false + updating = false + updateMutex = sync.Mutex{} + ) + + // This loop will continually update nodes as updates are processed. + // The intent is to always be up to date without spamming the run + // function, so if a new update comes in while one is being processed, + // it will reprocess afterwards. + var update func() + update = func() { + err := m.syncReplicas(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + m.logger.Warn(ctx, "run replica from subscribe", slog.Error(err)) + } + updateMutex.Lock() + if needsUpdate { + needsUpdate = false + updateMutex.Unlock() + update() + return + } + updating = false + updateMutex.Unlock() + } + cancelFunc, err := m.pubsub.Subscribe(PubsubEvent, func(ctx context.Context, message []byte) { + updateMutex.Lock() + defer updateMutex.Unlock() + id, err := uuid.Parse(string(message)) + if err != nil { + return + } + // Don't process updates for ourself! + if id == m.id { + return + } + if updating { + needsUpdate = true + return + } + updating = true + go update() + }) + if err != nil { + return err + } + go func() { + <-ctx.Done() + cancelFunc() + }() + return nil +} + +func (m *Manager) syncReplicas(ctx context.Context) error { + m.closeMutex.Lock() + m.closeWait.Add(1) + m.closeMutex.Unlock() + defer m.closeWait.Done() + // Expect replicas to update once every three times the interval... + // If they don't, assume death! + replicas, err := m.db.GetReplicasUpdatedAfter(ctx, m.updateInterval()) + if err != nil { + return xerrors.Errorf("get replicas: %w", err) + } + + m.mutex.Lock() + m.peers = make([]database.Replica, 0, len(replicas)) + for _, replica := range replicas { + if replica.ID == m.id { + continue + } + m.peers = append(m.peers, replica) + } + m.mutex.Unlock() + + client := http.Client{ + Timeout: m.options.PeerTimeout, + Transport: &http.Transport{ + TLSClientConfig: m.options.TLSConfig, + }, + } + defer client.CloseIdleConnections() + var wg sync.WaitGroup + var mu sync.Mutex + failed := make([]string, 0) + for _, peer := range m.Regional() { + wg.Add(1) + go func(peer database.Replica) { + defer wg.Done() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, peer.RelayAddress, nil) + if err != nil { + m.logger.Warn(ctx, "create http request for relay probe", + slog.F("relay_address", peer.RelayAddress), slog.Error(err)) + return + } + res, err := client.Do(req) + if err != nil { + mu.Lock() + failed = append(failed, fmt.Sprintf("relay %s (%s): %s", peer.Hostname, peer.RelayAddress, err)) + mu.Unlock() + return + } + _ = res.Body.Close() + }(peer) + } + wg.Wait() + replicaError := "" + if len(failed) > 0 { + replicaError = fmt.Sprintf("Failed to dial peers: %s", strings.Join(failed, ", ")) + } + + databaseLatency, err := m.db.Ping(ctx) + if err != nil { + return xerrors.Errorf("ping database: %w", err) + } + + replica, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{ + ID: m.self.ID, + UpdatedAt: database.Now(), + StartedAt: m.self.StartedAt, + StoppedAt: m.self.StoppedAt, + RelayAddress: m.self.RelayAddress, + RegionID: m.self.RegionID, + Hostname: m.self.Hostname, + Version: m.self.Version, + Error: replicaError, + DatabaseLatency: int32(databaseLatency.Microseconds()), + }) + if err != nil { + return xerrors.Errorf("update replica: %w", err) + } + m.mutex.Lock() + defer m.mutex.Unlock() + if m.self.Error != replica.Error { + // Publish an update occurred! + err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String())) + if err != nil { + return xerrors.Errorf("publish replica update: %w", err) + } + } + m.self = replica + if m.callback != nil { + go m.callback() + } + return nil +} + +// Self represents the current replica. +func (m *Manager) Self() database.Replica { + m.mutex.Lock() + defer m.mutex.Unlock() + return m.self +} + +// All returns every replica, including itself. +func (m *Manager) All() []database.Replica { + m.mutex.Lock() + defer m.mutex.Unlock() + return append(m.peers[:], m.self) +} + +// Regional returns all replicas in the same region excluding itself. +func (m *Manager) Regional() []database.Replica { + m.mutex.Lock() + defer m.mutex.Unlock() + replicas := make([]database.Replica, 0) + for _, replica := range m.peers { + if replica.RegionID != m.self.RegionID { + continue + } + replicas = append(replicas, replica) + } + return replicas +} + +// SetCallback sets a function to execute whenever new peers +// are refreshed or updated. +func (m *Manager) SetCallback(callback func()) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.callback = callback + // Instantly call the callback to inform replicas! + go callback() +} + +func (m *Manager) Close() error { + m.closeMutex.Lock() + select { + case <-m.closed: + m.closeMutex.Unlock() + return nil + default: + } + close(m.closed) + m.closeCancel() + m.closeWait.Wait() + m.closeMutex.Unlock() + m.mutex.Lock() + defer m.mutex.Unlock() + ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFunc() + _, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{ + ID: m.self.ID, + UpdatedAt: database.Now(), + StartedAt: m.self.StartedAt, + StoppedAt: sql.NullTime{ + Time: database.Now(), + Valid: true, + }, + RelayAddress: m.self.RelayAddress, + RegionID: m.self.RegionID, + Hostname: m.self.Hostname, + Version: m.self.Version, + Error: m.self.Error, + }) + if err != nil { + return xerrors.Errorf("update replica: %w", err) + } + err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String())) + if err != nil { + return xerrors.Errorf("publish replica update: %w", err) + } + return nil +} diff --git a/enterprise/replicasync/replicasync_test.go b/enterprise/replicasync/replicasync_test.go new file mode 100644 index 0000000000..b7709c1f6f --- /dev/null +++ b/enterprise/replicasync/replicasync_test.go @@ -0,0 +1,239 @@ +package replicasync_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/enterprise/replicasync" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestReplica(t *testing.T) { + t.Parallel() + t.Run("CreateOnNew", func(t *testing.T) { + // This ensures that a new replica is created on New. + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + closeChan := make(chan struct{}, 1) + cancel, err := pubsub.Subscribe(replicasync.PubsubEvent, func(ctx context.Context, message []byte) { + closeChan <- struct{}{} + }) + require.NoError(t, err) + defer cancel() + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil) + require.NoError(t, err) + <-closeChan + _ = server.Close() + require.NoError(t, err) + }) + t.Run("ErrorsWithoutRelayAddress", func(t *testing.T) { + // Ensures that the replica reports a successful status for + // accessing all of its peers. + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: "something", + }) + require.NoError(t, err) + _, err = replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil) + require.Error(t, err) + require.Equal(t, "a relay address must be specified when running multiple replicas in the same region", err.Error()) + }) + t.Run("ConnectsToPeerReplica", func(t *testing.T) { + // Ensures that the replica reports a successful status for + // accessing all of its peers. + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + db, pubsub := dbtestutil.NewDB(t) + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: "something", + RelayAddress: srv.URL, + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + RelayAddress: "http://169.254.169.254", + }) + require.NoError(t, err) + require.Len(t, server.Regional(), 1) + require.Equal(t, peer.ID, server.Regional()[0].ID) + require.Empty(t, server.Self().Error) + _ = server.Close() + }) + t.Run("ConnectsToPeerReplicaTLS", func(t *testing.T) { + // Ensures that the replica reports a successful status for + // accessing all of its peers. + t.Parallel() + rawCert := testutil.GenerateTLSCertificate(t, "hello.org") + certificate, err := x509.ParseCertificate(rawCert.Certificate[0]) + require.NoError(t, err) + pool := x509.NewCertPool() + pool.AddCert(certificate) + // nolint:gosec + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{rawCert}, + ServerName: "hello.org", + RootCAs: pool, + } + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + srv.TLS = tlsConfig + srv.StartTLS() + defer srv.Close() + db, pubsub := dbtestutil.NewDB(t) + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + StartedAt: database.Now(), + UpdatedAt: database.Now(), + Hostname: "something", + RelayAddress: srv.URL, + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + RelayAddress: "http://169.254.169.254", + TLSConfig: tlsConfig, + }) + require.NoError(t, err) + require.Len(t, server.Regional(), 1) + require.Equal(t, peer.ID, server.Regional()[0].ID) + require.Empty(t, server.Self().Error) + _ = server.Close() + }) + t.Run("ConnectsToFakePeerWithError", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + CreatedAt: database.Now().Add(time.Minute), + StartedAt: database.Now().Add(time.Minute), + UpdatedAt: database.Now().Add(time.Minute), + Hostname: "something", + // Fake address to dial! + RelayAddress: "http://127.0.0.1:1", + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + PeerTimeout: 1 * time.Millisecond, + RelayAddress: "http://127.0.0.1:1", + }) + require.NoError(t, err) + require.Len(t, server.Regional(), 1) + require.Equal(t, peer.ID, server.Regional()[0].ID) + require.NotEmpty(t, server.Self().Error) + require.Contains(t, server.Self().Error, "Failed to dial peers") + _ = server.Close() + }) + t.Run("RefreshOnPublish", func(t *testing.T) { + // Refresh when a new replica appears! + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil) + require.NoError(t, err) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + RelayAddress: srv.URL, + UpdatedAt: database.Now(), + }) + require.NoError(t, err) + // Publish multiple times to ensure it can handle that case. + err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String())) + require.NoError(t, err) + err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String())) + require.NoError(t, err) + require.Eventually(t, func() bool { + return len(server.Regional()) == 1 + }, testutil.WaitShort, testutil.IntervalFast) + _ = server.Close() + }) + t.Run("DeletesOld", func(t *testing.T) { + t.Parallel() + db, pubsub := dbtestutil.NewDB(t) + _, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ + ID: uuid.New(), + UpdatedAt: database.Now().Add(-time.Hour), + }) + require.NoError(t, err) + server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{ + RelayAddress: "google.com", + CleanupInterval: time.Millisecond, + }) + require.NoError(t, err) + defer server.Close() + require.Eventually(t, func() bool { + return len(server.Regional()) == 0 + }, testutil.WaitShort, testutil.IntervalFast) + }) + t.Run("TwentyConcurrent", func(t *testing.T) { + // Ensures that twenty concurrent replicas can spawn and all + // discover each other in parallel! + t.Parallel() + // This doesn't use the database fake because creating + // this many PostgreSQL connections takes some + // configuration tweaking. + db := databasefake.New() + pubsub := database.NewPubsubInMemory() + logger := slogtest.Make(t, nil) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + var wg sync.WaitGroup + count := 20 + wg.Add(count) + for i := 0; i < count; i++ { + server, err := replicasync.New(context.Background(), logger, db, pubsub, &replicasync.Options{ + RelayAddress: srv.URL, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = server.Close() + }) + done := false + server.SetCallback(func() { + if len(server.All()) != count { + return + } + if done { + return + } + done = true + wg.Done() + }) + } + wg.Wait() + }) +} diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go new file mode 100644 index 0000000000..5749d9ef47 --- /dev/null +++ b/enterprise/tailnet/coordinator.go @@ -0,0 +1,575 @@ +package tailnet + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/coderd/database" + agpl "github.com/coder/coder/tailnet" +) + +// NewCoordinator creates a new high availability coordinator +// that uses PostgreSQL pubsub to exchange handshakes. +func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) { + ctx, cancelFunc := context.WithCancel(context.Background()) + coord := &haCoordinator{ + id: uuid.New(), + log: logger, + pubsub: pubsub, + closeFunc: cancelFunc, + close: make(chan struct{}), + nodes: map[uuid.UUID]*agpl.Node{}, + agentSockets: map[uuid.UUID]net.Conn{}, + agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, + } + + if err := coord.runPubsub(ctx); err != nil { + return nil, xerrors.Errorf("run coordinator pubsub: %w", err) + } + + return coord, nil +} + +type haCoordinator struct { + id uuid.UUID + log slog.Logger + mutex sync.RWMutex + pubsub database.Pubsub + close chan struct{} + closeFunc context.CancelFunc + + // nodes maps agent and connection IDs their respective node. + nodes map[uuid.UUID]*agpl.Node + // agentSockets maps agent IDs to their open websocket. + agentSockets map[uuid.UUID]net.Conn + // agentToConnectionSockets maps agent IDs to connection IDs of conns that + // are subscribed to updates for that agent. + agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn +} + +// Node returns an in-memory node by ID. +func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node { + c.mutex.Lock() + defer c.mutex.Unlock() + node := c.nodes[id] + return node +} + +// ServeClient accepts a WebSocket connection that wants to connect to an agent +// with the specified ID. +func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + c.mutex.Lock() + // When a new connection is requested, we update it with the latest + // node of the agent. This allows the connection to establish. + node, ok := c.nodes[agent] + c.mutex.Unlock() + if ok { + data, err := json.Marshal([]*agpl.Node{node}) + if err != nil { + return xerrors.Errorf("marshal node: %w", err) + } + _, err = conn.Write(data) + if err != nil { + return xerrors.Errorf("write nodes: %w", err) + } + } else { + err := c.publishClientHello(agent) + if err != nil { + return xerrors.Errorf("publish client hello: %w", err) + } + } + + c.mutex.Lock() + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + connectionSockets = map[uuid.UUID]net.Conn{} + c.agentToConnectionSockets[agent] = connectionSockets + } + + // Insert this connection into a map so the agent can publish node updates. + connectionSockets[id] = conn + c.mutex.Unlock() + + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + // Clean all traces of this connection from the map. + delete(c.nodes, id) + connectionSockets, ok := c.agentToConnectionSockets[agent] + if !ok { + return + } + delete(connectionSockets, id) + if len(connectionSockets) != 0 { + return + } + delete(c.agentToConnectionSockets, agent) + }() + + decoder := json.NewDecoder(conn) + // Indefinitely handle messages from the client websocket. + for { + err := c.handleNextClientMessage(id, agent, decoder) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) + } + } +} + +func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + // Update the node of this client in our in-memory map. If an agent entirely + // shuts down and reconnects, it needs to be aware of all clients attempting + // to establish connections. + c.nodes[id] = &node + // Write the new node from this client to the actively connected agent. + agentSocket, ok := c.agentSockets[agent] + c.mutex.Unlock() + if !ok { + // If we don't own the agent locally, send it over pubsub to a node that + // owns the agent. + err := c.publishNodesToAgent(agent, []*agpl.Node{&node}) + if err != nil { + return xerrors.Errorf("publish node to agent") + } + return nil + } + + // Write the new node from this client to the actively + // connected agent. + data, err := json.Marshal([]*agpl.Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + _, err = agentSocket.Write(data) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("write json: %w", err) + } + + return nil +} + +// ServeAgent accepts a WebSocket connection to an agent that listens to +// incoming connections and publishes node updates. +func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { + // Tell clients on other instances to send a callmemaybe to us. + err := c.publishAgentHello(id) + if err != nil { + return xerrors.Errorf("publish agent hello: %w", err) + } + + // Publish all nodes on this instance that want to connect to this agent. + nodes := c.nodesSubscribedToAgent(id) + if len(nodes) > 0 { + data, err := json.Marshal(nodes) + if err != nil { + return xerrors.Errorf("marshal json: %w", err) + } + _, err = conn.Write(data) + if err != nil { + return xerrors.Errorf("write nodes: %w", err) + } + } + + // If an old agent socket is connected, we close it + // to avoid any leaks. This shouldn't ever occur because + // we expect one agent to be running. + c.mutex.Lock() + oldAgentSocket, ok := c.agentSockets[id] + if ok { + _ = oldAgentSocket.Close() + } + c.agentSockets[id] = conn + c.mutex.Unlock() + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + delete(c.agentSockets, id) + delete(c.nodes, id) + }() + + decoder := json.NewDecoder(conn) + for { + node, err := c.handleAgentUpdate(id, decoder) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) + } + + err = c.publishAgentToNodes(id, node) + if err != nil { + return xerrors.Errorf("publish agent to nodes: %w", err) + } + } +} + +func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node { + c.mutex.Lock() + defer c.mutex.Unlock() + sockets, ok := c.agentToConnectionSockets[agentID] + if !ok { + return nil + } + + nodes := make([]*agpl.Node, 0, len(sockets)) + for targetID := range sockets { + node, ok := c.nodes[targetID] + if !ok { + continue + } + nodes = append(nodes, node) + } + + return nodes +} + +func (c *haCoordinator) handleClientHello(id uuid.UUID) error { + c.mutex.Lock() + node, ok := c.nodes[id] + c.mutex.Unlock() + if !ok { + return nil + } + return c.publishAgentToNodes(id, node) +} + +func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) { + var node agpl.Node + err := decoder.Decode(&node) + if err != nil { + return nil, xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + oldNode := c.nodes[id] + if oldNode != nil { + if oldNode.AsOf.After(node.AsOf) { + c.mutex.Unlock() + return oldNode, nil + } + } + c.nodes[id] = &node + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + c.mutex.Unlock() + return &node, nil + } + + data, err := json.Marshal([]*agpl.Node{&node}) + if err != nil { + c.mutex.Unlock() + return nil, xerrors.Errorf("marshal nodes: %w", err) + } + + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + defer wg.Done() + _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = connectionSocket.Write(data) + }() + } + c.mutex.Unlock() + wg.Wait() + return &node, nil +} + +// Close closes all of the open connections in the coordinator and stops the +// coordinator from accepting new connections. +func (c *haCoordinator) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + select { + case <-c.close: + return nil + default: + } + close(c.close) + c.closeFunc() + + wg := sync.WaitGroup{} + + wg.Add(len(c.agentSockets)) + for _, socket := range c.agentSockets { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + + for _, connMap := range c.agentToConnectionSockets { + wg.Add(len(connMap)) + for _, socket := range connMap { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + } + + wg.Wait() + return nil +} + +func (c *haCoordinator) publishNodesToAgent(recipient uuid.UUID, nodes []*agpl.Node) error { + msg, err := c.formatCallMeMaybe(recipient, nodes) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) publishAgentHello(id uuid.UUID) error { + msg, err := c.formatAgentHello(id) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) publishClientHello(id uuid.UUID) error { + msg, err := c.formatClientHello(id) + if err != nil { + return xerrors.Errorf("format client hello: %w", err) + } + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish client hello: %w", err) + } + return nil +} + +func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error { + msg, err := c.formatAgentUpdate(id, node) + if err != nil { + return xerrors.Errorf("format publish message: %w", err) + } + + err = c.pubsub.Publish("wireguard_peers", msg) + if err != nil { + return xerrors.Errorf("publish message: %w", err) + } + + return nil +} + +func (c *haCoordinator) runPubsub(ctx context.Context) error { + messageQueue := make(chan []byte, 64) + cancelSub, err := c.pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) { + select { + case messageQueue <- message: + case <-ctx.Done(): + return + } + }) + if err != nil { + return xerrors.Errorf("subscribe wireguard peers") + } + go func() { + for { + var message []byte + select { + case <-ctx.Done(): + return + case message = <-messageQueue: + } + c.handlePubsubMessage(ctx, message) + } + }() + + go func() { + defer cancelSub() + <-c.close + }() + + return nil +} + +func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) { + sp := bytes.Split(message, []byte("|")) + if len(sp) != 4 { + c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message))) + return + } + + var ( + coordinatorID = sp[0] + eventType = sp[1] + agentID = sp[2] + nodeJSON = sp[3] + ) + + sender, err := uuid.ParseBytes(coordinatorID) + if err != nil { + c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message))) + return + } + + // We sent this message! + if sender == c.id { + return + } + + switch string(eventType) { + case "callmemaybe": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + c.mutex.Lock() + agentSocket, ok := c.agentSockets[agentUUID] + if !ok { + c.mutex.Unlock() + return + } + c.mutex.Unlock() + + // We get a single node over pubsub, so turn into an array. + _, err = agentSocket.Write(nodeJSON) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { + return + } + c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err)) + return + } + case "clienthello": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + err = c.handleClientHello(agentUUID) + if err != nil { + c.log.Error(ctx, "handle agent request node", slog.Error(err)) + return + } + case "agenthello": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + nodes := c.nodesSubscribedToAgent(agentUUID) + if len(nodes) > 0 { + err := c.publishNodesToAgent(agentUUID, nodes) + if err != nil { + c.log.Error(ctx, "publish nodes to agent", slog.Error(err)) + return + } + } + case "agentupdate": + agentUUID, err := uuid.ParseBytes(agentID) + if err != nil { + c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID))) + return + } + + decoder := json.NewDecoder(bytes.NewReader(nodeJSON)) + _, err = c.handleAgentUpdate(agentUUID, decoder) + if err != nil { + c.log.Error(ctx, "handle agent update", slog.Error(err)) + return + } + default: + c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType))) + } +} + +// format: |callmemaybe|| +func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, nodes []*agpl.Node) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("callmemaybe|") + buf.WriteString(recipient.String() + "|") + err := json.NewEncoder(&buf).Encode(nodes) + if err != nil { + return nil, xerrors.Errorf("encode node: %w", err) + } + + return buf.Bytes(), nil +} + +// format: |agenthello|| +func (c *haCoordinator) formatAgentHello(id uuid.UUID) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("agenthello|") + buf.WriteString(id.String() + "|") + + return buf.Bytes(), nil +} + +// format: |clienthello|| +func (c *haCoordinator) formatClientHello(id uuid.UUID) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("clienthello|") + buf.WriteString(id.String() + "|") + + return buf.Bytes(), nil +} + +// format: |agentupdate|| +func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) { + buf := bytes.Buffer{} + + buf.WriteString(c.id.String() + "|") + buf.WriteString("agentupdate|") + buf.WriteString(id.String() + "|") + err := json.NewEncoder(&buf).Encode(node) + if err != nil { + return nil, xerrors.Errorf("encode node: %w", err) + } + + return buf.Bytes(), nil +} diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go new file mode 100644 index 0000000000..86cee94dbd --- /dev/null +++ b/enterprise/tailnet/coordinator_test.go @@ -0,0 +1,261 @@ +package tailnet_test + +import ( + "net" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/enterprise/tailnet" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestCoordinatorSingle(t *testing.T) { + t.Parallel() + t.Run("ClientWithoutAgent", func(t *testing.T) { + t.Parallel() + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + client, server := net.Pipe() + sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(server, id, uuid.New()) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + err = client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithoutClients", func(t *testing.T) { + t.Parallel() + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + client, server := net.Pipe() + sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + err = client.Close() + require.NoError(t, err) + <-errChan + <-closeChan + }) + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + + coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory()) + require.NoError(t, err) + defer coordinator.Close() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*agpl.Node) + sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*agpl.Node) + sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&agpl.Node{}) + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&agpl.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*agpl.Node) + _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} + +func TestCoordinatorHA(t *testing.T) { + t.Parallel() + + t.Run("AgentWithClient", func(t *testing.T) { + t.Parallel() + + _, pubsub := dbtestutil.NewDB(t) + + coordinator1, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub) + require.NoError(t, err) + defer coordinator1.Close() + + agentWS, agentServerWS := net.Pipe() + defer agentWS.Close() + agentNodeChan := make(chan []*agpl.Node) + sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + agentID := uuid.New() + closeAgentChan := make(chan struct{}) + go func() { + err := coordinator1.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + sendAgentNode(&agpl.Node{}) + require.Eventually(t, func() bool { + return coordinator1.Node(agentID) != nil + }, testutil.WaitShort, testutil.IntervalFast) + + coordinator2, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub) + require.NoError(t, err) + defer coordinator2.Close() + + clientWS, clientServerWS := net.Pipe() + defer clientWS.Close() + defer clientServerWS.Close() + clientNodeChan := make(chan []*agpl.Node) + sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error { + clientNodeChan <- nodes + return nil + }) + clientID := uuid.New() + closeClientChan := make(chan struct{}) + go func() { + err := coordinator2.ServeClient(clientServerWS, clientID, agentID) + assert.NoError(t, err) + close(closeClientChan) + }() + agentNodes := <-clientNodeChan + require.Len(t, agentNodes, 1) + sendClientNode(&agpl.Node{}) + _ = sendClientNode + clientNodes := <-agentNodeChan + require.Len(t, clientNodes, 1) + + // Ensure an update to the agent node reaches the client! + sendAgentNode(&agpl.Node{}) + agentNodes = <-clientNodeChan + require.Len(t, agentNodes, 1) + + // Close the agent WebSocket so a new one can connect. + require.NoError(t, agentWS.Close()) + require.NoError(t, agentServerWS.Close()) + <-agentErrChan + <-closeAgentChan + + // Create a new agent connection. This is to simulate a reconnect! + agentWS, agentServerWS = net.Pipe() + defer agentWS.Close() + agentNodeChan = make(chan []*agpl.Node) + _, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error { + agentNodeChan <- nodes + return nil + }) + closeAgentChan = make(chan struct{}) + go func() { + err := coordinator1.ServeAgent(agentServerWS, agentID) + assert.NoError(t, err) + close(closeAgentChan) + }() + // Ensure the existing listening client sends it's node immediately! + clientNodes = <-agentNodeChan + require.Len(t, clientNodes, 1) + + err = agentWS.Close() + require.NoError(t, err) + <-agentErrChan + <-closeAgentChan + + err = clientWS.Close() + require.NoError(t, err) + <-clientErrChan + <-closeClientChan + }) +} diff --git a/go.mod b/go.mod index 9834e27e5f..195a09ae2b 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5 // Switch to our fork that imports fixes from http://github.com/tailscale/ssh. // See: https://github.com/coder/coder/issues/3371 diff --git a/go.sum b/go.sum index 13fdc5724f..b80c0d4173 100644 --- a/go.sum +++ b/go.sum @@ -351,8 +351,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= -github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c h1:xa6lr5Pj87Is26tgpzwBsEGKL7aVz7/fRGgY9QIbf3E= -github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak= +github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5 h1:WVH6e/qK3Wpl0wbmpORD2oQ1qLJborF3fsFHyO1ps0Y= +github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= diff --git a/helm/templates/coder.yaml b/helm/templates/coder.yaml index 45f3f6e29a..1165251fc8 100644 --- a/helm/templates/coder.yaml +++ b/helm/templates/coder.yaml @@ -14,10 +14,7 @@ metadata: {{- include "coder.labels" . | nindent 4 }} annotations: {{ toYaml .Values.coder.annotations | nindent 4}} spec: - # NOTE: this is currently not used as coder v2 does not support high - # availability yet. - # replicas: {{ .Values.coder.replicaCount }} - replicas: 1 + replicas: {{ .Values.coder.replicaCount }} selector: matchLabels: {{- include "coder.selectorLabels" . | nindent 6 }} @@ -38,6 +35,13 @@ spec: env: - name: CODER_ADDRESS value: "0.0.0.0:{{ include "coder.port" . }}" + # Used for inter-pod communication with high-availability. + - name: KUBE_POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + - name: CODER_DERP_SERVER_RELAY_ADDRESS + value: "{{ include "coder.portName" . }}://$(KUBE_POD_IP):{{ include "coder.port" . }}" {{- include "coder.tlsEnv" . | nindent 12 }} {{- with .Values.coder.env -}} {{ toYaml . | nindent 12 }} diff --git a/helm/templates/service.yaml b/helm/templates/service.yaml index 28fe0e9f9a..b9a7e9a2f0 100644 --- a/helm/templates/service.yaml +++ b/helm/templates/service.yaml @@ -10,6 +10,7 @@ metadata: {{- toYaml .Values.coder.service.annotations | nindent 4 }} spec: type: {{ .Values.coder.service.type }} + sessionAffinity: ClientIP ports: - name: {{ include "coder.portName" . | quote }} port: {{ include "coder.servicePort" . }} diff --git a/helm/values.yaml b/helm/values.yaml index 30a21a8985..392a53c187 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -1,9 +1,9 @@ # coder -- Primary configuration for `coder server`. coder: - # NOTE: this is currently not used as coder v2 does not support high - # availability yet. - # # coder.replicaCount -- The number of Kubernetes deployment replicas. - # replicaCount: 1 + # coder.replicaCount -- The number of Kubernetes deployment replicas. + # This should only be increased if High Availability is enabled. + # This is an Enterprise feature. Contact sales@coder.com. + replicaCount: 1 # coder.image -- The image to use for Coder. image: diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 2e60a88b84..fb12571fd9 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -28,6 +28,7 @@ export const defaultEntitlements = (): TypesGen.Entitlements => { return { features: features, has_license: false, + errors: [], warnings: [], experimental: false, trial: false, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 5347613e77..a4b2cf83a9 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -274,6 +274,7 @@ export interface DeploymentFlags { readonly derp_server_region_code: StringFlag readonly derp_server_region_name: StringFlag readonly derp_server_stun_address: StringArrayFlag + readonly derp_server_relay_address: StringFlag readonly derp_config_url: StringFlag readonly derp_config_path: StringFlag readonly prom_enabled: BoolFlag @@ -337,6 +338,7 @@ export interface DurationFlag { export interface Entitlements { readonly features: Record readonly warnings: string[] + readonly errors: string[] readonly has_license: boolean readonly experimental: boolean readonly trial: boolean @@ -528,6 +530,17 @@ export interface PutExtendWorkspaceRequest { readonly deadline: string } +// From codersdk/replicas.go +export interface Replica { + readonly id: string + readonly hostname: string + readonly created_at: string + readonly relay_address: string + readonly region_id: number + readonly error: string + readonly database_latency: number +} + // From codersdk/error.go export interface Response { readonly message: string diff --git a/site/src/components/LicenseBanner/LicenseBanner.tsx b/site/src/components/LicenseBanner/LicenseBanner.tsx index 8532bfca2e..7ecfc2a2a2 100644 --- a/site/src/components/LicenseBanner/LicenseBanner.tsx +++ b/site/src/components/LicenseBanner/LicenseBanner.tsx @@ -8,15 +8,15 @@ export const LicenseBanner: React.FC = () => { const [entitlementsState, entitlementsSend] = useActor( xServices.entitlementsXService, ) - const { warnings } = entitlementsState.context.entitlements + const { errors, warnings } = entitlementsState.context.entitlements /** Gets license data on app mount because LicenseBanner is mounted in App */ useEffect(() => { entitlementsSend("GET_ENTITLEMENTS") }, [entitlementsSend]) - if (warnings.length > 0) { - return + if (errors.length > 0 || warnings.length > 0) { + return } else { return null } diff --git a/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx b/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx index c37653eff7..c7ee69c261 100644 --- a/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx +++ b/site/src/components/LicenseBanner/LicenseBannerView.stories.tsx @@ -12,13 +12,23 @@ const Template: Story = (args) => ( export const OneWarning = Template.bind({}) OneWarning.args = { + errors: [], warnings: ["You have exceeded the number of seats in your license."], } export const TwoWarnings = Template.bind({}) TwoWarnings.args = { + errors: [], warnings: [ "You have exceeded the number of seats in your license.", "You are flying too close to the sun.", ], } + +export const OneError = Template.bind({}) +OneError.args = { + errors: [ + "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.", + ], + warnings: [], +} diff --git a/site/src/components/LicenseBanner/LicenseBannerView.tsx b/site/src/components/LicenseBanner/LicenseBannerView.tsx index 49276b1f0d..792bc191a0 100644 --- a/site/src/components/LicenseBanner/LicenseBannerView.tsx +++ b/site/src/components/LicenseBanner/LicenseBannerView.tsx @@ -2,47 +2,56 @@ import { makeStyles } from "@material-ui/core/styles" import { Expander } from "components/Expander/Expander" import { Pill } from "components/Pill/Pill" import { useState } from "react" +import { colors } from "theme/colors" export const Language = { licenseIssue: "License Issue", licenseIssues: (num: number): string => `${num} License Issues`, - upgrade: "Contact us to upgrade your license.", + upgrade: "Contact sales@coder.com.", exceeded: "It looks like you've exceeded some limits of your license.", lessDetails: "Less", moreDetails: "More", } export interface LicenseBannerViewProps { + errors: string[] warnings: string[] } export const LicenseBannerView: React.FC = ({ + errors, warnings, }) => { const styles = useStyles() const [showDetails, setShowDetails] = useState(false) - if (warnings.length === 1) { + const isError = errors.length > 0 + const messages = [...errors, ...warnings] + const type = isError ? "error" : "warning" + + if (messages.length === 1) { return ( -
- - {warnings[0]} -   - - {Language.upgrade} - +
+ +
+ {messages[0]} +   + + {Language.upgrade} + +
) } else { return ( -
-
-
- - {Language.exceeded} +
+ +
+
    - {warnings.map((warning) => ( -
  • - {warning} + {messages.map((message) => ( +
  • + {message}
  • ))}
@@ -67,14 +76,18 @@ const useStyles = makeStyles((theme) => ({ container: { padding: theme.spacing(1.5), backgroundColor: theme.palette.warning.main, + display: "flex", + alignItems: "center", + + "&.error": { + backgroundColor: colors.red[12], + }, }, flex: { - display: "flex", + display: "column", }, leftContent: { marginRight: theme.spacing(1), - }, - text: { marginLeft: theme.spacing(1), }, link: { @@ -83,9 +96,10 @@ const useStyles = makeStyles((theme) => ({ fontWeight: "bold", }, list: { - margin: theme.spacing(1.5), + padding: theme.spacing(1), + margin: 0, }, listItem: { - margin: theme.spacing(1), + margin: theme.spacing(0.5), }, })) diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index 59abb4a913..8d0358bc58 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -821,6 +821,7 @@ export const makeMockApiError = ({ }) export const MockEntitlements: TypesGen.Entitlements = { + errors: [], warnings: [], has_license: false, features: {}, @@ -829,6 +830,7 @@ export const MockEntitlements: TypesGen.Entitlements = { } export const MockEntitlementsWithWarnings: TypesGen.Entitlements = { + errors: [], warnings: ["You are over your active user limit.", "And another thing."], has_license: true, experimental: false, @@ -852,6 +854,7 @@ export const MockEntitlementsWithWarnings: TypesGen.Entitlements = { } export const MockEntitlementsWithAuditLog: TypesGen.Entitlements = { + errors: [], warnings: [], has_license: true, experimental: false, diff --git a/site/src/xServices/entitlements/entitlementsXService.ts b/site/src/xServices/entitlements/entitlementsXService.ts index 83ed44d120..a1e8bb0d9b 100644 --- a/site/src/xServices/entitlements/entitlementsXService.ts +++ b/site/src/xServices/entitlements/entitlementsXService.ts @@ -20,6 +20,7 @@ export type EntitlementsEvent = | { type: "HIDE_MOCK_BANNER" } const emptyEntitlements = { + errors: [], warnings: [], features: {}, has_license: false, diff --git a/tailnet/conn.go b/tailnet/conn.go index 1b454d6346..e3af3786ec 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -48,7 +48,10 @@ type Options struct { Addresses []netip.Prefix DERPMap *tailcfg.DERPMap - Logger slog.Logger + // BlockEndpoints specifies whether P2P endpoints are blocked. + // If so, only DERPs can establish connections. + BlockEndpoints bool + Logger slog.Logger } // NewConn constructs a new Wireguard server that will accept connections from the addresses provided. @@ -175,6 +178,7 @@ func NewConn(options *Options) (*Conn, error) { wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter")))) dialContext, dialCancel := context.WithCancel(context.Background()) server := &Conn{ + blockEndpoints: options.BlockEndpoints, dialContext: dialContext, dialCancel: dialCancel, closed: make(chan struct{}), @@ -240,11 +244,12 @@ func IP() netip.Addr { // Conn is an actively listening Wireguard connection. type Conn struct { - dialContext context.Context - dialCancel context.CancelFunc - mutex sync.Mutex - closed chan struct{} - logger slog.Logger + dialContext context.Context + dialCancel context.CancelFunc + mutex sync.Mutex + closed chan struct{} + logger slog.Logger + blockEndpoints bool dialer *tsdial.Dialer tunDevice *tstun.Wrapper @@ -323,6 +328,8 @@ func (c *Conn) UpdateNodes(nodes []*Node) error { delete(c.peerMap, peer.ID) } for _, node := range nodes { + c.logger.Debug(context.Background(), "adding node", slog.F("node", node)) + peerStatus, ok := status.Peer[node.Key] peerNode := &tailcfg.Node{ ID: node.ID, @@ -339,6 +346,13 @@ func (c *Conn) UpdateNodes(nodes []*Node) error { // reason. TODO: @kylecarbs debug this! KeepAlive: ok && peerStatus.Active, } + // If no preferred DERP is provided, don't set an IP! + if node.PreferredDERP == 0 { + peerNode.DERP = "" + } + if c.blockEndpoints { + peerNode.Endpoints = nil + } c.peerMap[node.ID] = peerNode } c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) @@ -421,6 +435,7 @@ func (c *Conn) sendNode() { } node := &Node{ ID: c.netMap.SelfNode.ID, + AsOf: c.lastStatus, Key: c.netMap.SelfNode.Key, Addresses: c.netMap.SelfNode.Addresses, AllowedIPs: c.netMap.SelfNode.AllowedIPs, @@ -429,6 +444,9 @@ func (c *Conn) sendNode() { PreferredDERP: c.lastPreferredDERP, DERPLatency: c.lastDERPLatency, } + if c.blockEndpoints { + node.Endpoints = nil + } nodeCallback := c.nodeCallback if nodeCallback == nil { return diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index ee696b0925..4216bbc624 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "sync" + "time" "github.com/google/uuid" "golang.org/x/xerrors" @@ -14,10 +15,30 @@ import ( "tailscale.com/types/key" ) +// Coordinator exchanges nodes with agents to establish connections. +// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ +// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ +// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ +// Coordinators have different guarantees for HA support. +type Coordinator interface { + // Node returns an in-memory node by ID. + Node(id uuid.UUID) *Node + // ServeClient accepts a WebSocket connection that wants to connect to an agent + // with the specified ID. + ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error + // ServeAgent accepts a WebSocket connection to an agent that listens to + // incoming connections and publishes node updates. + ServeAgent(conn net.Conn, id uuid.UUID) error + // Close closes the coordinator. + Close() error +} + // Node represents a node in the network. type Node struct { // ID is used to identify the connection. ID tailcfg.NodeID `json:"id"` + // AsOf is the time the node was created. + AsOf time.Time `json:"as_of"` // Key is the Wireguard public key of the node. Key key.NodePublic `json:"key"` // DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections. @@ -75,48 +96,59 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func }, errChan } -// NewCoordinator constructs a new in-memory connection coordinator. -func NewCoordinator() *Coordinator { - return &Coordinator{ +// NewCoordinator constructs a new in-memory connection coordinator. This +// coordinator is incompatible with multiple Coder replicas as all node data is +// in-memory. +func NewCoordinator() Coordinator { + return &coordinator{ + closed: false, nodes: map[uuid.UUID]*Node{}, agentSockets: map[uuid.UUID]net.Conn{}, agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{}, } } -// Coordinator exchanges nodes with agents to establish connections. +// coordinator exchanges nodes with agents to establish connections entirely in-memory. +// The Enterprise implementation provides this for high-availability. // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ // This coordinator is incompatible with multiple Coder // replicas as all node data is in-memory. -type Coordinator struct { - mutex sync.Mutex +type coordinator struct { + mutex sync.Mutex + closed bool - // Maps agent and connection IDs to a node. + // nodes maps agent and connection IDs their respective node. nodes map[uuid.UUID]*Node - // Maps agent ID to an open socket. + // agentSockets maps agent IDs to their open websocket. agentSockets map[uuid.UUID]net.Conn - // Maps agent ID to connection ID for sending - // new node data as it comes in! + // agentToConnectionSockets maps agent IDs to connection IDs of conns that + // are subscribed to updates for that agent. agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn } // Node returns an in-memory node by ID. -func (c *Coordinator) Node(id uuid.UUID) *Node { +// If the node does not exist, nil is returned. +func (c *coordinator) Node(id uuid.UUID) *Node { c.mutex.Lock() defer c.mutex.Unlock() - node := c.nodes[id] - return node + return c.nodes[id] } -// ServeClient accepts a WebSocket connection that wants to -// connect to an agent with the specified ID. -func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { +// ServeClient accepts a WebSocket connection that wants to connect to an agent +// with the specified ID. +func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { c.mutex.Lock() + if c.closed { + c.mutex.Unlock() + return xerrors.New("coordinator is closed") + } + // When a new connection is requested, we update it with the latest // node of the agent. This allows the connection to establish. node, ok := c.nodes[agent] + c.mutex.Unlock() if ok { data, err := json.Marshal([]*Node{node}) if err != nil { @@ -129,6 +161,7 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) return xerrors.Errorf("write nodes: %w", err) } } + c.mutex.Lock() connectionSockets, ok := c.agentToConnectionSockets[agent] if !ok { connectionSockets = map[uuid.UUID]net.Conn{} @@ -156,47 +189,62 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) decoder := json.NewDecoder(conn) for { - var node Node - err := decoder.Decode(&node) - if errors.Is(err, io.EOF) { - return nil - } + err := c.handleNextClientMessage(id, agent, decoder) if err != nil { - return xerrors.Errorf("read json: %w", err) - } - c.mutex.Lock() - // Update the node of this client in our in-memory map. - // If an agent entirely shuts down and reconnects, it - // needs to be aware of all clients attempting to - // establish connections. - c.nodes[id] = &node - agentSocket, ok := c.agentSockets[agent] - if !ok { - c.mutex.Unlock() - continue - } - c.mutex.Unlock() - // Write the new node from this client to the actively - // connected agent. - data, err := json.Marshal([]*Node{&node}) - if err != nil { - c.mutex.Unlock() - return xerrors.Errorf("marshal nodes: %w", err) - } - _, err = agentSocket.Write(data) - if errors.Is(err, io.EOF) { - return nil - } - if err != nil { - return xerrors.Errorf("write json: %w", err) + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next client message: %w", err) } } } +func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error { + var node Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + // Update the node of this client in our in-memory map. If an agent entirely + // shuts down and reconnects, it needs to be aware of all clients attempting + // to establish connections. + c.nodes[id] = &node + + agentSocket, ok := c.agentSockets[agent] + if !ok { + c.mutex.Unlock() + return nil + } + c.mutex.Unlock() + + // Write the new node from this client to the actively connected agent. + data, err := json.Marshal([]*Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + _, err = agentSocket.Write(data) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("write json: %w", err) + } + + return nil +} + // ServeAgent accepts a WebSocket connection to an agent that // listens to incoming connections and publishes node updates. -func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { +func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { c.mutex.Lock() + if c.closed { + c.mutex.Unlock() + return xerrors.New("coordinator is closed") + } + sockets, ok := c.agentToConnectionSockets[id] if ok { // Publish all nodes that want to connect to the @@ -209,16 +257,16 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { } nodes = append(nodes, node) } + c.mutex.Unlock() data, err := json.Marshal(nodes) if err != nil { - c.mutex.Unlock() return xerrors.Errorf("marshal json: %w", err) } _, err = conn.Write(data) if err != nil { - c.mutex.Unlock() return xerrors.Errorf("write nodes: %w", err) } + c.mutex.Lock() } // If an old agent socket is connected, we close it @@ -239,36 +287,84 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error { decoder := json.NewDecoder(conn) for { - var node Node - err := decoder.Decode(&node) - if errors.Is(err, io.EOF) { - return nil - } + err := c.handleNextAgentMessage(id, decoder) if err != nil { - return xerrors.Errorf("read json: %w", err) + if errors.Is(err, io.EOF) { + return nil + } + return xerrors.Errorf("handle next agent message: %w", err) } - c.mutex.Lock() - c.nodes[id] = &node - connectionSockets, ok := c.agentToConnectionSockets[id] - if !ok { - c.mutex.Unlock() - continue - } - data, err := json.Marshal([]*Node{&node}) - if err != nil { - return xerrors.Errorf("marshal nodes: %w", err) - } - // Publish the new node to every listening socket. - var wg sync.WaitGroup - wg.Add(len(connectionSockets)) - for _, connectionSocket := range connectionSockets { - connectionSocket := connectionSocket + } +} + +func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error { + var node Node + err := decoder.Decode(&node) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + c.mutex.Lock() + c.nodes[id] = &node + connectionSockets, ok := c.agentToConnectionSockets[id] + if !ok { + c.mutex.Unlock() + return nil + } + data, err := json.Marshal([]*Node{&node}) + if err != nil { + return xerrors.Errorf("marshal nodes: %w", err) + } + + // Publish the new node to every listening socket. + var wg sync.WaitGroup + wg.Add(len(connectionSockets)) + for _, connectionSocket := range connectionSockets { + connectionSocket := connectionSocket + go func() { + _ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = connectionSocket.Write(data) + wg.Done() + }() + } + + c.mutex.Unlock() + wg.Wait() + return nil +} + +// Close closes all of the open connections in the coordinator and stops the +// coordinator from accepting new connections. +func (c *coordinator) Close() error { + c.mutex.Lock() + if c.closed { + return nil + } + c.closed = true + c.mutex.Unlock() + + wg := sync.WaitGroup{} + + wg.Add(len(c.agentSockets)) + for _, socket := range c.agentSockets { + socket := socket + go func() { + _ = socket.Close() + wg.Done() + }() + } + + for _, connMap := range c.agentToConnectionSockets { + wg.Add(len(connMap)) + for _, socket := range connMap { + socket := socket go func() { - _, _ = connectionSocket.Write(data) + _ = socket.Close() wg.Done() }() } - c.mutex.Unlock() - wg.Wait() } + + wg.Wait() + return nil } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index f3fdab88d5..a4a020dead 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -32,8 +32,8 @@ func TestCoordinator(t *testing.T) { require.Eventually(t, func() bool { return coordinator.Node(id) != nil }, testutil.WaitShort, testutil.IntervalFast) - err := client.Close() - require.NoError(t, err) + require.NoError(t, client.Close()) + require.NoError(t, server.Close()) <-errChan <-closeChan }) diff --git a/testutil/certificate.go b/testutil/certificate.go new file mode 100644 index 0000000000..1edc975746 --- /dev/null +++ b/testutil/certificate.go @@ -0,0 +1,53 @@ +package testutil + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func GenerateTLSCertificate(t testing.TB, commonName string) tls.Certificate { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + CommonName: commonName, + }, + DNSNames: []string{commonName}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 180), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + var certFile bytes.Buffer + require.NoError(t, err) + _, err = certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})) + require.NoError(t, err) + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + var keyFile bytes.Buffer + err = pem.Encode(&keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}) + require.NoError(t, err) + cert, err := tls.X509KeyPair(certFile.Bytes(), keyFile.Bytes()) + require.NoError(t, err) + return cert +}