feat: Add high availability for multiple replicas (#4555)

* feat: HA tailnet coordinator

* fixup! feat: HA tailnet coordinator

* fixup! feat: HA tailnet coordinator

* remove printlns

* close all connections on coordinator

* impelement high availability feature

* fixup! impelement high availability feature

* fixup! impelement high availability feature

* fixup! impelement high availability feature

* fixup! impelement high availability feature

* Add replicas

* Add DERP meshing to arbitrary addresses

* Move packages to highavailability folder

* Move coordinator to high availability package

* Add flags for HA

* Rename to replicasync

* Denest packages for replicas

* Add test for multiple replicas

* Fix coordination test

* Add HA to the helm chart

* Rename function pointer

* Add warnings for HA

* Add the ability to block endpoints

* Add flag to disable P2P connections

* Wow, I made the tests pass

* Add replicas endpoint

* Ensure close kills replica

* Update sql

* Add database latency to high availability

* Pipe TLS to DERP mesh

* Fix DERP mesh with TLS

* Add tests for TLS

* Fix replica sync TLS

* Fix RootCA for replica meshing

* Remove ID from replicasync

* Fix getting certificates for meshing

* Remove excessive locking

* Fix linting

* Store mesh key in the database

* Fix replica key for tests

* Fix types gen

* Fix unlocking unlocked

* Fix race in tests

* Update enterprise/derpmesh/derpmesh.go

Co-authored-by: Colin Adler <colin1adler@gmail.com>

* Rename to syncReplicas

* Reuse http client

* Delete old replicas on a CRON

* Fix race condition in connection tests

* Fix linting

* Fix nil type

* Move pubsub to in-memory for twenty test

* Add comment for configuration tweaking

* Fix leak with transport

* Fix close leak in derpmesh

* Fix race when creating server

* Remove handler update

* Skip test on Windows

* Fix DERP mesh test

* Wrap HTTP handler replacement in mutex

* Fix error message for relay

* Fix API handler for normal tests

* Fix speedtest

* Fix replica resend

* Fix derpmesh send

* Ping async

* Increase wait time of template version jobd

* Fix race when closing replica sync

* Add name to client

* Log the derpmap being used

* Don't connect if DERP is empty

* Improve agent coordinator logging

* Fix lock in coordinator

* Fix relay addr

* Fix race when updating durations

* Fix client publish race

* Run pubsub loop in a queue

* Store agent nodes in order

* Fix coordinator locking

* Check for closed pipe

Co-authored-by: Colin Adler <colin1adler@gmail.com>
This commit is contained in:
Kyle Carberry 2022-10-17 08:43:30 -05:00 committed by GitHub
parent dc3519e973
commit 2ba4a62a0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 3437 additions and 404 deletions

View File

@ -19,6 +19,7 @@
"derphttp",
"derpmap",
"devel",
"dflags",
"drpc",
"drpcconn",
"drpcmux",
@ -86,8 +87,10 @@
"ptytest",
"quickstart",
"reconfig",
"replicasync",
"retrier",
"rpty",
"SCIM",
"sdkproto",
"sdktrace",
"Signup",

View File

@ -170,6 +170,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
if a.isClosed() {
return
}
a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", derpMap))
if a.network != nil {
a.network.SetDERPMap(derpMap)
return

View File

@ -465,7 +465,7 @@ func TestAgent(t *testing.T) {
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
require.Eventually(t, func() bool {
_, err := conn.Ping()
_, err := conn.Ping(context.Background())
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
@ -483,9 +483,7 @@ func TestAgent(t *testing.T) {
t.Run("Speedtest", func(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("The minimum duration for a speedtest is hardcoded in Tailscale to 5s!")
}
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
derpMap := tailnettest.RunDERPAndSTUN(t)
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{
DERPMap: derpMap,

View File

@ -7,8 +7,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/provisioner/echo"
@ -67,11 +65,11 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.Eventually(t, func() bool {
_, err := dialer.Ping()
_, err := dialer.Ping(ctx)
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
cancelFunc()
@ -128,11 +126,11 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.Eventually(t, func() bool {
_, err := dialer.Ping()
_, err := dialer.Ping(ctx)
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
cancelFunc()
@ -189,11 +187,11 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.Eventually(t, func() bool {
_, err := dialer.Ping()
_, err := dialer.Ping(ctx)
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
cancelFunc()

View File

@ -13,6 +13,11 @@ func (r Root) Session() File {
return File(filepath.Join(string(r), "session"))
}
// ReplicaID is a unique identifier for the Coder server.
func (r Root) ReplicaID() File {
return File(filepath.Join(string(r), "replica_id"))
}
func (r Root) URL() File {
return File(filepath.Join(string(r), "url"))
}

View File

@ -19,7 +19,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
@ -115,7 +114,7 @@ func TestConfigSSH(t *testing.T) {
_ = agentCloser.Close()
}()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID)
agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer agentConn.Close()

View File

@ -85,6 +85,13 @@ func Flags() *codersdk.DeploymentFlags {
Description: "Addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections.",
Default: []string{"stun.l.google.com:19302"},
},
DerpServerRelayAddress: &codersdk.StringFlag{
Name: "DERP Server Relay Address",
Flag: "derp-server-relay-address",
EnvVar: "CODER_DERP_SERVER_RELAY_ADDRESS",
Description: "An HTTP address that is accessible by other replicas to relay DERP traffic. Required for high availability.",
Enterprise: true,
},
DerpConfigURL: &codersdk.StringFlag{
Name: "DERP Config URL",
Flag: "derp-config-url",

View File

@ -16,7 +16,6 @@ import (
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
@ -96,7 +95,7 @@ func portForward() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
if err != nil {
return err
}
@ -156,7 +155,7 @@ func portForward() *cobra.Command {
case <-ticker.C:
}
_, err = conn.Ping()
_, err = conn.Ping(ctx)
if err != nil {
continue
}

View File

@ -4,6 +4,7 @@ import (
"context"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
@ -100,8 +101,9 @@ func Core() []*cobra.Command {
}
func AGPL() []*cobra.Command {
all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, error) {
return coderd.New(o), nil
all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) {
api := coderd.New(o)
return api, api, nil
}))
return all
}

View File

@ -69,7 +69,7 @@ import (
)
// nolint:gocyclo
func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command {
func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command {
root := &cobra.Command{
Use: "server",
Short: "Start a Coder server",
@ -167,9 +167,10 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
}
defer listener.Close()
var tlsConfig *tls.Config
if dflags.TLSEnable.Value {
listener, err = configureServerTLS(
listener, dflags.TLSMinVersion.Value,
tlsConfig, err = configureTLS(
dflags.TLSMinVersion.Value,
dflags.TLSClientAuth.Value,
dflags.TLSCertFiles.Value,
dflags.TLSKeyFiles.Value,
@ -178,6 +179,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
if err != nil {
return xerrors.Errorf("configure tls: %w", err)
}
listener = tls.NewListener(listener, tlsConfig)
}
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
@ -328,6 +330,9 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
Experimental: ExperimentalEnabled(cmd),
DeploymentFlags: dflags,
}
if tlsConfig != nil {
options.TLSCertificates = tlsConfig.Certificates
}
if dflags.OAuth2GithubClientSecret.Value != "" {
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed,
@ -471,11 +476,14 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
), dflags.PromAddress.Value, "prometheus")()
}
coderAPI, err := newAPI(ctx, options)
// We use a separate closer so the Enterprise API
// can have it's own close functions. This is cleaner
// than abstracting the Coder API itself.
coderAPI, closer, err := newAPI(ctx, options)
if err != nil {
return err
}
defer coderAPI.Close()
defer closer.Close()
client := codersdk.New(localURL)
if dflags.TLSEnable.Value {
@ -893,7 +901,7 @@ func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, er
return certs, nil
}
func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (net.Listener, error) {
func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
@ -929,6 +937,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri
if err != nil {
return nil, xerrors.Errorf("load certificates: %w", err)
}
tlsConfig.Certificates = certs
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
// If there's only one certificate, return it.
if len(certs) == 1 {
@ -963,7 +972,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri
tlsConfig.ClientCAs = caPool
}
return tls.NewListener(listener, tlsConfig), nil
return tlsConfig, nil
}
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {

View File

@ -55,7 +55,9 @@ func speedtest() *cobra.Command {
if cliflag.IsSetBool(cmd, varVerbose) {
logger = logger.Leveled(slog.LevelDebug)
}
conn, err := client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID)
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{
Logger: logger,
})
if err != nil {
return err
}
@ -68,7 +70,7 @@ func speedtest() *cobra.Command {
return ctx.Err()
case <-ticker.C:
}
dur, err := conn.Ping()
dur, err := conn.Ping(ctx)
if err != nil {
continue
}

View File

@ -20,8 +20,6 @@ import (
"golang.org/x/term"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify"
@ -86,7 +84,7 @@ func ssh() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
if err != nil {
return err
}

View File

@ -72,7 +72,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
"deadline %v never updated", firstDeadline,
)
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, time.Second)
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, 3*time.Second)
}
}
@ -82,7 +82,9 @@ func TestWorkspaceActivityBump(t *testing.T) {
client, workspace, assertBumped := setupActivityTest(t)
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil), resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil),
})
require.NoError(t, err)
defer conn.Close()

View File

@ -1,6 +1,7 @@
package coderd
import (
"crypto/tls"
"crypto/x509"
"io"
"net/http"
@ -82,7 +83,10 @@ type Options struct {
TracerProvider trace.TracerProvider
AutoImportTemplates []AutoImportTemplate
TailnetCoordinator *tailnet.Coordinator
// TLSCertificates is used to mesh DERP servers securely.
TLSCertificates []tls.Certificate
TailnetCoordinator tailnet.Coordinator
DERPServer *derp.Server
DERPMap *tailcfg.DERPMap
MetricsCacheRefreshInterval time.Duration
@ -130,6 +134,9 @@ func New(options *Options) *API {
if options.TailnetCoordinator == nil {
options.TailnetCoordinator = tailnet.NewCoordinator()
}
if options.DERPServer == nil {
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
}
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
@ -168,7 +175,7 @@ func New(options *Options) *API {
api.Auditor.Store(&options.Auditor)
api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
@ -246,7 +253,7 @@ func New(options *Options) *API {
r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
r.Route("/derp", func(r chi.Router) {
r.Get("/", derphttp.Handler(api.derpServer).ServeHTTP)
r.Get("/", derphttp.Handler(api.DERPServer).ServeHTTP)
// This is used when UDP is blocked, and latency must be checked via HTTP(s).
r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@ -550,6 +557,7 @@ type API struct {
Auditor atomic.Pointer[audit.Auditor]
WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool]
WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer]
TailnetCoordinator atomic.Pointer[tailnet.Coordinator]
HTTPAuth *HTTPAuthorizer
// APIHandler serves "/api/v2"
@ -557,7 +565,6 @@ type API struct {
// RootHandler serves "/"
RootHandler chi.Router
derpServer *derp.Server
metricsCache *metricscache.Cache
siteHandler http.Handler
websocketWaitMutex sync.Mutex
@ -572,7 +579,10 @@ func (api *API) Close() error {
api.websocketWaitMutex.Unlock()
api.metricsCache.Close()
coordinator := api.TailnetCoordinator.Load()
if coordinator != nil {
_ = (*coordinator).Close()
}
return api.workspaceAgentCache.Close()
}

View File

@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
@ -23,6 +24,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -37,8 +39,10 @@ import (
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
"tailscale.com/derp"
"tailscale.com/net/stun/stuntest"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/nettype"
"cdr.dev/slog"
@ -60,6 +64,7 @@ import (
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
@ -77,12 +82,19 @@ type Options struct {
AutobuildTicker <-chan time.Time
AutobuildStats chan<- executor.Stats
Auditor audit.Auditor
TLSCertificates []tls.Certificate
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
IncludeProvisionerDaemon bool
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
DeploymentFlags *codersdk.DeploymentFlags
// Overriding the database is heavily discouraged.
// It should only be used in cases where multiple Coder
// test instances are running against the same database.
Database database.Store
Pubsub database.Pubsub
}
// New constructs a codersdk client connected to an in-memory API instance.
@ -116,7 +128,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer)
return client, closer
}
func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) {
func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.CancelFunc, *coderd.Options) {
if options == nil {
options = &Options{}
}
@ -137,23 +149,40 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
close(options.AutobuildStats)
})
}
db, pubsub := dbtestutil.NewDB(t)
if options.Database == nil {
options.Database, options.Pubsub = dbtestutil.NewDB(t)
}
ctx, cancelFunc := context.WithCancel(context.Background())
lifecycleExecutor := executor.New(
ctx,
db,
options.Database,
slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug),
options.AutobuildTicker,
).WithStatsChannel(options.AutobuildStats)
lifecycleExecutor.Run()
srv := httptest.NewUnstartedServer(nil)
var mutex sync.RWMutex
var handler http.Handler
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mutex.RLock()
defer mutex.RUnlock()
if handler != nil {
handler.ServeHTTP(w, r)
}
}))
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
srv.Start()
if options.TLSCertificates != nil {
srv.TLS = &tls.Config{
Certificates: options.TLSCertificates,
MinVersion: tls.VersionTLS12,
}
srv.StartTLS()
} else {
srv.Start()
}
t.Cleanup(srv.Close)
tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr)
@ -169,6 +198,9 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
t.Cleanup(stunCleanup)
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp")))
derpServer.SetMeshKey("test-key")
// match default with cli default
if options.SSHKeygenAlgorithm == "" {
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
@ -181,53 +213,59 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
require.NoError(t, err)
}
return srv, cancelFunc, &coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
AgentInactiveDisconnectTimeout: testutil.WaitShort,
AccessURL: serverURL,
AppHostname: options.AppHostname,
AppHostnameRegex: appHostnameRegex,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
CacheDir: t.TempDir(),
Database: db,
Pubsub: pubsub,
return func(h http.Handler) {
mutex.Lock()
defer mutex.Unlock()
handler = h
}, cancelFunc, &coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
AgentInactiveDisconnectTimeout: testutil.WaitShort,
AccessURL: serverURL,
AppHostname: options.AppHostname,
AppHostnameRegex: appHostnameRegex,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
CacheDir: t.TempDir(),
Database: options.Database,
Pubsub: options.Pubsub,
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
EmbeddedRelay: true,
RegionID: 1,
RegionCode: "coder",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
IPv4: "127.0.0.1",
DERPPort: derpPort,
STUNPort: stunAddr.Port,
InsecureForTests: true,
ForceHTTP: true,
}},
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
DERPServer: derpServer,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
TLSCertificates: options.TLSCertificates,
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
EmbeddedRelay: true,
RegionID: 1,
RegionCode: "coder",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
IPv4: "127.0.0.1",
DERPPort: derpPort,
STUNPort: stunAddr.Port,
InsecureForTests: true,
ForceHTTP: options.TLSCertificates == nil,
}},
},
},
},
},
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentFlags: options.DeploymentFlags,
}
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentFlags: options.DeploymentFlags,
}
}
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
@ -237,10 +275,10 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
if options == nil {
options = &Options{}
}
srv, cancelFunc, newOptions := NewOptions(t, options)
setHandler, cancelFunc, newOptions := NewOptions(t, options)
// We set the handler after server creation for the access URL.
coderAPI := coderd.New(newOptions)
srv.Config.Handler = coderAPI.RootHandler
setHandler(coderAPI.RootHandler)
var provisionerCloser io.Closer = nopcloser{}
if options.IncludeProvisionerDaemon {
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
@ -459,7 +497,7 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid
var err error
templateVersion, err = client.TemplateVersion(context.Background(), version)
return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil
}, testutil.WaitShort, testutil.IntervalFast)
}, testutil.WaitMedium, testutil.IntervalFast)
return templateVersion
}

View File

@ -107,11 +107,17 @@ type data struct {
workspaceApps []database.WorkspaceApp
workspaces []database.Workspace
licenses []database.License
replicas []database.Replica
deploymentID string
derpMeshKey string
lastLicenseID int32
}
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
return 0, nil
}
// InTx doesn't rollback data properly for in-memory yet.
func (q *fakeQuerier) InTx(fn func(database.Store) error) error {
q.mutex.Lock()
@ -2931,6 +2937,21 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
return q.deploymentID, nil
}
func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
q.derpMeshKey = id
return nil
}
func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.derpMeshKey, nil
}
func (q *fakeQuerier) InsertLicense(
_ context.Context, arg database.InsertLicenseParams,
) (database.License, error) {
@ -3196,3 +3217,70 @@ func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error {
return sql.ErrNoRows
}
func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for i, replica := range q.replicas {
if replica.UpdatedAt.Before(before) {
q.replicas = append(q.replicas[:i], q.replicas[i+1:]...)
}
}
return nil
}
func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
replica := database.Replica{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
StartedAt: arg.StartedAt,
UpdatedAt: arg.UpdatedAt,
Hostname: arg.Hostname,
RegionID: arg.RegionID,
RelayAddress: arg.RelayAddress,
Version: arg.Version,
DatabaseLatency: arg.DatabaseLatency,
}
q.replicas = append(q.replicas, replica)
return replica, nil
}
func (q *fakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
for index, replica := range q.replicas {
if replica.ID != arg.ID {
continue
}
replica.Hostname = arg.Hostname
replica.StartedAt = arg.StartedAt
replica.StoppedAt = arg.StoppedAt
replica.UpdatedAt = arg.UpdatedAt
replica.RelayAddress = arg.RelayAddress
replica.RegionID = arg.RegionID
replica.Version = arg.Version
replica.Error = arg.Error
replica.DatabaseLatency = arg.DatabaseLatency
q.replicas[index] = replica
return replica, nil
}
return database.Replica{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
replicas := make([]database.Replica, 0)
for _, replica := range q.replicas {
if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid {
replicas = append(replicas, replica)
}
}
return replicas, nil
}

View File

@ -12,6 +12,7 @@ import (
"context"
"database/sql"
"errors"
"time"
"github.com/jmoiron/sqlx"
"golang.org/x/xerrors"
@ -24,6 +25,7 @@ type Store interface {
// customQuerier contains custom queries that are not generated.
customQuerier
Ping(ctx context.Context) (time.Duration, error)
InTx(func(Store) error) error
}
@ -58,6 +60,13 @@ type sqlQuerier struct {
db DBTX
}
// Ping returns the time it takes to ping the database.
func (q *sqlQuerier) Ping(ctx context.Context) (time.Duration, error) {
start := time.Now()
err := q.sdb.PingContext(ctx)
return time.Since(start), err
}
// InTx performs database operations inside a transaction.
func (q *sqlQuerier) InTx(function func(Store) error) error {
if _, ok := q.db.(*sqlx.Tx); ok {

View File

@ -256,7 +256,8 @@ CREATE TABLE provisioner_daemons (
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone,
name character varying(64) NOT NULL,
provisioners provisioner_type[] NOT NULL
provisioners provisioner_type[] NOT NULL,
replica_id uuid
);
CREATE TABLE provisioner_job_logs (
@ -287,6 +288,20 @@ CREATE TABLE provisioner_jobs (
file_id uuid NOT NULL
);
CREATE TABLE replicas (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
started_at timestamp with time zone NOT NULL,
stopped_at timestamp with time zone,
updated_at timestamp with time zone NOT NULL,
hostname text NOT NULL,
region_id integer NOT NULL,
relay_address text NOT NULL,
database_latency integer NOT NULL,
version text NOT NULL,
error text DEFAULT ''::text NOT NULL
);
CREATE TABLE site_configs (
key character varying(256) NOT NULL,
value character varying(8192) NOT NULL

View File

@ -0,0 +1,2 @@
DROP TABLE replicas;
ALTER TABLE provisioner_daemons DROP COLUMN replica_id;

View File

@ -0,0 +1,28 @@
CREATE TABLE IF NOT EXISTS replicas (
-- A unique identifier for the replica that is stored on disk.
-- For persistent replicas, this will be reused.
-- For ephemeral replicas, this will be a new UUID for each one.
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
-- The time the replica was created.
started_at timestamp with time zone NOT NULL,
-- The time the replica was last seen.
stopped_at timestamp with time zone,
-- Updated periodically to ensure the replica is still alive.
updated_at timestamp with time zone NOT NULL,
-- Hostname is the hostname of the replica.
hostname text NOT NULL,
-- Region is the region the replica is in.
-- We only DERP mesh to the same region ID of a running replica.
region_id integer NOT NULL,
-- An address that should be accessible to other replicas.
relay_address text NOT NULL,
-- The latency of the replica to the database in microseconds.
database_latency int NOT NULL,
-- Version is the Coder version of the replica.
version text NOT NULL,
error text NOT NULL DEFAULT ''
);
-- Associates a provisioner daemon with a replica.
ALTER TABLE provisioner_daemons ADD COLUMN replica_id uuid;

View File

@ -508,6 +508,7 @@ type ProvisionerDaemon struct {
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"`
}
type ProvisionerJob struct {
@ -538,6 +539,20 @@ type ProvisionerJobLog struct {
Output string `db:"output" json:"output"`
}
type Replica struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
StartedAt time.Time `db:"started_at" json:"started_at"`
StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Hostname string `db:"hostname" json:"hostname"`
RegionID int32 `db:"region_id" json:"region_id"`
RelayAddress string `db:"relay_address" json:"relay_address"`
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
Version string `db:"version" json:"version"`
Error string `db:"error" json:"error"`
}
type SiteConfig struct {
Key string `db:"key" json:"key"`
Value string `db:"value" json:"value"`

View File

@ -47,8 +47,9 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
return nil
}
for _, listener := range listeners {
listener(context.Background(), message)
go listener(context.Background(), message)
}
return nil
}

View File

@ -26,6 +26,7 @@ type sqlcQuerier interface {
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteOldAgentStats(ctx context.Context) error
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
@ -38,6 +39,7 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetDERPMeshKey(ctx context.Context) (string, error)
GetDeploymentID(ctx context.Context) (string, error)
GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error)
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
@ -67,6 +69,7 @@ type sqlcQuerier interface {
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error)
GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error)
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error)
GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error)
GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error)
@ -123,6 +126,7 @@ type sqlcQuerier interface {
// every member of the org.
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertDERPMeshKey(ctx context.Context, value string) error
InsertDeploymentID(ctx context.Context, value string) error
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error)
@ -136,6 +140,7 @@ type sqlcQuerier interface {
InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error)
InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error)
InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error)
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error)
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
@ -156,6 +161,7 @@ type sqlcQuerier interface {
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error
UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error)
UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error
UpdateTemplateDeletedByID(ctx context.Context, arg UpdateTemplateDeletedByIDParams) error
UpdateTemplateMetaByID(ctx context.Context, arg UpdateTemplateMetaByIDParams) (Template, error)

View File

@ -2031,7 +2031,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar
const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one
SELECT
id, created_at, updated_at, name, provisioners
id, created_at, updated_at, name, provisioners, replica_id
FROM
provisioner_daemons
WHERE
@ -2047,13 +2047,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID)
&i.UpdatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
)
return i, err
}
const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many
SELECT
id, created_at, updated_at, name, provisioners
id, created_at, updated_at, name, provisioners, replica_id
FROM
provisioner_daemons
`
@ -2073,6 +2074,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
&i.UpdatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
); err != nil {
return nil, err
}
@ -2096,7 +2098,7 @@ INSERT INTO
provisioners
)
VALUES
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id
`
type InsertProvisionerDaemonParams struct {
@ -2120,6 +2122,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
&i.UpdatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
)
return i, err
}
@ -2577,6 +2580,177 @@ func (q *sqlQuerier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, a
return err
}
const deleteReplicasUpdatedBefore = `-- name: DeleteReplicasUpdatedBefore :exec
DELETE FROM replicas WHERE updated_at < $1
`
func (q *sqlQuerier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error {
_, err := q.db.ExecContext(ctx, deleteReplicasUpdatedBefore, updatedAt)
return err
}
const getReplicasUpdatedAfter = `-- name: GetReplicasUpdatedAfter :many
SELECT id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL
`
func (q *sqlQuerier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) {
rows, err := q.db.QueryContext(ctx, getReplicasUpdatedAfter, updatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Replica
for rows.Next() {
var i Replica
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.StartedAt,
&i.StoppedAt,
&i.UpdatedAt,
&i.Hostname,
&i.RegionID,
&i.RelayAddress,
&i.DatabaseLatency,
&i.Version,
&i.Error,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertReplica = `-- name: InsertReplica :one
INSERT INTO replicas (
id,
created_at,
started_at,
updated_at,
hostname,
region_id,
relay_address,
version,
database_latency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error
`
type InsertReplicaParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
StartedAt time.Time `db:"started_at" json:"started_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Hostname string `db:"hostname" json:"hostname"`
RegionID int32 `db:"region_id" json:"region_id"`
RelayAddress string `db:"relay_address" json:"relay_address"`
Version string `db:"version" json:"version"`
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
}
func (q *sqlQuerier) InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) {
row := q.db.QueryRowContext(ctx, insertReplica,
arg.ID,
arg.CreatedAt,
arg.StartedAt,
arg.UpdatedAt,
arg.Hostname,
arg.RegionID,
arg.RelayAddress,
arg.Version,
arg.DatabaseLatency,
)
var i Replica
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.StartedAt,
&i.StoppedAt,
&i.UpdatedAt,
&i.Hostname,
&i.RegionID,
&i.RelayAddress,
&i.DatabaseLatency,
&i.Version,
&i.Error,
)
return i, err
}
const updateReplica = `-- name: UpdateReplica :one
UPDATE replicas SET
updated_at = $2,
started_at = $3,
stopped_at = $4,
relay_address = $5,
region_id = $6,
hostname = $7,
version = $8,
error = $9,
database_latency = $10
WHERE id = $1 RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error
`
type UpdateReplicaParams struct {
ID uuid.UUID `db:"id" json:"id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
StartedAt time.Time `db:"started_at" json:"started_at"`
StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"`
RelayAddress string `db:"relay_address" json:"relay_address"`
RegionID int32 `db:"region_id" json:"region_id"`
Hostname string `db:"hostname" json:"hostname"`
Version string `db:"version" json:"version"`
Error string `db:"error" json:"error"`
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
}
func (q *sqlQuerier) UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) {
row := q.db.QueryRowContext(ctx, updateReplica,
arg.ID,
arg.UpdatedAt,
arg.StartedAt,
arg.StoppedAt,
arg.RelayAddress,
arg.RegionID,
arg.Hostname,
arg.Version,
arg.Error,
arg.DatabaseLatency,
)
var i Replica
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.StartedAt,
&i.StoppedAt,
&i.UpdatedAt,
&i.Hostname,
&i.RegionID,
&i.RelayAddress,
&i.DatabaseLatency,
&i.Version,
&i.Error,
)
return i, err
}
const getDERPMeshKey = `-- name: GetDERPMeshKey :one
SELECT value FROM site_configs WHERE key = 'derp_mesh_key'
`
func (q *sqlQuerier) GetDERPMeshKey(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getDERPMeshKey)
var value string
err := row.Scan(&value)
return value, err
}
const getDeploymentID = `-- name: GetDeploymentID :one
SELECT value FROM site_configs WHERE key = 'deployment_id'
`
@ -2588,6 +2762,15 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) {
return value, err
}
const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1)
`
func (q *sqlQuerier) InsertDERPMeshKey(ctx context.Context, value string) error {
_, err := q.db.ExecContext(ctx, insertDERPMeshKey, value)
return err
}
const insertDeploymentID = `-- name: InsertDeploymentID :exec
INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1)
`

View File

@ -0,0 +1,31 @@
-- name: GetReplicasUpdatedAfter :many
SELECT * FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL;
-- name: InsertReplica :one
INSERT INTO replicas (
id,
created_at,
started_at,
updated_at,
hostname,
region_id,
relay_address,
version,
database_latency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
-- name: UpdateReplica :one
UPDATE replicas SET
updated_at = $2,
started_at = $3,
stopped_at = $4,
relay_address = $5,
region_id = $6,
hostname = $7,
version = $8,
error = $9,
database_latency = $10
WHERE id = $1 RETURNING *;
-- name: DeleteReplicasUpdatedBefore :exec
DELETE FROM replicas WHERE updated_at < $1;

View File

@ -3,3 +3,9 @@ INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1);
-- name: GetDeploymentID :one
SELECT value FROM site_configs WHERE key = 'deployment_id';
-- name: InsertDERPMeshKey :exec
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1);
-- name: GetDERPMeshKey :one
SELECT value FROM site_configs WHERE key = 'derp_mesh_key';

View File

@ -270,7 +270,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
}
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading job agent.",

View File

@ -146,6 +146,10 @@ var (
ResourceDeploymentFlags = Object{
Type: "deployment_flags",
}
ResourceReplicas = Object{
Type: "replicas",
}
)
// Object is used to create objects for authz checks when you have none in

View File

@ -627,7 +627,9 @@ func TestTemplateMetrics(t *testing.T) {
require.NoError(t, err)
assert.Zero(t, workspaces[0].LastUsedAt)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("tailnet"), resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Named("tailnet"),
})
require.NoError(t, err)
defer func() {
_ = conn.Close()

View File

@ -49,7 +49,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -78,7 +78,7 @@ func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -98,7 +98,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -152,7 +152,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
httpapi.ResourceNotFound(rw)
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -229,7 +229,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -376,8 +376,9 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
})
conn.SetNodeCallback(sendNodes)
go func() {
err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID)
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
_ = conn.Close()
}
}()
@ -514,8 +515,9 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
closeChan := make(chan struct{})
go func() {
defer close(closeChan)
err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID)
err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID)
if err != nil {
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
@ -583,7 +585,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
go httpapi.Heartbeat(ctx, conn)
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetCoordinator.ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
@ -611,7 +613,7 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
return apps
}
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)

View File

@ -123,13 +123,13 @@ func TestWorkspaceAgentListen(t *testing.T) {
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
require.Eventually(t, func() bool {
_, err := conn.Ping()
_, err := conn.Ping(ctx)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
})
@ -253,7 +253,9 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
defer conn.Close()
sshClient, err := conn.SSHClient()

View File

@ -861,7 +861,7 @@ func (api *API) convertWorkspaceBuild(
apiAgents := make([]codersdk.WorkspaceAgent, 0)
for _, agent := range agents {
apps := appsByAgentID[agent.ID]
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(apps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(apps), api.AgentInactiveDisconnectTimeout)
if err != nil {
return codersdk.WorkspaceBuild{}, xerrors.Errorf("converting workspace agent: %w", err)
}

View File

@ -128,7 +128,9 @@ func TestCache(t *testing.T) {
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
transport := conn.HTTPTransport()
defer transport.CloseIdleConnections()
proxy.Transport = transport
res := httptest.NewRecorder()
proxy.ServeHTTP(res, req)
resp := res.Result()

View File

@ -132,10 +132,10 @@ type AgentConn struct {
CloseFunc func()
}
func (c *AgentConn) Ping() (time.Duration, error) {
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
errCh := make(chan error, 1)
durCh := make(chan time.Duration, 1)
c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
if pr.Err != "" {
errCh <- xerrors.New(pr.Err)
return
@ -145,6 +145,8 @@ func (c *AgentConn) Ping() (time.Duration, error) {
select {
case err := <-errCh:
return 0, err
case <-ctx.Done():
return 0, ctx.Err()
case dur := <-durCh:
return dur, nil
}

View File

@ -15,12 +15,13 @@ const (
)
const (
FeatureUserLimit = "user_limit"
FeatureAuditLog = "audit_log"
FeatureBrowserOnly = "browser_only"
FeatureSCIM = "scim"
FeatureWorkspaceQuota = "workspace_quota"
FeatureTemplateRBAC = "template_rbac"
FeatureUserLimit = "user_limit"
FeatureAuditLog = "audit_log"
FeatureBrowserOnly = "browser_only"
FeatureSCIM = "scim"
FeatureWorkspaceQuota = "workspace_quota"
FeatureTemplateRBAC = "template_rbac"
FeatureHighAvailability = "high_availability"
)
var FeatureNames = []string{
@ -30,6 +31,7 @@ var FeatureNames = []string{
FeatureSCIM,
FeatureWorkspaceQuota,
FeatureTemplateRBAC,
FeatureHighAvailability,
}
type Feature struct {
@ -42,6 +44,7 @@ type Feature struct {
type Entitlements struct {
Features map[string]Feature `json:"features"`
Warnings []string `json:"warnings"`
Errors []string `json:"errors"`
HasLicense bool `json:"has_license"`
Experimental bool `json:"experimental"`
Trial bool `json:"trial"`

View File

@ -19,6 +19,7 @@ type DeploymentFlags struct {
DerpServerRegionCode *StringFlag `json:"derp_server_region_code" typescript:",notnull"`
DerpServerRegionName *StringFlag `json:"derp_server_region_name" typescript:",notnull"`
DerpServerSTUNAddresses *StringArrayFlag `json:"derp_server_stun_address" typescript:",notnull"`
DerpServerRelayAddress *StringFlag `json:"derp_server_relay_address" typescript:",notnull"`
DerpConfigURL *StringFlag `json:"derp_config_url" typescript:",notnull"`
DerpConfigPath *StringFlag `json:"derp_config_path" typescript:",notnull"`
PromEnabled *BoolFlag `json:"prom_enabled" typescript:",notnull"`

44
codersdk/replicas.go Normal file
View File

@ -0,0 +1,44 @@
package codersdk
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
type Replica struct {
// ID is the unique identifier for the replica.
ID uuid.UUID `json:"id"`
// Hostname is the hostname of the replica.
Hostname string `json:"hostname"`
// CreatedAt is when the replica was first seen.
CreatedAt time.Time `json:"created_at"`
// RelayAddress is the accessible address to relay DERP connections.
RelayAddress string `json:"relay_address"`
// RegionID is the region of the replica.
RegionID int32 `json:"region_id"`
// Error is the error.
Error string `json:"error"`
// DatabaseLatency is the latency in microseconds to the database.
DatabaseLatency int32 `json:"database_latency"`
}
// Replicas fetches the list of replicas.
func (c *Client) Replicas(ctx context.Context) ([]Replica, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/v2/replicas", nil)
if err != nil {
return nil, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
var replicas []Replica
return replicas, json.NewDecoder(res.Body).Decode(&replicas)
}

View File

@ -21,7 +21,6 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
)
@ -316,7 +315,8 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err
Value: c.SessionToken,
}})
httpClient := &http.Client{
Jar: jar,
Jar: jar,
Transport: c.HTTPClient.Transport,
}
// nolint:bodyclose
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
@ -332,7 +332,17 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
}
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*AgentConn, error) {
// @typescript-ignore DialWorkspaceAgentOptions
type DialWorkspaceAgentOptions struct {
Logger slog.Logger
// BlockEndpoints forced a direct connection through DERP.
BlockEndpoints bool
}
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) {
if options == nil {
options = &DialWorkspaceAgentOptions{}
}
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil)
if err != nil {
return nil, err
@ -349,9 +359,10 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
ip := tailnet.IP()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
Logger: logger,
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
Logger: options.Logger,
BlockEndpoints: options.BlockEndpoints,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)
@ -370,7 +381,8 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
Value: c.SessionToken,
}})
httpClient := &http.Client{
Jar: jar,
Jar: jar,
Transport: c.HTTPClient.Transport,
}
ctx, cancelFunc := context.WithCancel(ctx)
closed := make(chan struct{})
@ -379,7 +391,7 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
defer close(closed)
isFirst := true
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
logger.Debug(ctx, "connecting")
options.Logger.Debug(ctx, "connecting")
// nolint:bodyclose
ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
@ -398,21 +410,21 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
if errors.Is(err, context.Canceled) {
return
}
logger.Debug(ctx, "failed to dial", slog.Error(err))
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
continue
}
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNode)
logger.Debug(ctx, "serving coordinator")
options.Logger.Debug(ctx, "serving coordinator")
err = <-errChan
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusGoingAway, "")
return
}
if err != nil {
logger.Debug(ctx, "error serving coordinator", slog.Error(err))
options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err))
_ = ws.Close(websocket.StatusGoingAway, "")
continue
}

View File

@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) {
var entitlements codersdk.Entitlements
err := json.Unmarshal(buf.Bytes(), &entitlements)
require.NoError(t, err, "unmarshal JSON output")
assert.Len(t, entitlements.Features, 6)
assert.Len(t, entitlements.Features, 7)
assert.Empty(t, entitlements.Warnings)
assert.Equal(t, codersdk.EntitlementNotEntitled,
entitlements.Features[codersdk.FeatureUserLimit].Entitlement)
@ -71,6 +71,8 @@ func TestFeaturesList(t *testing.T) {
entitlements.Features[codersdk.FeatureTemplateRBAC].Entitlement)
assert.Equal(t, codersdk.EntitlementNotEntitled,
entitlements.Features[codersdk.FeatureSCIM].Entitlement)
assert.Equal(t, codersdk.EntitlementNotEntitled,
entitlements.Features[codersdk.FeatureHighAvailability].Entitlement)
assert.False(t, entitlements.HasLicense)
assert.False(t, entitlements.Experimental)
})

View File

@ -2,11 +2,20 @@ package cli
import (
"context"
"database/sql"
"errors"
"io"
"net/url"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/types/key"
"github.com/coder/coder/cli/deployment"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/enterprise/coderd"
"github.com/coder/coder/tailnet"
agpl "github.com/coder/coder/cli"
agplcoderd "github.com/coder/coder/coderd"
@ -14,23 +23,49 @@ import (
func server() *cobra.Command {
dflags := deployment.Flags()
cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) {
cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, io.Closer, error) {
if dflags.DerpServerRelayAddress.Value != "" {
_, err := url.Parse(dflags.DerpServerRelayAddress.Value)
if err != nil {
return nil, nil, xerrors.Errorf("derp-server-relay-address must be a valid HTTP URL: %w", err)
}
}
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
meshKey, err := options.Database.GetDERPMeshKey(ctx)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return nil, nil, xerrors.Errorf("get mesh key: %w", err)
}
meshKey, err = cryptorand.String(32)
if err != nil {
return nil, nil, xerrors.Errorf("generate mesh key: %w", err)
}
err = options.Database.InsertDERPMeshKey(ctx, meshKey)
if err != nil {
return nil, nil, xerrors.Errorf("insert mesh key: %w", err)
}
}
options.DERPServer.SetMeshKey(meshKey)
o := &coderd.Options{
AuditLogging: dflags.AuditLogging.Value,
BrowserOnly: dflags.BrowserOnly.Value,
SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value),
UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value,
RBACEnabled: true,
Options: options,
AuditLogging: dflags.AuditLogging.Value,
BrowserOnly: dflags.BrowserOnly.Value,
SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value),
UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value,
RBAC: true,
DERPServerRelayAddress: dflags.DerpServerRelayAddress.Value,
DERPServerRegionID: dflags.DerpServerRegionID.Value,
Options: options,
}
api, err := coderd.New(ctx, o)
if err != nil {
return nil, err
return nil, nil, err
}
return api.AGPL, nil
return api.AGPL, api, nil
})
deployment.AttachFlags(cmd.Flags(), dflags, true)
return cmd
}

View File

@ -28,7 +28,7 @@ func TestCheckACLPermissions(t *testing.T) {
// Create adminClient, member, and org adminClient
adminUser := coderdtest.CreateFirstUser(t, adminClient)
_ = coderdenttest.AddLicense(t, adminClient, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
memberClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)

View File

@ -3,6 +3,8 @@ package coderd
import (
"context"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"net/http"
"sync"
"time"
@ -23,6 +25,10 @@ import (
"github.com/coder/coder/enterprise/audit"
"github.com/coder/coder/enterprise/audit/backends"
"github.com/coder/coder/enterprise/coderd/license"
"github.com/coder/coder/enterprise/derpmesh"
"github.com/coder/coder/enterprise/replicasync"
"github.com/coder/coder/enterprise/tailnet"
agpltailnet "github.com/coder/coder/tailnet"
)
// New constructs an Enterprise coderd API instance.
@ -47,6 +53,7 @@ func New(ctx context.Context, options *Options) (*API, error) {
Options: options,
cancelEntitlementsLoop: cancelFunc,
}
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
@ -59,6 +66,10 @@ func New(ctx context.Context, options *Options) (*API, error) {
api.AGPL.APIHandler.Group(func(r chi.Router) {
r.Get("/entitlements", api.serveEntitlements)
r.Route("/replicas", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/", api.replicas)
})
r.Route("/licenses", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Post("/", api.postLicense)
@ -117,7 +128,40 @@ func New(ctx context.Context, options *Options) (*API, error) {
})
}
err := api.updateEntitlements(ctx)
meshRootCA := x509.NewCertPool()
for _, certificate := range options.TLSCertificates {
for _, certificatePart := range certificate.Certificate {
certificate, err := x509.ParseCertificate(certificatePart)
if err != nil {
return nil, xerrors.Errorf("parse certificate %s: %w", certificate.Subject.CommonName, err)
}
meshRootCA.AddCert(certificate)
}
}
// This TLS configuration spoofs access from the access URL hostname
// assuming that the certificates provided will cover that hostname.
//
// Replica sync and DERP meshing require accessing replicas via their
// internal IP addresses, and if TLS is configured we use the same
// certificates.
meshTLSConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: options.TLSCertificates,
RootCAs: meshRootCA,
ServerName: options.AccessURL.Hostname(),
}
var err error
api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{
RelayAddress: options.DERPServerRelayAddress,
RegionID: int32(options.DERPServerRegionID),
TLSConfig: meshTLSConfig,
})
if err != nil {
return nil, xerrors.Errorf("initialize replica: %w", err)
}
api.derpMesh = derpmesh.New(options.Logger.Named("derpmesh"), api.DERPServer, meshTLSConfig)
err = api.updateEntitlements(ctx)
if err != nil {
return nil, xerrors.Errorf("update entitlements: %w", err)
}
@ -129,13 +173,17 @@ func New(ctx context.Context, options *Options) (*API, error) {
type Options struct {
*coderd.Options
RBACEnabled bool
RBAC bool
AuditLogging bool
// Whether to block non-browser connections.
BrowserOnly bool
SCIMAPIKey []byte
UserWorkspaceQuota int
// Used for high availability.
DERPServerRelayAddress string
DERPServerRegionID int
EntitlementsUpdateInterval time.Duration
Keys map[string]ed25519.PublicKey
}
@ -144,6 +192,11 @@ type API struct {
AGPL *coderd.API
*Options
// Detects multiple Coder replicas running at the same time.
replicaManager *replicasync.Manager
// Meshes DERP connections from multiple replicas.
derpMesh *derpmesh.Mesh
cancelEntitlementsLoop func()
entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
@ -151,6 +204,8 @@ type API struct {
func (api *API) Close() error {
api.cancelEntitlementsLoop()
_ = api.replicaManager.Close()
_ = api.derpMesh.Close()
return api.AGPL.Close()
}
@ -158,12 +213,13 @@ func (api *API) updateEntitlements(ctx context.Context) error {
api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, api.Keys, map[string]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
codersdk.FeatureBrowserOnly: api.BrowserOnly,
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0,
codersdk.FeatureTemplateRBAC: api.RBACEnabled,
entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, len(api.replicaManager.All()), api.Keys, map[string]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
codersdk.FeatureBrowserOnly: api.BrowserOnly,
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0,
codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "",
codersdk.FeatureTemplateRBAC: api.RBAC,
})
if err != nil {
return err
@ -209,6 +265,46 @@ func (api *API) updateEntitlements(ctx context.Context) error {
api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer)
}
if changed, enabled := featureChanged(codersdk.FeatureHighAvailability); changed {
coordinator := agpltailnet.NewCoordinator()
if enabled {
haCoordinator, err := tailnet.NewCoordinator(api.Logger, api.Pubsub)
if err != nil {
api.Logger.Error(ctx, "unable to set up high availability coordinator", slog.Error(err))
// If we try to setup the HA coordinator and it fails, nothing
// is actually changing.
changed = false
} else {
coordinator = haCoordinator
}
api.replicaManager.SetCallback(func() {
addresses := make([]string, 0)
for _, replica := range api.replicaManager.Regional() {
addresses = append(addresses, replica.RelayAddress)
}
api.derpMesh.SetAddresses(addresses, false)
_ = api.updateEntitlements(ctx)
})
} else {
api.derpMesh.SetAddresses([]string{}, false)
api.replicaManager.SetCallback(func() {
// If the amount of replicas change, so should our entitlements.
// This is to display a warning in the UI if the user is unlicensed.
_ = api.updateEntitlements(ctx)
})
}
// Recheck changed in case the HA coordinator failed to set up.
if changed {
oldCoordinator := *api.AGPL.TailnetCoordinator.Swap(&coordinator)
err := oldCoordinator.Close()
if err != nil {
api.Logger.Error(ctx, "close old tailnet coordinator", slog.Error(err))
}
}
}
api.entitlements = entitlements
return nil

View File

@ -41,9 +41,9 @@ func TestEntitlements(t *testing.T) {
})
_ = coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
TemplateRBACEnabled: true,
UserLimit: 100,
AuditLog: true,
TemplateRBAC: true,
})
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
@ -85,7 +85,7 @@ func TestEntitlements(t *testing.T) {
assert.False(t, res.HasLicense)
al = res.Features[codersdk.FeatureAuditLog]
assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement)
assert.True(t, al.Enabled)
assert.False(t, al.Enabled)
})
t.Run("Pubsub", func(t *testing.T) {
t.Parallel()

View File

@ -4,7 +4,9 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"io"
"net/http"
"testing"
"time"
@ -60,19 +62,21 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
if options.Options == nil {
options.Options = &coderdtest.Options{}
}
srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options)
setHandler, cancelFunc, oop := coderdtest.NewOptions(t, options.Options)
coderAPI, err := coderd.New(context.Background(), &coderd.Options{
RBACEnabled: true,
RBAC: true,
AuditLogging: options.AuditLogging,
BrowserOnly: options.BrowserOnly,
SCIMAPIKey: options.SCIMAPIKey,
DERPServerRelayAddress: oop.AccessURL.String(),
DERPServerRegionID: oop.DERPMap.RegionIDs()[0],
UserWorkspaceQuota: options.UserWorkspaceQuota,
Options: oop,
EntitlementsUpdateInterval: options.EntitlementsUpdateInterval,
Keys: Keys,
})
assert.NoError(t, err)
srv.Config.Handler = coderAPI.AGPL.RootHandler
setHandler(coderAPI.AGPL.RootHandler)
var provisionerCloser io.Closer = nopcloser{}
if options.IncludeProvisionerDaemon {
provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL)
@ -83,22 +87,32 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
_ = provisionerCloser.Close()
_ = coderAPI.Close()
})
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
client := codersdk.New(coderAPI.AccessURL)
client.HTTPClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
},
}
return client, provisionerCloser, coderAPI
}
type LicenseOptions struct {
AccountType string
AccountID string
Trial bool
AllFeatures bool
GraceAt time.Time
ExpiresAt time.Time
UserLimit int64
AuditLog bool
BrowserOnly bool
SCIM bool
WorkspaceQuota bool
TemplateRBACEnabled bool
AccountType string
AccountID string
Trial bool
AllFeatures bool
GraceAt time.Time
ExpiresAt time.Time
UserLimit int64
AuditLog bool
BrowserOnly bool
SCIM bool
WorkspaceQuota bool
TemplateRBAC bool
HighAvailability bool
}
// AddLicense generates a new license with the options provided and inserts it.
@ -134,9 +148,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
if options.WorkspaceQuota {
workspaceQuota = 1
}
highAvailability := int64(0)
if options.HighAvailability {
highAvailability = 1
}
rbacEnabled := int64(0)
if options.TemplateRBACEnabled {
if options.TemplateRBAC {
rbacEnabled = 1
}
@ -154,12 +172,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
Version: license.CurrentVersion,
AllFeatures: options.AllFeatures,
Features: license.Features{
UserLimit: options.UserLimit,
AuditLog: auditLog,
BrowserOnly: browserOnly,
SCIM: scim,
WorkspaceQuota: workspaceQuota,
TemplateRBAC: rbacEnabled,
UserLimit: options.UserLimit,
AuditLog: auditLog,
BrowserOnly: browserOnly,
SCIM: scim,
WorkspaceQuota: workspaceQuota,
HighAvailability: highAvailability,
TemplateRBAC: rbacEnabled,
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)

View File

@ -33,7 +33,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
ctx, _ := testutil.Context(t)
admin := coderdtest.CreateFirstUser(t, client)
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
group, err := client.CreateGroup(ctx, admin.OrganizationID, codersdk.CreateGroupRequest{
Name: "testgroup",
@ -58,6 +58,10 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceLicense,
}
assertRoute["GET:/api/v2/replicas"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceReplicas,
}
assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionDelete,
AssertObject: rbac.ResourceLicense,

View File

@ -24,7 +24,7 @@ func TestCreateGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -43,7 +43,7 @@ func TestCreateGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
_, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -67,7 +67,7 @@ func TestCreateGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
_, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -90,7 +90,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -112,7 +112,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -138,7 +138,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -173,7 +173,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -197,7 +197,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -221,7 +221,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
ctx, _ := testutil.Context(t)
@ -247,7 +247,7 @@ func TestPatchGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -276,7 +276,7 @@ func TestGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -296,7 +296,7 @@ func TestGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -326,7 +326,7 @@ func TestGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -347,7 +347,7 @@ func TestGroup(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -380,7 +380,7 @@ func TestGroup(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -421,7 +421,7 @@ func TestGroups(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -467,7 +467,7 @@ func TestDeleteGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
group1, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@ -492,7 +492,7 @@ func TestDeleteGroup(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
ctx, _ := testutil.Context(t)
err := client.DeleteGroup(ctx, user.OrganizationID)

View File

@ -17,12 +17,20 @@ import (
)
// Entitlements processes licenses to return whether features are enabled or not.
func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) {
func Entitlements(
ctx context.Context,
db database.Store,
logger slog.Logger,
replicaCount int,
keys map[string]ed25519.PublicKey,
enablements map[string]bool,
) (codersdk.Entitlements, error) {
now := time.Now()
// Default all entitlements to be disabled.
entitlements := codersdk.Entitlements{
Features: map[string]codersdk.Feature{},
Warnings: []string{},
Errors: []string{},
}
for _, featureName := range codersdk.FeatureNames {
entitlements.Features[featureName] = codersdk.Feature{
@ -96,6 +104,12 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
Enabled: enablements[codersdk.FeatureWorkspaceQuota],
}
}
if claims.Features.HighAvailability > 0 {
entitlements.Features[codersdk.FeatureHighAvailability] = codersdk.Feature{
Entitlement: entitlement,
Enabled: enablements[codersdk.FeatureHighAvailability],
}
}
if claims.Features.TemplateRBAC > 0 {
entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{
Entitlement: entitlement,
@ -132,6 +146,10 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
if featureName == codersdk.FeatureUserLimit {
continue
}
// High availability has it's own warnings based on replica count!
if featureName == codersdk.FeatureHighAvailability {
continue
}
feature := entitlements.Features[featureName]
if !feature.Enabled {
continue
@ -141,9 +159,6 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
case codersdk.EntitlementNotEntitled:
entitlements.Warnings = append(entitlements.Warnings,
fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName))
// Disable the feature and add a warning...
feature.Enabled = false
entitlements.Features[featureName] = feature
case codersdk.EntitlementGracePeriod:
entitlements.Warnings = append(entitlements.Warnings,
fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName))
@ -152,6 +167,32 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
}
}
if replicaCount > 1 {
feature := entitlements.Features[codersdk.FeatureHighAvailability]
switch feature.Entitlement {
case codersdk.EntitlementNotEntitled:
if entitlements.HasLicense {
entitlements.Errors = append(entitlements.Warnings,
"You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.")
} else {
entitlements.Errors = append(entitlements.Warnings,
"You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.")
}
case codersdk.EntitlementGracePeriod:
entitlements.Warnings = append(entitlements.Warnings,
"You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.")
}
}
for _, featureName := range codersdk.FeatureNames {
feature := entitlements.Features[featureName]
if feature.Entitlement == codersdk.EntitlementNotEntitled {
feature.Enabled = false
entitlements.Features[featureName] = feature
}
}
return entitlements, nil
}
@ -171,12 +212,13 @@ var (
)
type Features struct {
UserLimit int64 `json:"user_limit"`
AuditLog int64 `json:"audit_log"`
BrowserOnly int64 `json:"browser_only"`
SCIM int64 `json:"scim"`
WorkspaceQuota int64 `json:"workspace_quota"`
TemplateRBAC int64 `json:"template_rbac"`
UserLimit int64 `json:"user_limit"`
AuditLog int64 `json:"audit_log"`
BrowserOnly int64 `json:"browser_only"`
SCIM int64 `json:"scim"`
WorkspaceQuota int64 `json:"workspace_quota"`
TemplateRBAC int64 `json:"template_rbac"`
HighAvailability int64 `json:"high_availability"`
}
type Claims struct {

View File

@ -20,17 +20,18 @@ import (
func TestEntitlements(t *testing.T) {
t.Parallel()
all := map[string]bool{
codersdk.FeatureAuditLog: true,
codersdk.FeatureBrowserOnly: true,
codersdk.FeatureSCIM: true,
codersdk.FeatureWorkspaceQuota: true,
codersdk.FeatureTemplateRBAC: true,
codersdk.FeatureAuditLog: true,
codersdk.FeatureBrowserOnly: true,
codersdk.FeatureSCIM: true,
codersdk.FeatureWorkspaceQuota: true,
codersdk.FeatureHighAvailability: true,
codersdk.FeatureTemplateRBAC: true,
}
t.Run("Defaults", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
require.NoError(t, err)
require.False(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -46,7 +47,7 @@ func TestEntitlements(t *testing.T) {
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -60,16 +61,17 @@ func TestEntitlements(t *testing.T) {
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
WorkspaceQuota: true,
TemplateRBACEnabled: true,
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
WorkspaceQuota: true,
HighAvailability: true,
TemplateRBAC: true,
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -82,18 +84,19 @@ func TestEntitlements(t *testing.T) {
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
WorkspaceQuota: true,
TemplateRBACEnabled: true,
GraceAt: time.Now().Add(-time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
WorkspaceQuota: true,
HighAvailability: true,
TemplateRBAC: true,
GraceAt: time.Now().Add(-time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -101,6 +104,9 @@ func TestEntitlements(t *testing.T) {
if featureName == codersdk.FeatureUserLimit {
continue
}
if featureName == codersdk.FeatureHighAvailability {
continue
}
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
require.Equal(t, codersdk.EntitlementGracePeriod, entitlements.Features[featureName].Entitlement)
require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName))
@ -113,7 +119,7 @@ func TestEntitlements(t *testing.T) {
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -121,6 +127,9 @@ func TestEntitlements(t *testing.T) {
if featureName == codersdk.FeatureUserLimit {
continue
}
if featureName == codersdk.FeatureHighAvailability {
continue
}
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
// Ensures features that are not entitled are properly disabled.
require.False(t, entitlements.Features[featureName].Enabled)
@ -139,7 +148,7 @@ func TestEntitlements(t *testing.T) {
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.Contains(t, entitlements.Warnings, "Your deployment has 2 active users but is only licensed for 1.")
@ -161,7 +170,7 @@ func TestEntitlements(t *testing.T) {
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.Empty(t, entitlements.Warnings)
@ -184,7 +193,7 @@ func TestEntitlements(t *testing.T) {
}),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -199,7 +208,7 @@ func TestEntitlements(t *testing.T) {
AllFeatures: true,
}),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.False(t, entitlements.Trial)
@ -211,4 +220,52 @@ func TestEntitlements(t *testing.T) {
require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement)
}
})
t.Run("MultipleReplicasNoLicense", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, all)
require.NoError(t, err)
require.False(t, entitlements.HasLicense)
require.Len(t, entitlements.Errors, 1)
require.Equal(t, "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.", entitlements.Errors[0])
})
t.Run("MultipleReplicasNotEntitled", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
Exp: time.Now().Add(time.Hour),
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
AuditLog: true,
}),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{
codersdk.FeatureHighAvailability: true,
})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.Len(t, entitlements.Errors, 1)
require.Equal(t, "You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.", entitlements.Errors[0])
})
t.Run("MultipleReplicasGrace", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
HighAvailability: true,
GraceAt: time.Now().Add(-time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
}),
Exp: time.Now().Add(time.Hour),
})
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{
codersdk.FeatureHighAvailability: true,
})
require.NoError(t, err)
require.True(t, entitlements.HasLicense)
require.Len(t, entitlements.Warnings, 1)
require.Equal(t, "You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.", entitlements.Warnings[0])
})
}

View File

@ -78,21 +78,21 @@ func TestGetLicense(t *testing.T) {
defer cancel()
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountID: "testing",
AuditLog: true,
SCIM: true,
BrowserOnly: true,
TemplateRBACEnabled: true,
AccountID: "testing",
AuditLog: true,
SCIM: true,
BrowserOnly: true,
TemplateRBAC: true,
})
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountID: "testing2",
AuditLog: true,
SCIM: true,
BrowserOnly: true,
Trial: true,
UserLimit: 200,
TemplateRBACEnabled: false,
AccountID: "testing2",
AuditLog: true,
SCIM: true,
BrowserOnly: true,
Trial: true,
UserLimit: 200,
TemplateRBAC: false,
})
licenses, err := client.Licenses(ctx)
@ -101,23 +101,25 @@ func TestGetLicense(t *testing.T) {
assert.Equal(t, int32(1), licenses[0].ID)
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("0"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureWorkspaceQuota: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("1"),
codersdk.FeatureUserLimit: json.Number("0"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureWorkspaceQuota: json.Number("0"),
codersdk.FeatureHighAvailability: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("1"),
}, licenses[0].Claims["features"])
assert.Equal(t, int32(2), licenses[1].ID)
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
assert.Equal(t, true, licenses[1].Claims["trial"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureWorkspaceQuota: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("0"),
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureWorkspaceQuota: json.Number("0"),
codersdk.FeatureHighAvailability: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("0"),
}, licenses[1].Claims["features"])
})
}

View File

@ -0,0 +1,37 @@
package coderd
import (
"net/http"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
// replicas returns the number of replicas that are active in Coder.
func (api *API) replicas(rw http.ResponseWriter, r *http.Request) {
if !api.AGPL.Authorize(r, rbac.ActionRead, rbac.ResourceReplicas) {
httpapi.ResourceNotFound(rw)
return
}
replicas := api.replicaManager.All()
res := make([]codersdk.Replica, 0, len(replicas))
for _, replica := range replicas {
res = append(res, convertReplica(replica))
}
httpapi.Write(r.Context(), rw, http.StatusOK, res)
}
func convertReplica(replica database.Replica) codersdk.Replica {
return codersdk.Replica{
ID: replica.ID,
Hostname: replica.Hostname,
CreatedAt: replica.CreatedAt,
RelayAddress: replica.RelayAddress,
RegionID: replica.RegionID,
Error: replica.Error,
DatabaseLatency: replica.DatabaseLatency,
}
}

View File

@ -0,0 +1,138 @@
package coderd_test
import (
"context"
"crypto/tls"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/testutil"
)
func TestReplicas(t *testing.T) {
t.Parallel()
t.Run("ErrorWithoutLicense", func(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
firstClient := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
IncludeProvisionerDaemon: true,
Database: db,
Pubsub: pubsub,
},
})
_ = coderdtest.CreateFirstUser(t, firstClient)
secondClient, _, secondAPI := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
},
})
secondClient.SessionToken = firstClient.SessionToken
ents, err := secondClient.Entitlements(context.Background())
require.NoError(t, err)
require.Len(t, ents.Errors, 1)
_ = secondAPI.Close()
ents, err = firstClient.Entitlements(context.Background())
require.NoError(t, err)
require.Len(t, ents.Warnings, 0)
})
t.Run("ConnectAcrossMultiple", func(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
firstClient := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
IncludeProvisionerDaemon: true,
Database: db,
Pubsub: pubsub,
},
})
firstUser := coderdtest.CreateFirstUser(t, firstClient)
coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{
HighAvailability: true,
})
secondClient := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
},
})
secondClient.SessionToken = firstClient.SessionToken
replicas, err := secondClient.Replicas(context.Background())
require.NoError(t, err)
require.Len(t, replicas, 2)
_, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0)
conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{
BlockEndpoints: true,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
require.NoError(t, err)
require.Eventually(t, func() bool {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancelFunc()
_, err = conn.Ping(ctx)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
_ = conn.Close()
})
t.Run("ConnectAcrossMultipleTLS", func(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
certificates := []tls.Certificate{testutil.GenerateTLSCertificate(t, "localhost")}
firstClient := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
IncludeProvisionerDaemon: true,
Database: db,
Pubsub: pubsub,
TLSCertificates: certificates,
},
})
firstUser := coderdtest.CreateFirstUser(t, firstClient)
coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{
HighAvailability: true,
})
secondClient := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
TLSCertificates: certificates,
},
})
secondClient.SessionToken = firstClient.SessionToken
replicas, err := secondClient.Replicas(context.Background())
require.NoError(t, err)
require.Len(t, replicas, 2)
_, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0)
conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{
BlockEndpoints: true,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
require.Eventually(t, func() bool {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow)
defer cancelFunc()
_, err = conn.Ping(ctx)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
_ = conn.Close()
replicas, err = secondClient.Replicas(context.Background())
require.NoError(t, err)
require.Len(t, replicas, 2)
for _, replica := range replicas {
require.Empty(t, replica.Error)
}
})
}

View File

@ -23,7 +23,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -64,7 +64,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -88,7 +88,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -138,7 +138,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -176,7 +176,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -214,7 +214,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
@ -262,7 +262,7 @@ func TestTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -318,7 +318,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -361,7 +361,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -422,7 +422,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
@ -447,7 +447,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
@ -472,7 +472,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -498,7 +498,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -533,7 +533,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -575,7 +575,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
@ -597,7 +597,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
@ -662,7 +662,7 @@ func TestUpdateTemplateACL(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)

View File

@ -2,6 +2,7 @@ package coderd_test
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"testing"
@ -9,7 +10,6 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
@ -42,7 +42,7 @@ func TestBlockNonBrowser(t *testing.T) {
BrowserOnly: true,
})
_, agent := setupWorkspaceAgent(t, client, user, 0)
_, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID)
_, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
@ -59,7 +59,7 @@ func TestBlockNonBrowser(t *testing.T) {
BrowserOnly: false,
})
_, agent := setupWorkspaceAgent(t, client, user, 0)
conn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID)
conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil)
require.NoError(t, err)
_ = conn.Close()
})
@ -109,6 +109,14 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.HTTPClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
},
}
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,

View File

@ -26,7 +26,7 @@ func TestCreateWorkspace(t *testing.T) {
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBACEnabled: true,
TemplateRBAC: true,
})
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)

View File

@ -0,0 +1,165 @@
package derpmesh
import (
"context"
"crypto/tls"
"net"
"net/url"
"sync"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/types/key"
"github.com/coder/coder/tailnet"
"cdr.dev/slog"
)
// New constructs a new mesh for DERP servers.
func New(logger slog.Logger, server *derp.Server, tlsConfig *tls.Config) *Mesh {
return &Mesh{
logger: logger,
server: server,
tlsConfig: tlsConfig,
ctx: context.Background(),
closed: make(chan struct{}),
active: make(map[string]context.CancelFunc),
}
}
type Mesh struct {
logger slog.Logger
server *derp.Server
ctx context.Context
tlsConfig *tls.Config
mutex sync.Mutex
closed chan struct{}
active map[string]context.CancelFunc
}
// SetAddresses performs a diff of the incoming addresses and adds
// or removes DERP clients from the mesh.
//
// Connect is only used for testing to ensure DERPs are meshed before
// exchanging messages.
// nolint:revive
func (m *Mesh) SetAddresses(addresses []string, connect bool) {
total := make(map[string]struct{}, 0)
for _, address := range addresses {
addressURL, err := url.Parse(address)
if err != nil {
m.logger.Error(m.ctx, "invalid address", slog.F("address", err), slog.Error(err))
continue
}
derpURL, err := addressURL.Parse("/derp")
if err != nil {
m.logger.Error(m.ctx, "parse derp", slog.F("address", err), slog.Error(err))
continue
}
address = derpURL.String()
total[address] = struct{}{}
added, err := m.addAddress(address, connect)
if err != nil {
m.logger.Error(m.ctx, "failed to add address", slog.F("address", address), slog.Error(err))
continue
}
if added {
m.logger.Debug(m.ctx, "added mesh address", slog.F("address", address))
}
}
m.mutex.Lock()
for address := range m.active {
_, found := total[address]
if found {
continue
}
removed := m.removeAddress(address)
if removed {
m.logger.Debug(m.ctx, "removed mesh address", slog.F("address", address))
}
}
m.mutex.Unlock()
}
// addAddress begins meshing with a new address. It returns false if the address is already being meshed with.
// It's expected that this is a full HTTP address with a path.
// e.g. http://127.0.0.1:8080/derp
// nolint:revive
func (m *Mesh) addAddress(address string, connect bool) (bool, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.isClosed() {
return false, nil
}
_, isActive := m.active[address]
if isActive {
return false, nil
}
client, err := derphttp.NewClient(m.server.PrivateKey(), address, tailnet.Logger(m.logger.Named("client")))
if err != nil {
return false, xerrors.Errorf("create derp client: %w", err)
}
client.TLSConfig = m.tlsConfig
client.MeshKey = m.server.MeshKey()
client.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
})
if connect {
_ = client.Connect(m.ctx)
}
ctx, cancelFunc := context.WithCancel(m.ctx)
closed := make(chan struct{})
closeFunc := func() {
cancelFunc()
_ = client.Close()
<-closed
}
m.active[address] = closeFunc
go func() {
defer close(closed)
client.RunWatchConnectionLoop(ctx, m.server.PublicKey(), tailnet.Logger(m.logger.Named("loop")), func(np key.NodePublic) {
m.server.AddPacketForwarder(np, client)
}, func(np key.NodePublic) {
m.server.RemovePacketForwarder(np, client)
})
}()
return true, nil
}
// removeAddress stops meshing with a given address.
func (m *Mesh) removeAddress(address string) bool {
cancelFunc, isActive := m.active[address]
if isActive {
cancelFunc()
}
return isActive
}
// Close ends all active meshes with the DERP server.
func (m *Mesh) Close() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.isClosed() {
return nil
}
close(m.closed)
for _, cancelFunc := range m.active {
cancelFunc()
}
return nil
}
func (m *Mesh) isClosed() bool {
select {
case <-m.closed:
return true
default:
}
return false
}

View File

@ -0,0 +1,219 @@
package derpmesh_test
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/enterprise/derpmesh"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDERPMesh(t *testing.T) {
t.Parallel()
commonName := "something.org"
rawCert := testutil.GenerateTLSCertificate(t, commonName)
certificate, err := x509.ParseCertificate(rawCert.Certificate[0])
require.NoError(t, err)
pool := x509.NewCertPool()
pool.AddCert(certificate)
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: commonName,
RootCAs: pool,
Certificates: []tls.Certificate{rawCert},
}
t.Run("ExchangeMessages", func(t *testing.T) {
// This tests messages passing through multiple DERP servers.
t.Parallel()
firstServer, firstServerURL := startDERP(t, tlsConfig)
defer firstServer.Close()
secondServer, secondServerURL := startDERP(t, tlsConfig)
firstMesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), firstServer, tlsConfig)
firstMesh.SetAddresses([]string{secondServerURL}, true)
secondMesh := derpmesh.New(slogtest.Make(t, nil).Named("second").Leveled(slog.LevelDebug), secondServer, tlsConfig)
secondMesh.SetAddresses([]string{firstServerURL}, true)
defer firstMesh.Close()
defer secondMesh.Close()
first := key.NewNode()
second := key.NewNode()
firstClient, err := derphttp.NewClient(first, secondServerURL, tailnet.Logger(slogtest.Make(t, nil)))
require.NoError(t, err)
firstClient.TLSConfig = tlsConfig
secondClient, err := derphttp.NewClient(second, firstServerURL, tailnet.Logger(slogtest.Make(t, nil)))
require.NoError(t, err)
secondClient.TLSConfig = tlsConfig
err = secondClient.Connect(context.Background())
require.NoError(t, err)
closed := make(chan struct{})
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
sent := []byte("hello world")
go func() {
defer close(closed)
ticker := time.NewTicker(50 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
err = firstClient.Send(second.Public(), sent)
require.NoError(t, err)
}
}()
got := recvData(t, secondClient)
require.Equal(t, sent, got)
cancelFunc()
<-closed
})
t.Run("RemoveAddress", func(t *testing.T) {
// This tests messages passing through multiple DERP servers.
t.Parallel()
server, serverURL := startDERP(t, tlsConfig)
mesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), server, tlsConfig)
mesh.SetAddresses([]string{"http://fake.com"}, false)
// This should trigger a removal...
mesh.SetAddresses([]string{}, false)
defer mesh.Close()
first := key.NewNode()
second := key.NewNode()
firstClient, err := derphttp.NewClient(first, serverURL, tailnet.Logger(slogtest.Make(t, nil)))
require.NoError(t, err)
firstClient.TLSConfig = tlsConfig
secondClient, err := derphttp.NewClient(second, serverURL, tailnet.Logger(slogtest.Make(t, nil)))
require.NoError(t, err)
secondClient.TLSConfig = tlsConfig
err = secondClient.Connect(context.Background())
require.NoError(t, err)
closed := make(chan struct{})
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
sent := []byte("hello world")
go func() {
defer close(closed)
ticker := time.NewTicker(50 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
err = firstClient.Send(second.Public(), sent)
require.NoError(t, err)
}
}()
got := recvData(t, secondClient)
require.Equal(t, sent, got)
cancelFunc()
<-closed
})
t.Run("TwentyMeshes", func(t *testing.T) {
t.Parallel()
meshes := make([]*derpmesh.Mesh, 0, 20)
serverURLs := make([]string, 0, 20)
for i := 0; i < 20; i++ {
server, url := startDERP(t, tlsConfig)
mesh := derpmesh.New(slogtest.Make(t, nil).Named("mesh").Leveled(slog.LevelDebug), server, tlsConfig)
t.Cleanup(func() {
_ = server.Close()
_ = mesh.Close()
})
serverURLs = append(serverURLs, url)
meshes = append(meshes, mesh)
}
for _, mesh := range meshes {
mesh.SetAddresses(serverURLs, true)
}
first := key.NewNode()
second := key.NewNode()
firstClient, err := derphttp.NewClient(first, serverURLs[9], tailnet.Logger(slogtest.Make(t, nil)))
require.NoError(t, err)
firstClient.TLSConfig = tlsConfig
secondClient, err := derphttp.NewClient(second, serverURLs[16], tailnet.Logger(slogtest.Make(t, nil)))
require.NoError(t, err)
secondClient.TLSConfig = tlsConfig
err = secondClient.Connect(context.Background())
require.NoError(t, err)
closed := make(chan struct{})
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
sent := []byte("hello world")
go func() {
defer close(closed)
ticker := time.NewTicker(50 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
err = firstClient.Send(second.Public(), sent)
require.NoError(t, err)
}
}()
got := recvData(t, secondClient)
require.Equal(t, sent, got)
cancelFunc()
<-closed
})
}
func recvData(t *testing.T, client *derphttp.Client) []byte {
for {
msg, err := client.Recv()
if errors.Is(err, io.EOF) {
return nil
}
assert.NoError(t, err)
t.Logf("derp: %T", msg)
switch msg := msg.(type) {
case derp.ReceivedPacket:
return msg.Data
default:
// Drop all others!
}
}
}
func startDERP(t *testing.T, tlsConfig *tls.Config) (*derp.Server, string) {
logf := tailnet.Logger(slogtest.Make(t, nil))
d := derp.NewServer(key.NewNode(), logf)
d.SetMeshKey("some-key")
server := httptest.NewUnstartedServer(derphttp.Handler(d))
server.TLS = tlsConfig
server.StartTLS()
t.Cleanup(func() {
_ = d.Close()
})
t.Cleanup(server.Close)
return d, server.URL
}

View File

@ -0,0 +1,391 @@
package replicasync
import (
"context"
"crypto/tls"
"database/sql"
"errors"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/database"
)
var (
PubsubEvent = "replica"
)
type Options struct {
CleanupInterval time.Duration
UpdateInterval time.Duration
PeerTimeout time.Duration
RelayAddress string
RegionID int32
TLSConfig *tls.Config
}
// New registers the replica with the database and periodically updates to ensure
// it's healthy. It contacts all other alive replicas to ensure they are reachable.
func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) {
if options == nil {
options = &Options{}
}
if options.PeerTimeout == 0 {
options.PeerTimeout = 3 * time.Second
}
if options.UpdateInterval == 0 {
options.UpdateInterval = 5 * time.Second
}
if options.CleanupInterval == 0 {
// The cleanup interval can be quite long, because it's
// primary purpose is to clean up dead replicas.
options.CleanupInterval = 30 * time.Minute
}
hostname, err := os.Hostname()
if err != nil {
return nil, xerrors.Errorf("get hostname: %w", err)
}
databaseLatency, err := db.Ping(ctx)
if err != nil {
return nil, xerrors.Errorf("ping database: %w", err)
}
id := uuid.New()
replica, err := db.InsertReplica(ctx, database.InsertReplicaParams{
ID: id,
CreatedAt: database.Now(),
StartedAt: database.Now(),
UpdatedAt: database.Now(),
Hostname: hostname,
RegionID: options.RegionID,
RelayAddress: options.RelayAddress,
Version: buildinfo.Version(),
DatabaseLatency: int32(databaseLatency.Microseconds()),
})
if err != nil {
return nil, xerrors.Errorf("insert replica: %w", err)
}
err = pubsub.Publish(PubsubEvent, []byte(id.String()))
if err != nil {
return nil, xerrors.Errorf("publish new replica: %w", err)
}
ctx, cancelFunc := context.WithCancel(ctx)
manager := &Manager{
id: id,
options: options,
db: db,
pubsub: pubsub,
self: replica,
logger: logger,
closed: make(chan struct{}),
closeCancel: cancelFunc,
}
err = manager.syncReplicas(ctx)
if err != nil {
return nil, xerrors.Errorf("run replica: %w", err)
}
peers := manager.Regional()
if len(peers) > 0 {
self := manager.Self()
if self.RelayAddress == "" {
return nil, xerrors.Errorf("a relay address must be specified when running multiple replicas in the same region")
}
}
err = manager.subscribe(ctx)
if err != nil {
return nil, xerrors.Errorf("subscribe: %w", err)
}
manager.closeWait.Add(1)
go manager.loop(ctx)
return manager, nil
}
// Manager keeps the replica up to date and in sync with other replicas.
type Manager struct {
id uuid.UUID
options *Options
db database.Store
pubsub database.Pubsub
logger slog.Logger
closeWait sync.WaitGroup
closeMutex sync.Mutex
closed chan (struct{})
closeCancel context.CancelFunc
self database.Replica
mutex sync.Mutex
peers []database.Replica
callback func()
}
// updateInterval is used to determine a replicas state.
// If the replica was updated > the time, it's considered healthy.
// If the replica was updated < the time, it's considered stale.
func (m *Manager) updateInterval() time.Time {
return database.Now().Add(-3 * m.options.UpdateInterval)
}
// loop runs the replica update sequence on an update interval.
func (m *Manager) loop(ctx context.Context) {
defer m.closeWait.Done()
updateTicker := time.NewTicker(m.options.UpdateInterval)
defer updateTicker.Stop()
deleteTicker := time.NewTicker(m.options.CleanupInterval)
defer deleteTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-deleteTicker.C:
err := m.db.DeleteReplicasUpdatedBefore(ctx, m.updateInterval())
if err != nil {
m.logger.Warn(ctx, "delete old replicas", slog.Error(err))
}
continue
case <-updateTicker.C:
}
err := m.syncReplicas(ctx)
if err != nil && !errors.Is(err, context.Canceled) {
m.logger.Warn(ctx, "run replica update loop", slog.Error(err))
}
}
}
// subscribe listens for new replica information!
func (m *Manager) subscribe(ctx context.Context) error {
var (
needsUpdate = false
updating = false
updateMutex = sync.Mutex{}
)
// This loop will continually update nodes as updates are processed.
// The intent is to always be up to date without spamming the run
// function, so if a new update comes in while one is being processed,
// it will reprocess afterwards.
var update func()
update = func() {
err := m.syncReplicas(ctx)
if err != nil && !errors.Is(err, context.Canceled) {
m.logger.Warn(ctx, "run replica from subscribe", slog.Error(err))
}
updateMutex.Lock()
if needsUpdate {
needsUpdate = false
updateMutex.Unlock()
update()
return
}
updating = false
updateMutex.Unlock()
}
cancelFunc, err := m.pubsub.Subscribe(PubsubEvent, func(ctx context.Context, message []byte) {
updateMutex.Lock()
defer updateMutex.Unlock()
id, err := uuid.Parse(string(message))
if err != nil {
return
}
// Don't process updates for ourself!
if id == m.id {
return
}
if updating {
needsUpdate = true
return
}
updating = true
go update()
})
if err != nil {
return err
}
go func() {
<-ctx.Done()
cancelFunc()
}()
return nil
}
func (m *Manager) syncReplicas(ctx context.Context) error {
m.closeMutex.Lock()
m.closeWait.Add(1)
m.closeMutex.Unlock()
defer m.closeWait.Done()
// Expect replicas to update once every three times the interval...
// If they don't, assume death!
replicas, err := m.db.GetReplicasUpdatedAfter(ctx, m.updateInterval())
if err != nil {
return xerrors.Errorf("get replicas: %w", err)
}
m.mutex.Lock()
m.peers = make([]database.Replica, 0, len(replicas))
for _, replica := range replicas {
if replica.ID == m.id {
continue
}
m.peers = append(m.peers, replica)
}
m.mutex.Unlock()
client := http.Client{
Timeout: m.options.PeerTimeout,
Transport: &http.Transport{
TLSClientConfig: m.options.TLSConfig,
},
}
defer client.CloseIdleConnections()
var wg sync.WaitGroup
var mu sync.Mutex
failed := make([]string, 0)
for _, peer := range m.Regional() {
wg.Add(1)
go func(peer database.Replica) {
defer wg.Done()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, peer.RelayAddress, nil)
if err != nil {
m.logger.Warn(ctx, "create http request for relay probe",
slog.F("relay_address", peer.RelayAddress), slog.Error(err))
return
}
res, err := client.Do(req)
if err != nil {
mu.Lock()
failed = append(failed, fmt.Sprintf("relay %s (%s): %s", peer.Hostname, peer.RelayAddress, err))
mu.Unlock()
return
}
_ = res.Body.Close()
}(peer)
}
wg.Wait()
replicaError := ""
if len(failed) > 0 {
replicaError = fmt.Sprintf("Failed to dial peers: %s", strings.Join(failed, ", "))
}
databaseLatency, err := m.db.Ping(ctx)
if err != nil {
return xerrors.Errorf("ping database: %w", err)
}
replica, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{
ID: m.self.ID,
UpdatedAt: database.Now(),
StartedAt: m.self.StartedAt,
StoppedAt: m.self.StoppedAt,
RelayAddress: m.self.RelayAddress,
RegionID: m.self.RegionID,
Hostname: m.self.Hostname,
Version: m.self.Version,
Error: replicaError,
DatabaseLatency: int32(databaseLatency.Microseconds()),
})
if err != nil {
return xerrors.Errorf("update replica: %w", err)
}
m.mutex.Lock()
defer m.mutex.Unlock()
if m.self.Error != replica.Error {
// Publish an update occurred!
err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String()))
if err != nil {
return xerrors.Errorf("publish replica update: %w", err)
}
}
m.self = replica
if m.callback != nil {
go m.callback()
}
return nil
}
// Self represents the current replica.
func (m *Manager) Self() database.Replica {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.self
}
// All returns every replica, including itself.
func (m *Manager) All() []database.Replica {
m.mutex.Lock()
defer m.mutex.Unlock()
return append(m.peers[:], m.self)
}
// Regional returns all replicas in the same region excluding itself.
func (m *Manager) Regional() []database.Replica {
m.mutex.Lock()
defer m.mutex.Unlock()
replicas := make([]database.Replica, 0)
for _, replica := range m.peers {
if replica.RegionID != m.self.RegionID {
continue
}
replicas = append(replicas, replica)
}
return replicas
}
// SetCallback sets a function to execute whenever new peers
// are refreshed or updated.
func (m *Manager) SetCallback(callback func()) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.callback = callback
// Instantly call the callback to inform replicas!
go callback()
}
func (m *Manager) Close() error {
m.closeMutex.Lock()
select {
case <-m.closed:
m.closeMutex.Unlock()
return nil
default:
}
close(m.closed)
m.closeCancel()
m.closeWait.Wait()
m.closeMutex.Unlock()
m.mutex.Lock()
defer m.mutex.Unlock()
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFunc()
_, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{
ID: m.self.ID,
UpdatedAt: database.Now(),
StartedAt: m.self.StartedAt,
StoppedAt: sql.NullTime{
Time: database.Now(),
Valid: true,
},
RelayAddress: m.self.RelayAddress,
RegionID: m.self.RegionID,
Hostname: m.self.Hostname,
Version: m.self.Version,
Error: m.self.Error,
})
if err != nil {
return xerrors.Errorf("update replica: %w", err)
}
err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String()))
if err != nil {
return xerrors.Errorf("publish replica update: %w", err)
}
return nil
}

View File

@ -0,0 +1,239 @@
package replicasync_test
import (
"context"
"crypto/tls"
"crypto/x509"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/enterprise/replicasync"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestReplica(t *testing.T) {
t.Parallel()
t.Run("CreateOnNew", func(t *testing.T) {
// This ensures that a new replica is created on New.
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
closeChan := make(chan struct{}, 1)
cancel, err := pubsub.Subscribe(replicasync.PubsubEvent, func(ctx context.Context, message []byte) {
closeChan <- struct{}{}
})
require.NoError(t, err)
defer cancel()
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil)
require.NoError(t, err)
<-closeChan
_ = server.Close()
require.NoError(t, err)
})
t.Run("ErrorsWithoutRelayAddress", func(t *testing.T) {
// Ensures that the replica reports a successful status for
// accessing all of its peers.
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
ID: uuid.New(),
CreatedAt: database.Now(),
StartedAt: database.Now(),
UpdatedAt: database.Now(),
Hostname: "something",
})
require.NoError(t, err)
_, err = replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil)
require.Error(t, err)
require.Equal(t, "a relay address must be specified when running multiple replicas in the same region", err.Error())
})
t.Run("ConnectsToPeerReplica", func(t *testing.T) {
// Ensures that the replica reports a successful status for
// accessing all of its peers.
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
db, pubsub := dbtestutil.NewDB(t)
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
ID: uuid.New(),
CreatedAt: database.Now(),
StartedAt: database.Now(),
UpdatedAt: database.Now(),
Hostname: "something",
RelayAddress: srv.URL,
})
require.NoError(t, err)
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
RelayAddress: "http://169.254.169.254",
})
require.NoError(t, err)
require.Len(t, server.Regional(), 1)
require.Equal(t, peer.ID, server.Regional()[0].ID)
require.Empty(t, server.Self().Error)
_ = server.Close()
})
t.Run("ConnectsToPeerReplicaTLS", func(t *testing.T) {
// Ensures that the replica reports a successful status for
// accessing all of its peers.
t.Parallel()
rawCert := testutil.GenerateTLSCertificate(t, "hello.org")
certificate, err := x509.ParseCertificate(rawCert.Certificate[0])
require.NoError(t, err)
pool := x509.NewCertPool()
pool.AddCert(certificate)
// nolint:gosec
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{rawCert},
ServerName: "hello.org",
RootCAs: pool,
}
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
srv.TLS = tlsConfig
srv.StartTLS()
defer srv.Close()
db, pubsub := dbtestutil.NewDB(t)
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
ID: uuid.New(),
CreatedAt: database.Now(),
StartedAt: database.Now(),
UpdatedAt: database.Now(),
Hostname: "something",
RelayAddress: srv.URL,
})
require.NoError(t, err)
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
RelayAddress: "http://169.254.169.254",
TLSConfig: tlsConfig,
})
require.NoError(t, err)
require.Len(t, server.Regional(), 1)
require.Equal(t, peer.ID, server.Regional()[0].ID)
require.Empty(t, server.Self().Error)
_ = server.Close()
})
t.Run("ConnectsToFakePeerWithError", func(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
ID: uuid.New(),
CreatedAt: database.Now().Add(time.Minute),
StartedAt: database.Now().Add(time.Minute),
UpdatedAt: database.Now().Add(time.Minute),
Hostname: "something",
// Fake address to dial!
RelayAddress: "http://127.0.0.1:1",
})
require.NoError(t, err)
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
PeerTimeout: 1 * time.Millisecond,
RelayAddress: "http://127.0.0.1:1",
})
require.NoError(t, err)
require.Len(t, server.Regional(), 1)
require.Equal(t, peer.ID, server.Regional()[0].ID)
require.NotEmpty(t, server.Self().Error)
require.Contains(t, server.Self().Error, "Failed to dial peers")
_ = server.Close()
})
t.Run("RefreshOnPublish", func(t *testing.T) {
// Refresh when a new replica appears!
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
ID: uuid.New(),
RelayAddress: srv.URL,
UpdatedAt: database.Now(),
})
require.NoError(t, err)
// Publish multiple times to ensure it can handle that case.
err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String()))
require.NoError(t, err)
err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String()))
require.NoError(t, err)
require.Eventually(t, func() bool {
return len(server.Regional()) == 1
}, testutil.WaitShort, testutil.IntervalFast)
_ = server.Close()
})
t.Run("DeletesOld", func(t *testing.T) {
t.Parallel()
db, pubsub := dbtestutil.NewDB(t)
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
ID: uuid.New(),
UpdatedAt: database.Now().Add(-time.Hour),
})
require.NoError(t, err)
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
RelayAddress: "google.com",
CleanupInterval: time.Millisecond,
})
require.NoError(t, err)
defer server.Close()
require.Eventually(t, func() bool {
return len(server.Regional()) == 0
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("TwentyConcurrent", func(t *testing.T) {
// Ensures that twenty concurrent replicas can spawn and all
// discover each other in parallel!
t.Parallel()
// This doesn't use the database fake because creating
// this many PostgreSQL connections takes some
// configuration tweaking.
db := databasefake.New()
pubsub := database.NewPubsubInMemory()
logger := slogtest.Make(t, nil)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
var wg sync.WaitGroup
count := 20
wg.Add(count)
for i := 0; i < count; i++ {
server, err := replicasync.New(context.Background(), logger, db, pubsub, &replicasync.Options{
RelayAddress: srv.URL,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = server.Close()
})
done := false
server.SetCallback(func() {
if len(server.All()) != count {
return
}
if done {
return
}
done = true
wg.Done()
})
}
wg.Wait()
})
}

View File

@ -0,0 +1,575 @@
package tailnet
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
agpl "github.com/coder/coder/tailnet"
)
// NewCoordinator creates a new high availability coordinator
// that uses PostgreSQL pubsub to exchange handshakes.
func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) {
ctx, cancelFunc := context.WithCancel(context.Background())
coord := &haCoordinator{
id: uuid.New(),
log: logger,
pubsub: pubsub,
closeFunc: cancelFunc,
close: make(chan struct{}),
nodes: map[uuid.UUID]*agpl.Node{},
agentSockets: map[uuid.UUID]net.Conn{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
}
if err := coord.runPubsub(ctx); err != nil {
return nil, xerrors.Errorf("run coordinator pubsub: %w", err)
}
return coord, nil
}
type haCoordinator struct {
id uuid.UUID
log slog.Logger
mutex sync.RWMutex
pubsub database.Pubsub
close chan struct{}
closeFunc context.CancelFunc
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*agpl.Node
// agentSockets maps agent IDs to their open websocket.
agentSockets map[uuid.UUID]net.Conn
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
}
// Node returns an in-memory node by ID.
func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
c.mutex.Lock()
defer c.mutex.Unlock()
node := c.nodes[id]
return node
}
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
c.mutex.Lock()
// When a new connection is requested, we update it with the latest
// node of the agent. This allows the connection to establish.
node, ok := c.nodes[agent]
c.mutex.Unlock()
if ok {
data, err := json.Marshal([]*agpl.Node{node})
if err != nil {
return xerrors.Errorf("marshal node: %w", err)
}
_, err = conn.Write(data)
if err != nil {
return xerrors.Errorf("write nodes: %w", err)
}
} else {
err := c.publishClientHello(agent)
if err != nil {
return xerrors.Errorf("publish client hello: %w", err)
}
}
c.mutex.Lock()
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
connectionSockets = map[uuid.UUID]net.Conn{}
c.agentToConnectionSockets[agent] = connectionSockets
}
// Insert this connection into a map so the agent can publish node updates.
connectionSockets[id] = conn
c.mutex.Unlock()
defer func() {
c.mutex.Lock()
defer c.mutex.Unlock()
// Clean all traces of this connection from the map.
delete(c.nodes, id)
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
return
}
delete(connectionSockets, id)
if len(connectionSockets) != 0 {
return
}
delete(c.agentToConnectionSockets, agent)
}()
decoder := json.NewDecoder(conn)
// Indefinitely handle messages from the client websocket.
for {
err := c.handleNextClientMessage(id, agent, decoder)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return nil
}
return xerrors.Errorf("handle next client message: %w", err)
}
}
}
func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
var node agpl.Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
c.mutex.Lock()
// Update the node of this client in our in-memory map. If an agent entirely
// shuts down and reconnects, it needs to be aware of all clients attempting
// to establish connections.
c.nodes[id] = &node
// Write the new node from this client to the actively connected agent.
agentSocket, ok := c.agentSockets[agent]
c.mutex.Unlock()
if !ok {
// If we don't own the agent locally, send it over pubsub to a node that
// owns the agent.
err := c.publishNodesToAgent(agent, []*agpl.Node{&node})
if err != nil {
return xerrors.Errorf("publish node to agent")
}
return nil
}
// Write the new node from this client to the actively
// connected agent.
data, err := json.Marshal([]*agpl.Node{&node})
if err != nil {
return xerrors.Errorf("marshal nodes: %w", err)
}
_, err = agentSocket.Write(data)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return nil
}
return xerrors.Errorf("write json: %w", err)
}
return nil
}
// ServeAgent accepts a WebSocket connection to an agent that listens to
// incoming connections and publishes node updates.
func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
// Tell clients on other instances to send a callmemaybe to us.
err := c.publishAgentHello(id)
if err != nil {
return xerrors.Errorf("publish agent hello: %w", err)
}
// Publish all nodes on this instance that want to connect to this agent.
nodes := c.nodesSubscribedToAgent(id)
if len(nodes) > 0 {
data, err := json.Marshal(nodes)
if err != nil {
return xerrors.Errorf("marshal json: %w", err)
}
_, err = conn.Write(data)
if err != nil {
return xerrors.Errorf("write nodes: %w", err)
}
}
// If an old agent socket is connected, we close it
// to avoid any leaks. This shouldn't ever occur because
// we expect one agent to be running.
c.mutex.Lock()
oldAgentSocket, ok := c.agentSockets[id]
if ok {
_ = oldAgentSocket.Close()
}
c.agentSockets[id] = conn
c.mutex.Unlock()
defer func() {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.agentSockets, id)
delete(c.nodes, id)
}()
decoder := json.NewDecoder(conn)
for {
node, err := c.handleAgentUpdate(id, decoder)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return nil
}
return xerrors.Errorf("handle next agent message: %w", err)
}
err = c.publishAgentToNodes(id, node)
if err != nil {
return xerrors.Errorf("publish agent to nodes: %w", err)
}
}
}
func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node {
c.mutex.Lock()
defer c.mutex.Unlock()
sockets, ok := c.agentToConnectionSockets[agentID]
if !ok {
return nil
}
nodes := make([]*agpl.Node, 0, len(sockets))
for targetID := range sockets {
node, ok := c.nodes[targetID]
if !ok {
continue
}
nodes = append(nodes, node)
}
return nodes
}
func (c *haCoordinator) handleClientHello(id uuid.UUID) error {
c.mutex.Lock()
node, ok := c.nodes[id]
c.mutex.Unlock()
if !ok {
return nil
}
return c.publishAgentToNodes(id, node)
}
func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) {
var node agpl.Node
err := decoder.Decode(&node)
if err != nil {
return nil, xerrors.Errorf("read json: %w", err)
}
c.mutex.Lock()
oldNode := c.nodes[id]
if oldNode != nil {
if oldNode.AsOf.After(node.AsOf) {
c.mutex.Unlock()
return oldNode, nil
}
}
c.nodes[id] = &node
connectionSockets, ok := c.agentToConnectionSockets[id]
if !ok {
c.mutex.Unlock()
return &node, nil
}
data, err := json.Marshal([]*agpl.Node{&node})
if err != nil {
c.mutex.Unlock()
return nil, xerrors.Errorf("marshal nodes: %w", err)
}
// Publish the new node to every listening socket.
var wg sync.WaitGroup
wg.Add(len(connectionSockets))
for _, connectionSocket := range connectionSockets {
connectionSocket := connectionSocket
go func() {
defer wg.Done()
_ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, _ = connectionSocket.Write(data)
}()
}
c.mutex.Unlock()
wg.Wait()
return &node, nil
}
// Close closes all of the open connections in the coordinator and stops the
// coordinator from accepting new connections.
func (c *haCoordinator) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
select {
case <-c.close:
return nil
default:
}
close(c.close)
c.closeFunc()
wg := sync.WaitGroup{}
wg.Add(len(c.agentSockets))
for _, socket := range c.agentSockets {
socket := socket
go func() {
_ = socket.Close()
wg.Done()
}()
}
for _, connMap := range c.agentToConnectionSockets {
wg.Add(len(connMap))
for _, socket := range connMap {
socket := socket
go func() {
_ = socket.Close()
wg.Done()
}()
}
}
wg.Wait()
return nil
}
func (c *haCoordinator) publishNodesToAgent(recipient uuid.UUID, nodes []*agpl.Node) error {
msg, err := c.formatCallMeMaybe(recipient, nodes)
if err != nil {
return xerrors.Errorf("format publish message: %w", err)
}
err = c.pubsub.Publish("wireguard_peers", msg)
if err != nil {
return xerrors.Errorf("publish message: %w", err)
}
return nil
}
func (c *haCoordinator) publishAgentHello(id uuid.UUID) error {
msg, err := c.formatAgentHello(id)
if err != nil {
return xerrors.Errorf("format publish message: %w", err)
}
err = c.pubsub.Publish("wireguard_peers", msg)
if err != nil {
return xerrors.Errorf("publish message: %w", err)
}
return nil
}
func (c *haCoordinator) publishClientHello(id uuid.UUID) error {
msg, err := c.formatClientHello(id)
if err != nil {
return xerrors.Errorf("format client hello: %w", err)
}
err = c.pubsub.Publish("wireguard_peers", msg)
if err != nil {
return xerrors.Errorf("publish client hello: %w", err)
}
return nil
}
func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error {
msg, err := c.formatAgentUpdate(id, node)
if err != nil {
return xerrors.Errorf("format publish message: %w", err)
}
err = c.pubsub.Publish("wireguard_peers", msg)
if err != nil {
return xerrors.Errorf("publish message: %w", err)
}
return nil
}
func (c *haCoordinator) runPubsub(ctx context.Context) error {
messageQueue := make(chan []byte, 64)
cancelSub, err := c.pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) {
select {
case messageQueue <- message:
case <-ctx.Done():
return
}
})
if err != nil {
return xerrors.Errorf("subscribe wireguard peers")
}
go func() {
for {
var message []byte
select {
case <-ctx.Done():
return
case message = <-messageQueue:
}
c.handlePubsubMessage(ctx, message)
}
}()
go func() {
defer cancelSub()
<-c.close
}()
return nil
}
func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) {
sp := bytes.Split(message, []byte("|"))
if len(sp) != 4 {
c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message)))
return
}
var (
coordinatorID = sp[0]
eventType = sp[1]
agentID = sp[2]
nodeJSON = sp[3]
)
sender, err := uuid.ParseBytes(coordinatorID)
if err != nil {
c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message)))
return
}
// We sent this message!
if sender == c.id {
return
}
switch string(eventType) {
case "callmemaybe":
agentUUID, err := uuid.ParseBytes(agentID)
if err != nil {
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
return
}
c.mutex.Lock()
agentSocket, ok := c.agentSockets[agentUUID]
if !ok {
c.mutex.Unlock()
return
}
c.mutex.Unlock()
// We get a single node over pubsub, so turn into an array.
_, err = agentSocket.Write(nodeJSON)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return
}
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
return
}
case "clienthello":
agentUUID, err := uuid.ParseBytes(agentID)
if err != nil {
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
return
}
err = c.handleClientHello(agentUUID)
if err != nil {
c.log.Error(ctx, "handle agent request node", slog.Error(err))
return
}
case "agenthello":
agentUUID, err := uuid.ParseBytes(agentID)
if err != nil {
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
return
}
nodes := c.nodesSubscribedToAgent(agentUUID)
if len(nodes) > 0 {
err := c.publishNodesToAgent(agentUUID, nodes)
if err != nil {
c.log.Error(ctx, "publish nodes to agent", slog.Error(err))
return
}
}
case "agentupdate":
agentUUID, err := uuid.ParseBytes(agentID)
if err != nil {
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
return
}
decoder := json.NewDecoder(bytes.NewReader(nodeJSON))
_, err = c.handleAgentUpdate(agentUUID, decoder)
if err != nil {
c.log.Error(ctx, "handle agent update", slog.Error(err))
return
}
default:
c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType)))
}
}
// format: <coordinator id>|callmemaybe|<recipient id>|<node json>
func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, nodes []*agpl.Node) ([]byte, error) {
buf := bytes.Buffer{}
buf.WriteString(c.id.String() + "|")
buf.WriteString("callmemaybe|")
buf.WriteString(recipient.String() + "|")
err := json.NewEncoder(&buf).Encode(nodes)
if err != nil {
return nil, xerrors.Errorf("encode node: %w", err)
}
return buf.Bytes(), nil
}
// format: <coordinator id>|agenthello|<node id>|
func (c *haCoordinator) formatAgentHello(id uuid.UUID) ([]byte, error) {
buf := bytes.Buffer{}
buf.WriteString(c.id.String() + "|")
buf.WriteString("agenthello|")
buf.WriteString(id.String() + "|")
return buf.Bytes(), nil
}
// format: <coordinator id>|clienthello|<agent id>|
func (c *haCoordinator) formatClientHello(id uuid.UUID) ([]byte, error) {
buf := bytes.Buffer{}
buf.WriteString(c.id.String() + "|")
buf.WriteString("clienthello|")
buf.WriteString(id.String() + "|")
return buf.Bytes(), nil
}
// format: <coordinator id>|agentupdate|<node id>|<node json>
func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) {
buf := bytes.Buffer{}
buf.WriteString(c.id.String() + "|")
buf.WriteString("agentupdate|")
buf.WriteString(id.String() + "|")
err := json.NewEncoder(&buf).Encode(node)
if err != nil {
return nil, xerrors.Errorf("encode node: %w", err)
}
return buf.Bytes(), nil
}

View File

@ -0,0 +1,261 @@
package tailnet_test
import (
"net"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/enterprise/tailnet"
agpl "github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
func TestCoordinatorSingle(t *testing.T) {
t.Parallel()
t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel()
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
require.NoError(t, err)
defer coordinator.Close()
client, server := net.Pipe()
sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&agpl.Node{})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err = client.Close()
require.NoError(t, err)
<-errChan
<-closeChan
})
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
require.NoError(t, err)
defer coordinator.Close()
client, server := net.Pipe()
sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id)
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&agpl.Node{})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err = client.Close()
require.NoError(t, err)
<-errChan
<-closeChan
})
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
require.NoError(t, err)
defer coordinator.Close()
agentWS, agentServerWS := net.Pipe()
defer agentWS.Close()
agentNodeChan := make(chan []*agpl.Node)
sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
agentNodeChan <- nodes
return nil
})
agentID := uuid.New()
closeAgentChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID)
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&agpl.Node{})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*agpl.Node)
sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error {
clientNodeChan <- nodes
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := <-clientNodeChan
require.Len(t, agentNodes, 1)
sendClientNode(&agpl.Node{})
clientNodes := <-agentNodeChan
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&agpl.Node{})
agentNodes = <-clientNodeChan
require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect.
err = agentWS.Close()
require.NoError(t, err)
<-agentErrChan
<-closeAgentChan
// Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe()
defer agentWS.Close()
agentNodeChan = make(chan []*agpl.Node)
_, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
agentNodeChan <- nodes
return nil
})
closeAgentChan = make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID)
assert.NoError(t, err)
close(closeAgentChan)
}()
// Ensure the existing listening client sends it's node immediately!
clientNodes = <-agentNodeChan
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
<-agentErrChan
<-closeAgentChan
err = clientWS.Close()
require.NoError(t, err)
<-clientErrChan
<-closeClientChan
})
}
func TestCoordinatorHA(t *testing.T) {
t.Parallel()
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
_, pubsub := dbtestutil.NewDB(t)
coordinator1, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
require.NoError(t, err)
defer coordinator1.Close()
agentWS, agentServerWS := net.Pipe()
defer agentWS.Close()
agentNodeChan := make(chan []*agpl.Node)
sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
agentNodeChan <- nodes
return nil
})
agentID := uuid.New()
closeAgentChan := make(chan struct{})
go func() {
err := coordinator1.ServeAgent(agentServerWS, agentID)
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&agpl.Node{})
require.Eventually(t, func() bool {
return coordinator1.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
coordinator2, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
require.NoError(t, err)
defer coordinator2.Close()
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
defer clientServerWS.Close()
clientNodeChan := make(chan []*agpl.Node)
sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error {
clientNodeChan <- nodes
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator2.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := <-clientNodeChan
require.Len(t, agentNodes, 1)
sendClientNode(&agpl.Node{})
_ = sendClientNode
clientNodes := <-agentNodeChan
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&agpl.Node{})
agentNodes = <-clientNodeChan
require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect.
require.NoError(t, agentWS.Close())
require.NoError(t, agentServerWS.Close())
<-agentErrChan
<-closeAgentChan
// Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe()
defer agentWS.Close()
agentNodeChan = make(chan []*agpl.Node)
_, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
agentNodeChan <- nodes
return nil
})
closeAgentChan = make(chan struct{})
go func() {
err := coordinator1.ServeAgent(agentServerWS, agentID)
assert.NoError(t, err)
close(closeAgentChan)
}()
// Ensure the existing listening client sends it's node immediately!
clientNodes = <-agentNodeChan
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
<-agentErrChan
<-closeAgentChan
err = clientWS.Close()
require.NoError(t, err)
<-clientErrChan
<-closeClientChan
})
}

2
go.mod
View File

@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0
// There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here:
// https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5
// Switch to our fork that imports fixes from http://github.com/tailscale/ssh.
// See: https://github.com/coder/coder/issues/3371

4
go.sum
View File

@ -351,8 +351,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY=
github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY=
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko=
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c h1:xa6lr5Pj87Is26tgpzwBsEGKL7aVz7/fRGgY9QIbf3E=
github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak=
github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5 h1:WVH6e/qK3Wpl0wbmpORD2oQ1qLJborF3fsFHyO1ps0Y=
github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak=
github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE=
github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU=
github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU=

View File

@ -14,10 +14,7 @@ metadata:
{{- include "coder.labels" . | nindent 4 }}
annotations: {{ toYaml .Values.coder.annotations | nindent 4}}
spec:
# NOTE: this is currently not used as coder v2 does not support high
# availability yet.
# replicas: {{ .Values.coder.replicaCount }}
replicas: 1
replicas: {{ .Values.coder.replicaCount }}
selector:
matchLabels:
{{- include "coder.selectorLabels" . | nindent 6 }}
@ -38,6 +35,13 @@ spec:
env:
- name: CODER_ADDRESS
value: "0.0.0.0:{{ include "coder.port" . }}"
# Used for inter-pod communication with high-availability.
- name: KUBE_POD_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: CODER_DERP_SERVER_RELAY_ADDRESS
value: "{{ include "coder.portName" . }}://$(KUBE_POD_IP):{{ include "coder.port" . }}"
{{- include "coder.tlsEnv" . | nindent 12 }}
{{- with .Values.coder.env -}}
{{ toYaml . | nindent 12 }}

View File

@ -10,6 +10,7 @@ metadata:
{{- toYaml .Values.coder.service.annotations | nindent 4 }}
spec:
type: {{ .Values.coder.service.type }}
sessionAffinity: ClientIP
ports:
- name: {{ include "coder.portName" . | quote }}
port: {{ include "coder.servicePort" . }}

View File

@ -1,9 +1,9 @@
# coder -- Primary configuration for `coder server`.
coder:
# NOTE: this is currently not used as coder v2 does not support high
# availability yet.
# # coder.replicaCount -- The number of Kubernetes deployment replicas.
# replicaCount: 1
# coder.replicaCount -- The number of Kubernetes deployment replicas.
# This should only be increased if High Availability is enabled.
# This is an Enterprise feature. Contact sales@coder.com.
replicaCount: 1
# coder.image -- The image to use for Coder.
image:

View File

@ -28,6 +28,7 @@ export const defaultEntitlements = (): TypesGen.Entitlements => {
return {
features: features,
has_license: false,
errors: [],
warnings: [],
experimental: false,
trial: false,

View File

@ -274,6 +274,7 @@ export interface DeploymentFlags {
readonly derp_server_region_code: StringFlag
readonly derp_server_region_name: StringFlag
readonly derp_server_stun_address: StringArrayFlag
readonly derp_server_relay_address: StringFlag
readonly derp_config_url: StringFlag
readonly derp_config_path: StringFlag
readonly prom_enabled: BoolFlag
@ -337,6 +338,7 @@ export interface DurationFlag {
export interface Entitlements {
readonly features: Record<string, Feature>
readonly warnings: string[]
readonly errors: string[]
readonly has_license: boolean
readonly experimental: boolean
readonly trial: boolean
@ -528,6 +530,17 @@ export interface PutExtendWorkspaceRequest {
readonly deadline: string
}
// From codersdk/replicas.go
export interface Replica {
readonly id: string
readonly hostname: string
readonly created_at: string
readonly relay_address: string
readonly region_id: number
readonly error: string
readonly database_latency: number
}
// From codersdk/error.go
export interface Response {
readonly message: string

View File

@ -8,15 +8,15 @@ export const LicenseBanner: React.FC = () => {
const [entitlementsState, entitlementsSend] = useActor(
xServices.entitlementsXService,
)
const { warnings } = entitlementsState.context.entitlements
const { errors, warnings } = entitlementsState.context.entitlements
/** Gets license data on app mount because LicenseBanner is mounted in App */
useEffect(() => {
entitlementsSend("GET_ENTITLEMENTS")
}, [entitlementsSend])
if (warnings.length > 0) {
return <LicenseBannerView warnings={warnings} />
if (errors.length > 0 || warnings.length > 0) {
return <LicenseBannerView errors={errors} warnings={warnings} />
} else {
return null
}

View File

@ -12,13 +12,23 @@ const Template: Story<LicenseBannerViewProps> = (args) => (
export const OneWarning = Template.bind({})
OneWarning.args = {
errors: [],
warnings: ["You have exceeded the number of seats in your license."],
}
export const TwoWarnings = Template.bind({})
TwoWarnings.args = {
errors: [],
warnings: [
"You have exceeded the number of seats in your license.",
"You are flying too close to the sun.",
],
}
export const OneError = Template.bind({})
OneError.args = {
errors: [
"You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.",
],
warnings: [],
}

View File

@ -2,47 +2,56 @@ import { makeStyles } from "@material-ui/core/styles"
import { Expander } from "components/Expander/Expander"
import { Pill } from "components/Pill/Pill"
import { useState } from "react"
import { colors } from "theme/colors"
export const Language = {
licenseIssue: "License Issue",
licenseIssues: (num: number): string => `${num} License Issues`,
upgrade: "Contact us to upgrade your license.",
upgrade: "Contact sales@coder.com.",
exceeded: "It looks like you've exceeded some limits of your license.",
lessDetails: "Less",
moreDetails: "More",
}
export interface LicenseBannerViewProps {
errors: string[]
warnings: string[]
}
export const LicenseBannerView: React.FC<LicenseBannerViewProps> = ({
errors,
warnings,
}) => {
const styles = useStyles()
const [showDetails, setShowDetails] = useState(false)
if (warnings.length === 1) {
const isError = errors.length > 0
const messages = [...errors, ...warnings]
const type = isError ? "error" : "warning"
if (messages.length === 1) {
return (
<div className={styles.container}>
<Pill text={Language.licenseIssue} type="warning" lightBorder />
<span className={styles.text}>{warnings[0]}</span>
&nbsp;
<a href="mailto:sales@coder.com" className={styles.link}>
{Language.upgrade}
</a>
<div className={`${styles.container} ${type}`}>
<Pill text={Language.licenseIssue} type={type} lightBorder />
<div className={styles.leftContent}>
<span>{messages[0]}</span>
&nbsp;
<a href="mailto:sales@coder.com" className={styles.link}>
{Language.upgrade}
</a>
</div>
</div>
)
} else {
return (
<div className={styles.container}>
<div className={styles.flex}>
<div className={styles.leftContent}>
<Pill
text={Language.licenseIssues(warnings.length)}
type="warning"
lightBorder
/>
<span className={styles.text}>{Language.exceeded}</span>
<div className={`${styles.container} ${type}`}>
<Pill
text={Language.licenseIssues(messages.length)}
type={type}
lightBorder
/>
<div className={styles.leftContent}>
<div>
{Language.exceeded}
&nbsp;
<a href="mailto:sales@coder.com" className={styles.link}>
{Language.upgrade}
@ -50,9 +59,9 @@ export const LicenseBannerView: React.FC<LicenseBannerViewProps> = ({
</div>
<Expander expanded={showDetails} setExpanded={setShowDetails}>
<ul className={styles.list}>
{warnings.map((warning) => (
<li className={styles.listItem} key={`${warning}`}>
{warning}
{messages.map((message) => (
<li className={styles.listItem} key={`${message}`}>
{message}
</li>
))}
</ul>
@ -67,14 +76,18 @@ const useStyles = makeStyles((theme) => ({
container: {
padding: theme.spacing(1.5),
backgroundColor: theme.palette.warning.main,
display: "flex",
alignItems: "center",
"&.error": {
backgroundColor: colors.red[12],
},
},
flex: {
display: "flex",
display: "column",
},
leftContent: {
marginRight: theme.spacing(1),
},
text: {
marginLeft: theme.spacing(1),
},
link: {
@ -83,9 +96,10 @@ const useStyles = makeStyles((theme) => ({
fontWeight: "bold",
},
list: {
margin: theme.spacing(1.5),
padding: theme.spacing(1),
margin: 0,
},
listItem: {
margin: theme.spacing(1),
margin: theme.spacing(0.5),
},
}))

View File

@ -821,6 +821,7 @@ export const makeMockApiError = ({
})
export const MockEntitlements: TypesGen.Entitlements = {
errors: [],
warnings: [],
has_license: false,
features: {},
@ -829,6 +830,7 @@ export const MockEntitlements: TypesGen.Entitlements = {
}
export const MockEntitlementsWithWarnings: TypesGen.Entitlements = {
errors: [],
warnings: ["You are over your active user limit.", "And another thing."],
has_license: true,
experimental: false,
@ -852,6 +854,7 @@ export const MockEntitlementsWithWarnings: TypesGen.Entitlements = {
}
export const MockEntitlementsWithAuditLog: TypesGen.Entitlements = {
errors: [],
warnings: [],
has_license: true,
experimental: false,

View File

@ -20,6 +20,7 @@ export type EntitlementsEvent =
| { type: "HIDE_MOCK_BANNER" }
const emptyEntitlements = {
errors: [],
warnings: [],
features: {},
has_license: false,

View File

@ -48,7 +48,10 @@ type Options struct {
Addresses []netip.Prefix
DERPMap *tailcfg.DERPMap
Logger slog.Logger
// BlockEndpoints specifies whether P2P endpoints are blocked.
// If so, only DERPs can establish connections.
BlockEndpoints bool
Logger slog.Logger
}
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
@ -175,6 +178,7 @@ func NewConn(options *Options) (*Conn, error) {
wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter"))))
dialContext, dialCancel := context.WithCancel(context.Background())
server := &Conn{
blockEndpoints: options.BlockEndpoints,
dialContext: dialContext,
dialCancel: dialCancel,
closed: make(chan struct{}),
@ -240,11 +244,12 @@ func IP() netip.Addr {
// Conn is an actively listening Wireguard connection.
type Conn struct {
dialContext context.Context
dialCancel context.CancelFunc
mutex sync.Mutex
closed chan struct{}
logger slog.Logger
dialContext context.Context
dialCancel context.CancelFunc
mutex sync.Mutex
closed chan struct{}
logger slog.Logger
blockEndpoints bool
dialer *tsdial.Dialer
tunDevice *tstun.Wrapper
@ -323,6 +328,8 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
delete(c.peerMap, peer.ID)
}
for _, node := range nodes {
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
peerStatus, ok := status.Peer[node.Key]
peerNode := &tailcfg.Node{
ID: node.ID,
@ -339,6 +346,13 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
// reason. TODO: @kylecarbs debug this!
KeepAlive: ok && peerStatus.Active,
}
// If no preferred DERP is provided, don't set an IP!
if node.PreferredDERP == 0 {
peerNode.DERP = ""
}
if c.blockEndpoints {
peerNode.Endpoints = nil
}
c.peerMap[node.ID] = peerNode
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
@ -421,6 +435,7 @@ func (c *Conn) sendNode() {
}
node := &Node{
ID: c.netMap.SelfNode.ID,
AsOf: c.lastStatus,
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
@ -429,6 +444,9 @@ func (c *Conn) sendNode() {
PreferredDERP: c.lastPreferredDERP,
DERPLatency: c.lastDERPLatency,
}
if c.blockEndpoints {
node.Endpoints = nil
}
nodeCallback := c.nodeCallback
if nodeCallback == nil {
return

View File

@ -7,6 +7,7 @@ import (
"net"
"net/netip"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@ -14,10 +15,30 @@ import (
"tailscale.com/types/key"
)
// Coordinator exchanges nodes with agents to establish connections.
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
// Coordinators have different guarantees for HA support.
type Coordinator interface {
// Node returns an in-memory node by ID.
Node(id uuid.UUID) *Node
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error
// ServeAgent accepts a WebSocket connection to an agent that listens to
// incoming connections and publishes node updates.
ServeAgent(conn net.Conn, id uuid.UUID) error
// Close closes the coordinator.
Close() error
}
// Node represents a node in the network.
type Node struct {
// ID is used to identify the connection.
ID tailcfg.NodeID `json:"id"`
// AsOf is the time the node was created.
AsOf time.Time `json:"as_of"`
// Key is the Wireguard public key of the node.
Key key.NodePublic `json:"key"`
// DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections.
@ -75,48 +96,59 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func
}, errChan
}
// NewCoordinator constructs a new in-memory connection coordinator.
func NewCoordinator() *Coordinator {
return &Coordinator{
// NewCoordinator constructs a new in-memory connection coordinator. This
// coordinator is incompatible with multiple Coder replicas as all node data is
// in-memory.
func NewCoordinator() Coordinator {
return &coordinator{
closed: false,
nodes: map[uuid.UUID]*Node{},
agentSockets: map[uuid.UUID]net.Conn{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
}
}
// Coordinator exchanges nodes with agents to establish connections.
// coordinator exchanges nodes with agents to establish connections entirely in-memory.
// The Enterprise implementation provides this for high-availability.
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
// This coordinator is incompatible with multiple Coder
// replicas as all node data is in-memory.
type Coordinator struct {
mutex sync.Mutex
type coordinator struct {
mutex sync.Mutex
closed bool
// Maps agent and connection IDs to a node.
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*Node
// Maps agent ID to an open socket.
// agentSockets maps agent IDs to their open websocket.
agentSockets map[uuid.UUID]net.Conn
// Maps agent ID to connection ID for sending
// new node data as it comes in!
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
}
// Node returns an in-memory node by ID.
func (c *Coordinator) Node(id uuid.UUID) *Node {
// If the node does not exist, nil is returned.
func (c *coordinator) Node(id uuid.UUID) *Node {
c.mutex.Lock()
defer c.mutex.Unlock()
node := c.nodes[id]
return node
return c.nodes[id]
}
// ServeClient accepts a WebSocket connection that wants to
// connect to an agent with the specified ID.
func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return xerrors.New("coordinator is closed")
}
// When a new connection is requested, we update it with the latest
// node of the agent. This allows the connection to establish.
node, ok := c.nodes[agent]
c.mutex.Unlock()
if ok {
data, err := json.Marshal([]*Node{node})
if err != nil {
@ -129,6 +161,7 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
return xerrors.Errorf("write nodes: %w", err)
}
}
c.mutex.Lock()
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
connectionSockets = map[uuid.UUID]net.Conn{}
@ -156,47 +189,62 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
decoder := json.NewDecoder(conn)
for {
var node Node
err := decoder.Decode(&node)
if errors.Is(err, io.EOF) {
return nil
}
err := c.handleNextClientMessage(id, agent, decoder)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
c.mutex.Lock()
// Update the node of this client in our in-memory map.
// If an agent entirely shuts down and reconnects, it
// needs to be aware of all clients attempting to
// establish connections.
c.nodes[id] = &node
agentSocket, ok := c.agentSockets[agent]
if !ok {
c.mutex.Unlock()
continue
}
c.mutex.Unlock()
// Write the new node from this client to the actively
// connected agent.
data, err := json.Marshal([]*Node{&node})
if err != nil {
c.mutex.Unlock()
return xerrors.Errorf("marshal nodes: %w", err)
}
_, err = agentSocket.Write(data)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return xerrors.Errorf("write json: %w", err)
if errors.Is(err, io.EOF) {
return nil
}
return xerrors.Errorf("handle next client message: %w", err)
}
}
}
func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
var node Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
c.mutex.Lock()
// Update the node of this client in our in-memory map. If an agent entirely
// shuts down and reconnects, it needs to be aware of all clients attempting
// to establish connections.
c.nodes[id] = &node
agentSocket, ok := c.agentSockets[agent]
if !ok {
c.mutex.Unlock()
return nil
}
c.mutex.Unlock()
// Write the new node from this client to the actively connected agent.
data, err := json.Marshal([]*Node{&node})
if err != nil {
return xerrors.Errorf("marshal nodes: %w", err)
}
_, err = agentSocket.Write(data)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return xerrors.Errorf("write json: %w", err)
}
return nil
}
// ServeAgent accepts a WebSocket connection to an agent that
// listens to incoming connections and publishes node updates.
func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return xerrors.New("coordinator is closed")
}
sockets, ok := c.agentToConnectionSockets[id]
if ok {
// Publish all nodes that want to connect to the
@ -209,16 +257,16 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
}
nodes = append(nodes, node)
}
c.mutex.Unlock()
data, err := json.Marshal(nodes)
if err != nil {
c.mutex.Unlock()
return xerrors.Errorf("marshal json: %w", err)
}
_, err = conn.Write(data)
if err != nil {
c.mutex.Unlock()
return xerrors.Errorf("write nodes: %w", err)
}
c.mutex.Lock()
}
// If an old agent socket is connected, we close it
@ -239,36 +287,84 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
decoder := json.NewDecoder(conn)
for {
var node Node
err := decoder.Decode(&node)
if errors.Is(err, io.EOF) {
return nil
}
err := c.handleNextAgentMessage(id, decoder)
if err != nil {
return xerrors.Errorf("read json: %w", err)
if errors.Is(err, io.EOF) {
return nil
}
return xerrors.Errorf("handle next agent message: %w", err)
}
c.mutex.Lock()
c.nodes[id] = &node
connectionSockets, ok := c.agentToConnectionSockets[id]
if !ok {
c.mutex.Unlock()
continue
}
data, err := json.Marshal([]*Node{&node})
if err != nil {
return xerrors.Errorf("marshal nodes: %w", err)
}
// Publish the new node to every listening socket.
var wg sync.WaitGroup
wg.Add(len(connectionSockets))
for _, connectionSocket := range connectionSockets {
connectionSocket := connectionSocket
}
}
func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error {
var node Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
c.mutex.Lock()
c.nodes[id] = &node
connectionSockets, ok := c.agentToConnectionSockets[id]
if !ok {
c.mutex.Unlock()
return nil
}
data, err := json.Marshal([]*Node{&node})
if err != nil {
return xerrors.Errorf("marshal nodes: %w", err)
}
// Publish the new node to every listening socket.
var wg sync.WaitGroup
wg.Add(len(connectionSockets))
for _, connectionSocket := range connectionSockets {
connectionSocket := connectionSocket
go func() {
_ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second))
_, _ = connectionSocket.Write(data)
wg.Done()
}()
}
c.mutex.Unlock()
wg.Wait()
return nil
}
// Close closes all of the open connections in the coordinator and stops the
// coordinator from accepting new connections.
func (c *coordinator) Close() error {
c.mutex.Lock()
if c.closed {
return nil
}
c.closed = true
c.mutex.Unlock()
wg := sync.WaitGroup{}
wg.Add(len(c.agentSockets))
for _, socket := range c.agentSockets {
socket := socket
go func() {
_ = socket.Close()
wg.Done()
}()
}
for _, connMap := range c.agentToConnectionSockets {
wg.Add(len(connMap))
for _, socket := range connMap {
socket := socket
go func() {
_, _ = connectionSocket.Write(data)
_ = socket.Close()
wg.Done()
}()
}
c.mutex.Unlock()
wg.Wait()
}
wg.Wait()
return nil
}

View File

@ -32,8 +32,8 @@ func TestCoordinator(t *testing.T) {
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
require.NoError(t, client.Close())
require.NoError(t, server.Close())
<-errChan
<-closeChan
})

53
testutil/certificate.go Normal file
View File

@ -0,0 +1,53 @@
package testutil
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func GenerateTLSCertificate(t testing.TB, commonName string) tls.Certificate {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Acme Co"},
CommonName: commonName,
},
DNSNames: []string{commonName},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
var certFile bytes.Buffer
require.NoError(t, err)
_, err = certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
require.NoError(t, err)
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
var keyFile bytes.Buffer
err = pem.Encode(&keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes})
require.NoError(t, err)
cert, err := tls.X509KeyPair(certFile.Bytes(), keyFile.Bytes())
require.NoError(t, err)
return cert
}