mirror of https://github.com/coder/coder.git
feat: Add high availability for multiple replicas (#4555)
* feat: HA tailnet coordinator * fixup! feat: HA tailnet coordinator * fixup! feat: HA tailnet coordinator * remove printlns * close all connections on coordinator * impelement high availability feature * fixup! impelement high availability feature * fixup! impelement high availability feature * fixup! impelement high availability feature * fixup! impelement high availability feature * Add replicas * Add DERP meshing to arbitrary addresses * Move packages to highavailability folder * Move coordinator to high availability package * Add flags for HA * Rename to replicasync * Denest packages for replicas * Add test for multiple replicas * Fix coordination test * Add HA to the helm chart * Rename function pointer * Add warnings for HA * Add the ability to block endpoints * Add flag to disable P2P connections * Wow, I made the tests pass * Add replicas endpoint * Ensure close kills replica * Update sql * Add database latency to high availability * Pipe TLS to DERP mesh * Fix DERP mesh with TLS * Add tests for TLS * Fix replica sync TLS * Fix RootCA for replica meshing * Remove ID from replicasync * Fix getting certificates for meshing * Remove excessive locking * Fix linting * Store mesh key in the database * Fix replica key for tests * Fix types gen * Fix unlocking unlocked * Fix race in tests * Update enterprise/derpmesh/derpmesh.go Co-authored-by: Colin Adler <colin1adler@gmail.com> * Rename to syncReplicas * Reuse http client * Delete old replicas on a CRON * Fix race condition in connection tests * Fix linting * Fix nil type * Move pubsub to in-memory for twenty test * Add comment for configuration tweaking * Fix leak with transport * Fix close leak in derpmesh * Fix race when creating server * Remove handler update * Skip test on Windows * Fix DERP mesh test * Wrap HTTP handler replacement in mutex * Fix error message for relay * Fix API handler for normal tests * Fix speedtest * Fix replica resend * Fix derpmesh send * Ping async * Increase wait time of template version jobd * Fix race when closing replica sync * Add name to client * Log the derpmap being used * Don't connect if DERP is empty * Improve agent coordinator logging * Fix lock in coordinator * Fix relay addr * Fix race when updating durations * Fix client publish race * Run pubsub loop in a queue * Store agent nodes in order * Fix coordinator locking * Check for closed pipe Co-authored-by: Colin Adler <colin1adler@gmail.com>
This commit is contained in:
parent
dc3519e973
commit
2ba4a62a0d
|
@ -19,6 +19,7 @@
|
|||
"derphttp",
|
||||
"derpmap",
|
||||
"devel",
|
||||
"dflags",
|
||||
"drpc",
|
||||
"drpcconn",
|
||||
"drpcmux",
|
||||
|
@ -86,8 +87,10 @@
|
|||
"ptytest",
|
||||
"quickstart",
|
||||
"reconfig",
|
||||
"replicasync",
|
||||
"retrier",
|
||||
"rpty",
|
||||
"SCIM",
|
||||
"sdkproto",
|
||||
"sdktrace",
|
||||
"Signup",
|
||||
|
|
|
@ -170,6 +170,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
|||
if a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", derpMap))
|
||||
if a.network != nil {
|
||||
a.network.SetDERPMap(derpMap)
|
||||
return
|
||||
|
|
|
@ -465,7 +465,7 @@ func TestAgent(t *testing.T) {
|
|||
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping()
|
||||
_, err := conn.Ping(context.Background())
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
|
@ -483,9 +483,7 @@ func TestAgent(t *testing.T) {
|
|||
|
||||
t.Run("Speedtest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.Skip("The minimum duration for a speedtest is hardcoded in Tailscale to 5s!")
|
||||
}
|
||||
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{
|
||||
DERPMap: derpMap,
|
||||
|
|
|
@ -7,8 +7,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
|
@ -67,11 +65,11 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping()
|
||||
_, err := dialer.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
|
@ -128,11 +126,11 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping()
|
||||
_, err := dialer.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
|
@ -189,11 +187,11 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping()
|
||||
_, err := dialer.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
|
|
|
@ -13,6 +13,11 @@ func (r Root) Session() File {
|
|||
return File(filepath.Join(string(r), "session"))
|
||||
}
|
||||
|
||||
// ReplicaID is a unique identifier for the Coder server.
|
||||
func (r Root) ReplicaID() File {
|
||||
return File(filepath.Join(string(r), "replica_id"))
|
||||
}
|
||||
|
||||
func (r Root) URL() File {
|
||||
return File(filepath.Join(string(r), "url"))
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
|
@ -115,7 +114,7 @@ func TestConfigSSH(t *testing.T) {
|
|||
_ = agentCloser.Close()
|
||||
}()
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID)
|
||||
agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer agentConn.Close()
|
||||
|
||||
|
|
|
@ -85,6 +85,13 @@ func Flags() *codersdk.DeploymentFlags {
|
|||
Description: "Addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections.",
|
||||
Default: []string{"stun.l.google.com:19302"},
|
||||
},
|
||||
DerpServerRelayAddress: &codersdk.StringFlag{
|
||||
Name: "DERP Server Relay Address",
|
||||
Flag: "derp-server-relay-address",
|
||||
EnvVar: "CODER_DERP_SERVER_RELAY_ADDRESS",
|
||||
Description: "An HTTP address that is accessible by other replicas to relay DERP traffic. Required for high availability.",
|
||||
Enterprise: true,
|
||||
},
|
||||
DerpConfigURL: &codersdk.StringFlag{
|
||||
Name: "DERP Config URL",
|
||||
Flag: "derp-config-url",
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
|
@ -96,7 +95,7 @@ func portForward() *cobra.Command {
|
|||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -156,7 +155,7 @@ func portForward() *cobra.Command {
|
|||
case <-ticker.C:
|
||||
}
|
||||
|
||||
_, err = conn.Ping()
|
||||
_, err = conn.Ping(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -100,8 +101,9 @@ func Core() []*cobra.Command {
|
|||
}
|
||||
|
||||
func AGPL() []*cobra.Command {
|
||||
all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, error) {
|
||||
return coderd.New(o), nil
|
||||
all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) {
|
||||
api := coderd.New(o)
|
||||
return api, api, nil
|
||||
}))
|
||||
return all
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ import (
|
|||
)
|
||||
|
||||
// nolint:gocyclo
|
||||
func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, error)) *cobra.Command {
|
||||
func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command {
|
||||
root := &cobra.Command{
|
||||
Use: "server",
|
||||
Short: "Start a Coder server",
|
||||
|
@ -167,9 +167,10 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
|
|||
}
|
||||
defer listener.Close()
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if dflags.TLSEnable.Value {
|
||||
listener, err = configureServerTLS(
|
||||
listener, dflags.TLSMinVersion.Value,
|
||||
tlsConfig, err = configureTLS(
|
||||
dflags.TLSMinVersion.Value,
|
||||
dflags.TLSClientAuth.Value,
|
||||
dflags.TLSCertFiles.Value,
|
||||
dflags.TLSKeyFiles.Value,
|
||||
|
@ -178,6 +179,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
|
|||
if err != nil {
|
||||
return xerrors.Errorf("configure tls: %w", err)
|
||||
}
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
}
|
||||
|
||||
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
|
||||
|
@ -328,6 +330,9 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
|
|||
Experimental: ExperimentalEnabled(cmd),
|
||||
DeploymentFlags: dflags,
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
options.TLSCertificates = tlsConfig.Certificates
|
||||
}
|
||||
|
||||
if dflags.OAuth2GithubClientSecret.Value != "" {
|
||||
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed,
|
||||
|
@ -471,11 +476,14 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
|
|||
), dflags.PromAddress.Value, "prometheus")()
|
||||
}
|
||||
|
||||
coderAPI, err := newAPI(ctx, options)
|
||||
// We use a separate closer so the Enterprise API
|
||||
// can have it's own close functions. This is cleaner
|
||||
// than abstracting the Coder API itself.
|
||||
coderAPI, closer, err := newAPI(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer coderAPI.Close()
|
||||
defer closer.Close()
|
||||
|
||||
client := codersdk.New(localURL)
|
||||
if dflags.TLSEnable.Value {
|
||||
|
@ -893,7 +901,7 @@ func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, er
|
|||
return certs, nil
|
||||
}
|
||||
|
||||
func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (net.Listener, error) {
|
||||
func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
@ -929,6 +937,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("load certificates: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = certs
|
||||
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// If there's only one certificate, return it.
|
||||
if len(certs) == 1 {
|
||||
|
@ -963,7 +972,7 @@ func configureServerTLS(listener net.Listener, tlsMinVersion, tlsClientAuth stri
|
|||
tlsConfig.ClientCAs = caPool
|
||||
}
|
||||
|
||||
return tls.NewListener(listener, tlsConfig), nil
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {
|
||||
|
|
|
@ -55,7 +55,9 @@ func speedtest() *cobra.Command {
|
|||
if cliflag.IsSetBool(cmd, varVerbose) {
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -68,7 +70,7 @@ func speedtest() *cobra.Command {
|
|||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
dur, err := conn.Ping()
|
||||
dur, err := conn.Ping(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -20,8 +20,6 @@ import (
|
|||
"golang.org/x/term"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/autobuild/notify"
|
||||
|
@ -86,7 +84,7 @@ func ssh() *cobra.Command {
|
|||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
|
|||
"deadline %v never updated", firstDeadline,
|
||||
)
|
||||
|
||||
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, time.Second)
|
||||
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, 3*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,9 @@ func TestWorkspaceActivityBump(t *testing.T) {
|
|||
client, workspace, assertBumped := setupActivityTest(t)
|
||||
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil), resources[0].Agents[0].ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -82,7 +83,10 @@ type Options struct {
|
|||
TracerProvider trace.TracerProvider
|
||||
AutoImportTemplates []AutoImportTemplate
|
||||
|
||||
TailnetCoordinator *tailnet.Coordinator
|
||||
// TLSCertificates is used to mesh DERP servers securely.
|
||||
TLSCertificates []tls.Certificate
|
||||
TailnetCoordinator tailnet.Coordinator
|
||||
DERPServer *derp.Server
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
|
@ -130,6 +134,9 @@ func New(options *Options) *API {
|
|||
if options.TailnetCoordinator == nil {
|
||||
options.TailnetCoordinator = tailnet.NewCoordinator()
|
||||
}
|
||||
if options.DERPServer == nil {
|
||||
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
|
||||
}
|
||||
if options.Auditor == nil {
|
||||
options.Auditor = audit.NewNop()
|
||||
}
|
||||
|
@ -168,7 +175,7 @@ func New(options *Options) *API {
|
|||
api.Auditor.Store(&options.Auditor)
|
||||
api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
|
||||
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
OIDC: options.OIDCConfig,
|
||||
|
@ -246,7 +253,7 @@ func New(options *Options) *API {
|
|||
r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
|
||||
r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
|
||||
r.Route("/derp", func(r chi.Router) {
|
||||
r.Get("/", derphttp.Handler(api.derpServer).ServeHTTP)
|
||||
r.Get("/", derphttp.Handler(api.DERPServer).ServeHTTP)
|
||||
// This is used when UDP is blocked, and latency must be checked via HTTP(s).
|
||||
r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
@ -550,6 +557,7 @@ type API struct {
|
|||
Auditor atomic.Pointer[audit.Auditor]
|
||||
WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool]
|
||||
WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer]
|
||||
TailnetCoordinator atomic.Pointer[tailnet.Coordinator]
|
||||
HTTPAuth *HTTPAuthorizer
|
||||
|
||||
// APIHandler serves "/api/v2"
|
||||
|
@ -557,7 +565,6 @@ type API struct {
|
|||
// RootHandler serves "/"
|
||||
RootHandler chi.Router
|
||||
|
||||
derpServer *derp.Server
|
||||
metricsCache *metricscache.Cache
|
||||
siteHandler http.Handler
|
||||
websocketWaitMutex sync.Mutex
|
||||
|
@ -572,7 +579,10 @@ func (api *API) Close() error {
|
|||
api.websocketWaitMutex.Unlock()
|
||||
|
||||
api.metricsCache.Close()
|
||||
|
||||
coordinator := api.TailnetCoordinator.Load()
|
||||
if coordinator != nil {
|
||||
_ = (*coordinator).Close()
|
||||
}
|
||||
return api.workspaceAgentCache.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
|
@ -23,6 +24,7 @@ import (
|
|||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -37,8 +39,10 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/net/stun/stuntest"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/nettype"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
@ -60,6 +64,7 @@ import (
|
|||
"github.com/coder/coder/provisionerd"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
|
@ -77,12 +82,19 @@ type Options struct {
|
|||
AutobuildTicker <-chan time.Time
|
||||
AutobuildStats chan<- executor.Stats
|
||||
Auditor audit.Auditor
|
||||
TLSCertificates []tls.Certificate
|
||||
|
||||
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
|
||||
IncludeProvisionerDaemon bool
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DeploymentFlags *codersdk.DeploymentFlags
|
||||
|
||||
// Overriding the database is heavily discouraged.
|
||||
// It should only be used in cases where multiple Coder
|
||||
// test instances are running against the same database.
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
}
|
||||
|
||||
// New constructs a codersdk client connected to an in-memory API instance.
|
||||
|
@ -116,7 +128,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer)
|
|||
return client, closer
|
||||
}
|
||||
|
||||
func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) {
|
||||
func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.CancelFunc, *coderd.Options) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
|
@ -137,23 +149,40 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
|
|||
close(options.AutobuildStats)
|
||||
})
|
||||
}
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
if options.Database == nil {
|
||||
options.Database, options.Pubsub = dbtestutil.NewDB(t)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
lifecycleExecutor := executor.New(
|
||||
ctx,
|
||||
db,
|
||||
options.Database,
|
||||
slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug),
|
||||
options.AutobuildTicker,
|
||||
).WithStatsChannel(options.AutobuildStats)
|
||||
lifecycleExecutor.Run()
|
||||
|
||||
srv := httptest.NewUnstartedServer(nil)
|
||||
var mutex sync.RWMutex
|
||||
var handler http.Handler
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mutex.RLock()
|
||||
defer mutex.RUnlock()
|
||||
if handler != nil {
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}))
|
||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
return ctx
|
||||
}
|
||||
srv.Start()
|
||||
if options.TLSCertificates != nil {
|
||||
srv.TLS = &tls.Config{
|
||||
Certificates: options.TLSCertificates,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
srv.StartTLS()
|
||||
} else {
|
||||
srv.Start()
|
||||
}
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr)
|
||||
|
@ -169,6 +198,9 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
|
|||
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
|
||||
t.Cleanup(stunCleanup)
|
||||
|
||||
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp")))
|
||||
derpServer.SetMeshKey("test-key")
|
||||
|
||||
// match default with cli default
|
||||
if options.SSHKeygenAlgorithm == "" {
|
||||
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
|
||||
|
@ -181,53 +213,59 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return srv, cancelFunc, &coderd.Options{
|
||||
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
|
||||
// Force a long disconnection timeout to ensure
|
||||
// agents are not marked as disconnected during slow tests.
|
||||
AgentInactiveDisconnectTimeout: testutil.WaitShort,
|
||||
AccessURL: serverURL,
|
||||
AppHostname: options.AppHostname,
|
||||
AppHostnameRegex: appHostnameRegex,
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
CacheDir: t.TempDir(),
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
return func(h http.Handler) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
handler = h
|
||||
}, cancelFunc, &coderd.Options{
|
||||
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
|
||||
// Force a long disconnection timeout to ensure
|
||||
// agents are not marked as disconnected during slow tests.
|
||||
AgentInactiveDisconnectTimeout: testutil.WaitShort,
|
||||
AccessURL: serverURL,
|
||||
AppHostname: options.AppHostname,
|
||||
AppHostnameRegex: appHostnameRegex,
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
CacheDir: t.TempDir(),
|
||||
Database: options.Database,
|
||||
Pubsub: options.Pubsub,
|
||||
|
||||
Auditor: options.Auditor,
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
OIDCConfig: options.OIDCConfig,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
APIRateLimit: options.APIRateLimit,
|
||||
Authorizer: options.Authorizer,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
DERPMap: &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: {
|
||||
EmbeddedRelay: true,
|
||||
RegionID: 1,
|
||||
RegionCode: "coder",
|
||||
RegionName: "Coder",
|
||||
Nodes: []*tailcfg.DERPNode{{
|
||||
Name: "1a",
|
||||
RegionID: 1,
|
||||
IPv4: "127.0.0.1",
|
||||
DERPPort: derpPort,
|
||||
STUNPort: stunAddr.Port,
|
||||
InsecureForTests: true,
|
||||
ForceHTTP: true,
|
||||
}},
|
||||
Auditor: options.Auditor,
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
OIDCConfig: options.OIDCConfig,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
DERPServer: derpServer,
|
||||
APIRateLimit: options.APIRateLimit,
|
||||
Authorizer: options.Authorizer,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
TLSCertificates: options.TLSCertificates,
|
||||
DERPMap: &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: {
|
||||
EmbeddedRelay: true,
|
||||
RegionID: 1,
|
||||
RegionCode: "coder",
|
||||
RegionName: "Coder",
|
||||
Nodes: []*tailcfg.DERPNode{{
|
||||
Name: "1a",
|
||||
RegionID: 1,
|
||||
IPv4: "127.0.0.1",
|
||||
DERPPort: derpPort,
|
||||
STUNPort: stunAddr.Port,
|
||||
InsecureForTests: true,
|
||||
ForceHTTP: options.TLSCertificates == nil,
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AutoImportTemplates: options.AutoImportTemplates,
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
DeploymentFlags: options.DeploymentFlags,
|
||||
}
|
||||
AutoImportTemplates: options.AutoImportTemplates,
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
DeploymentFlags: options.DeploymentFlags,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
|
||||
|
@ -237,10 +275,10 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
srv, cancelFunc, newOptions := NewOptions(t, options)
|
||||
setHandler, cancelFunc, newOptions := NewOptions(t, options)
|
||||
// We set the handler after server creation for the access URL.
|
||||
coderAPI := coderd.New(newOptions)
|
||||
srv.Config.Handler = coderAPI.RootHandler
|
||||
setHandler(coderAPI.RootHandler)
|
||||
var provisionerCloser io.Closer = nopcloser{}
|
||||
if options.IncludeProvisionerDaemon {
|
||||
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
|
||||
|
@ -459,7 +497,7 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid
|
|||
var err error
|
||||
templateVersion, err = client.TemplateVersion(context.Background(), version)
|
||||
return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
return templateVersion
|
||||
}
|
||||
|
||||
|
|
|
@ -107,11 +107,17 @@ type data struct {
|
|||
workspaceApps []database.WorkspaceApp
|
||||
workspaces []database.Workspace
|
||||
licenses []database.License
|
||||
replicas []database.Replica
|
||||
|
||||
deploymentID string
|
||||
derpMeshKey string
|
||||
lastLicenseID int32
|
||||
}
|
||||
|
||||
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// InTx doesn't rollback data properly for in-memory yet.
|
||||
func (q *fakeQuerier) InTx(fn func(database.Store) error) error {
|
||||
q.mutex.Lock()
|
||||
|
@ -2931,6 +2937,21 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
|
|||
return q.deploymentID, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.derpMeshKey = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.derpMeshKey, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertLicense(
|
||||
_ context.Context, arg database.InsertLicenseParams,
|
||||
) (database.License, error) {
|
||||
|
@ -3196,3 +3217,70 @@ func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error {
|
|||
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i, replica := range q.replicas {
|
||||
if replica.UpdatedAt.Before(before) {
|
||||
q.replicas = append(q.replicas[:i], q.replicas[i+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
replica := database.Replica{
|
||||
ID: arg.ID,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
StartedAt: arg.StartedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
Hostname: arg.Hostname,
|
||||
RegionID: arg.RegionID,
|
||||
RelayAddress: arg.RelayAddress,
|
||||
Version: arg.Version,
|
||||
DatabaseLatency: arg.DatabaseLatency,
|
||||
}
|
||||
q.replicas = append(q.replicas, replica)
|
||||
return replica, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for index, replica := range q.replicas {
|
||||
if replica.ID != arg.ID {
|
||||
continue
|
||||
}
|
||||
replica.Hostname = arg.Hostname
|
||||
replica.StartedAt = arg.StartedAt
|
||||
replica.StoppedAt = arg.StoppedAt
|
||||
replica.UpdatedAt = arg.UpdatedAt
|
||||
replica.RelayAddress = arg.RelayAddress
|
||||
replica.RegionID = arg.RegionID
|
||||
replica.Version = arg.Version
|
||||
replica.Error = arg.Error
|
||||
replica.DatabaseLatency = arg.DatabaseLatency
|
||||
q.replicas[index] = replica
|
||||
return replica, nil
|
||||
}
|
||||
return database.Replica{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
replicas := make([]database.Replica, 0)
|
||||
for _, replica := range q.replicas {
|
||||
if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid {
|
||||
replicas = append(replicas, replica)
|
||||
}
|
||||
}
|
||||
return replicas, nil
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -24,6 +25,7 @@ type Store interface {
|
|||
// customQuerier contains custom queries that are not generated.
|
||||
customQuerier
|
||||
|
||||
Ping(ctx context.Context) (time.Duration, error)
|
||||
InTx(func(Store) error) error
|
||||
}
|
||||
|
||||
|
@ -58,6 +60,13 @@ type sqlQuerier struct {
|
|||
db DBTX
|
||||
}
|
||||
|
||||
// Ping returns the time it takes to ping the database.
|
||||
func (q *sqlQuerier) Ping(ctx context.Context) (time.Duration, error) {
|
||||
start := time.Now()
|
||||
err := q.sdb.PingContext(ctx)
|
||||
return time.Since(start), err
|
||||
}
|
||||
|
||||
// InTx performs database operations inside a transaction.
|
||||
func (q *sqlQuerier) InTx(function func(Store) error) error {
|
||||
if _, ok := q.db.(*sqlx.Tx); ok {
|
||||
|
|
|
@ -256,7 +256,8 @@ CREATE TABLE provisioner_daemons (
|
|||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone,
|
||||
name character varying(64) NOT NULL,
|
||||
provisioners provisioner_type[] NOT NULL
|
||||
provisioners provisioner_type[] NOT NULL,
|
||||
replica_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE provisioner_job_logs (
|
||||
|
@ -287,6 +288,20 @@ CREATE TABLE provisioner_jobs (
|
|||
file_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE replicas (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
started_at timestamp with time zone NOT NULL,
|
||||
stopped_at timestamp with time zone,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
hostname text NOT NULL,
|
||||
region_id integer NOT NULL,
|
||||
relay_address text NOT NULL,
|
||||
database_latency integer NOT NULL,
|
||||
version text NOT NULL,
|
||||
error text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE site_configs (
|
||||
key character varying(256) NOT NULL,
|
||||
value character varying(8192) NOT NULL
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
DROP TABLE replicas;
|
||||
ALTER TABLE provisioner_daemons DROP COLUMN replica_id;
|
|
@ -0,0 +1,28 @@
|
|||
CREATE TABLE IF NOT EXISTS replicas (
|
||||
-- A unique identifier for the replica that is stored on disk.
|
||||
-- For persistent replicas, this will be reused.
|
||||
-- For ephemeral replicas, this will be a new UUID for each one.
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
-- The time the replica was created.
|
||||
started_at timestamp with time zone NOT NULL,
|
||||
-- The time the replica was last seen.
|
||||
stopped_at timestamp with time zone,
|
||||
-- Updated periodically to ensure the replica is still alive.
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
-- Hostname is the hostname of the replica.
|
||||
hostname text NOT NULL,
|
||||
-- Region is the region the replica is in.
|
||||
-- We only DERP mesh to the same region ID of a running replica.
|
||||
region_id integer NOT NULL,
|
||||
-- An address that should be accessible to other replicas.
|
||||
relay_address text NOT NULL,
|
||||
-- The latency of the replica to the database in microseconds.
|
||||
database_latency int NOT NULL,
|
||||
-- Version is the Coder version of the replica.
|
||||
version text NOT NULL,
|
||||
error text NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
-- Associates a provisioner daemon with a replica.
|
||||
ALTER TABLE provisioner_daemons ADD COLUMN replica_id uuid;
|
|
@ -508,6 +508,7 @@ type ProvisionerDaemon struct {
|
|||
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
|
||||
ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"`
|
||||
}
|
||||
|
||||
type ProvisionerJob struct {
|
||||
|
@ -538,6 +539,20 @@ type ProvisionerJobLog struct {
|
|||
Output string `db:"output" json:"output"`
|
||||
}
|
||||
|
||||
type Replica struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Hostname string `db:"hostname" json:"hostname"`
|
||||
RegionID int32 `db:"region_id" json:"region_id"`
|
||||
RelayAddress string `db:"relay_address" json:"relay_address"`
|
||||
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
|
||||
Version string `db:"version" json:"version"`
|
||||
Error string `db:"error" json:"error"`
|
||||
}
|
||||
|
||||
type SiteConfig struct {
|
||||
Key string `db:"key" json:"key"`
|
||||
Value string `db:"value" json:"value"`
|
||||
|
|
|
@ -47,8 +47,9 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
|
|||
return nil
|
||||
}
|
||||
for _, listener := range listeners {
|
||||
listener(context.Background(), message)
|
||||
go listener(context.Background(), message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ type sqlcQuerier interface {
|
|||
DeleteLicense(ctx context.Context, id int32) (int32, error)
|
||||
DeleteOldAgentStats(ctx context.Context) error
|
||||
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
|
@ -38,6 +39,7 @@ type sqlcQuerier interface {
|
|||
// This function returns roles for authorization purposes. Implied member roles
|
||||
// are included.
|
||||
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
|
||||
GetDERPMeshKey(ctx context.Context) (string, error)
|
||||
GetDeploymentID(ctx context.Context) (string, error)
|
||||
GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error)
|
||||
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
|
||||
|
@ -67,6 +69,7 @@ type sqlcQuerier interface {
|
|||
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
|
||||
GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error)
|
||||
GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error)
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error)
|
||||
GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error)
|
||||
GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error)
|
||||
|
@ -123,6 +126,7 @@ type sqlcQuerier interface {
|
|||
// every member of the org.
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertDERPMeshKey(ctx context.Context, value string) error
|
||||
InsertDeploymentID(ctx context.Context, value string) error
|
||||
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
|
||||
InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error)
|
||||
|
@ -136,6 +140,7 @@ type sqlcQuerier interface {
|
|||
InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error)
|
||||
InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error)
|
||||
InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error)
|
||||
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
|
||||
InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error)
|
||||
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error)
|
||||
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
||||
|
@ -156,6 +161,7 @@ type sqlcQuerier interface {
|
|||
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
|
||||
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
|
||||
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error
|
||||
UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error)
|
||||
UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error
|
||||
UpdateTemplateDeletedByID(ctx context.Context, arg UpdateTemplateDeletedByIDParams) error
|
||||
UpdateTemplateMetaByID(ctx context.Context, arg UpdateTemplateMetaByIDParams) (Template, error)
|
||||
|
|
|
@ -2031,7 +2031,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar
|
|||
|
||||
const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, name, provisioners
|
||||
id, created_at, updated_at, name, provisioners, replica_id
|
||||
FROM
|
||||
provisioner_daemons
|
||||
WHERE
|
||||
|
@ -2047,13 +2047,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID)
|
|||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
pq.Array(&i.Provisioners),
|
||||
&i.ReplicaID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many
|
||||
SELECT
|
||||
id, created_at, updated_at, name, provisioners
|
||||
id, created_at, updated_at, name, provisioners, replica_id
|
||||
FROM
|
||||
provisioner_daemons
|
||||
`
|
||||
|
@ -2073,6 +2074,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
|
|||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
pq.Array(&i.Provisioners),
|
||||
&i.ReplicaID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -2096,7 +2098,7 @@ INSERT INTO
|
|||
provisioners
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners
|
||||
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id
|
||||
`
|
||||
|
||||
type InsertProvisionerDaemonParams struct {
|
||||
|
@ -2120,6 +2122,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
|
|||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
pq.Array(&i.Provisioners),
|
||||
&i.ReplicaID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -2577,6 +2580,177 @@ func (q *sqlQuerier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, a
|
|||
return err
|
||||
}
|
||||
|
||||
const deleteReplicasUpdatedBefore = `-- name: DeleteReplicasUpdatedBefore :exec
|
||||
DELETE FROM replicas WHERE updated_at < $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteReplicasUpdatedBefore, updatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const getReplicasUpdatedAfter = `-- name: GetReplicasUpdatedAfter :many
|
||||
SELECT id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getReplicasUpdatedAfter, updatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Replica
|
||||
for rows.Next() {
|
||||
var i Replica
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.StoppedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Hostname,
|
||||
&i.RegionID,
|
||||
&i.RelayAddress,
|
||||
&i.DatabaseLatency,
|
||||
&i.Version,
|
||||
&i.Error,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertReplica = `-- name: InsertReplica :one
|
||||
INSERT INTO replicas (
|
||||
id,
|
||||
created_at,
|
||||
started_at,
|
||||
updated_at,
|
||||
hostname,
|
||||
region_id,
|
||||
relay_address,
|
||||
version,
|
||||
database_latency
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error
|
||||
`
|
||||
|
||||
type InsertReplicaParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Hostname string `db:"hostname" json:"hostname"`
|
||||
RegionID int32 `db:"region_id" json:"region_id"`
|
||||
RelayAddress string `db:"relay_address" json:"relay_address"`
|
||||
Version string `db:"version" json:"version"`
|
||||
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertReplica,
|
||||
arg.ID,
|
||||
arg.CreatedAt,
|
||||
arg.StartedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.Hostname,
|
||||
arg.RegionID,
|
||||
arg.RelayAddress,
|
||||
arg.Version,
|
||||
arg.DatabaseLatency,
|
||||
)
|
||||
var i Replica
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.StoppedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Hostname,
|
||||
&i.RegionID,
|
||||
&i.RelayAddress,
|
||||
&i.DatabaseLatency,
|
||||
&i.Version,
|
||||
&i.Error,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateReplica = `-- name: UpdateReplica :one
|
||||
UPDATE replicas SET
|
||||
updated_at = $2,
|
||||
started_at = $3,
|
||||
stopped_at = $4,
|
||||
relay_address = $5,
|
||||
region_id = $6,
|
||||
hostname = $7,
|
||||
version = $8,
|
||||
error = $9,
|
||||
database_latency = $10
|
||||
WHERE id = $1 RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error
|
||||
`
|
||||
|
||||
type UpdateReplicaParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"`
|
||||
RelayAddress string `db:"relay_address" json:"relay_address"`
|
||||
RegionID int32 `db:"region_id" json:"region_id"`
|
||||
Hostname string `db:"hostname" json:"hostname"`
|
||||
Version string `db:"version" json:"version"`
|
||||
Error string `db:"error" json:"error"`
|
||||
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateReplica,
|
||||
arg.ID,
|
||||
arg.UpdatedAt,
|
||||
arg.StartedAt,
|
||||
arg.StoppedAt,
|
||||
arg.RelayAddress,
|
||||
arg.RegionID,
|
||||
arg.Hostname,
|
||||
arg.Version,
|
||||
arg.Error,
|
||||
arg.DatabaseLatency,
|
||||
)
|
||||
var i Replica
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.StoppedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Hostname,
|
||||
&i.RegionID,
|
||||
&i.RelayAddress,
|
||||
&i.DatabaseLatency,
|
||||
&i.Version,
|
||||
&i.Error,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getDERPMeshKey = `-- name: GetDERPMeshKey :one
|
||||
SELECT value FROM site_configs WHERE key = 'derp_mesh_key'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getDERPMeshKey)
|
||||
var value string
|
||||
err := row.Scan(&value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
const getDeploymentID = `-- name: GetDeploymentID :one
|
||||
SELECT value FROM site_configs WHERE key = 'deployment_id'
|
||||
`
|
||||
|
@ -2588,6 +2762,15 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) {
|
|||
return value, err
|
||||
}
|
||||
|
||||
const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) InsertDERPMeshKey(ctx context.Context, value string) error {
|
||||
_, err := q.db.ExecContext(ctx, insertDERPMeshKey, value)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertDeploymentID = `-- name: InsertDeploymentID :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1)
|
||||
`
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
-- name: GetReplicasUpdatedAfter :many
|
||||
SELECT * FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL;
|
||||
|
||||
-- name: InsertReplica :one
|
||||
INSERT INTO replicas (
|
||||
id,
|
||||
created_at,
|
||||
started_at,
|
||||
updated_at,
|
||||
hostname,
|
||||
region_id,
|
||||
relay_address,
|
||||
version,
|
||||
database_latency
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
|
||||
|
||||
-- name: UpdateReplica :one
|
||||
UPDATE replicas SET
|
||||
updated_at = $2,
|
||||
started_at = $3,
|
||||
stopped_at = $4,
|
||||
relay_address = $5,
|
||||
region_id = $6,
|
||||
hostname = $7,
|
||||
version = $8,
|
||||
error = $9,
|
||||
database_latency = $10
|
||||
WHERE id = $1 RETURNING *;
|
||||
|
||||
-- name: DeleteReplicasUpdatedBefore :exec
|
||||
DELETE FROM replicas WHERE updated_at < $1;
|
|
@ -3,3 +3,9 @@ INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1);
|
|||
|
||||
-- name: GetDeploymentID :one
|
||||
SELECT value FROM site_configs WHERE key = 'deployment_id';
|
||||
|
||||
-- name: InsertDERPMeshKey :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1);
|
||||
|
||||
-- name: GetDERPMeshKey :one
|
||||
SELECT value FROM site_configs WHERE key = 'derp_mesh_key';
|
||||
|
|
|
@ -270,7 +270,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
|
|||
}
|
||||
}
|
||||
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading job agent.",
|
||||
|
|
|
@ -146,6 +146,10 @@ var (
|
|||
ResourceDeploymentFlags = Object{
|
||||
Type: "deployment_flags",
|
||||
}
|
||||
|
||||
ResourceReplicas = Object{
|
||||
Type: "replicas",
|
||||
}
|
||||
)
|
||||
|
||||
// Object is used to create objects for authz checks when you have none in
|
||||
|
|
|
@ -627,7 +627,9 @@ func TestTemplateMetrics(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.Zero(t, workspaces[0].LastUsedAt)
|
||||
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("tailnet"), resources[0].Agents[0].ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("tailnet"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
|
|
|
@ -49,7 +49,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
return
|
||||
}
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
|
@ -78,7 +78,7 @@ func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) {
|
|||
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
|
@ -98,7 +98,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
|
|||
func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
|
@ -152,7 +152,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
|||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
|
@ -229,7 +229,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
|||
return
|
||||
}
|
||||
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
|
@ -376,8 +376,9 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
|
|||
})
|
||||
conn.SetNodeCallback(sendNodes)
|
||||
go func() {
|
||||
err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
@ -514,8 +515,9 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(closeChan)
|
||||
err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID)
|
||||
err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||
return
|
||||
}
|
||||
|
@ -583,7 +585,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
|||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
err = api.TailnetCoordinator.ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
|
||||
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||
return
|
||||
|
@ -611,7 +613,7 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
|
|||
return apps
|
||||
}
|
||||
|
||||
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
|
||||
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
|
||||
var envs map[string]string
|
||||
if dbAgent.EnvironmentVariables.Valid {
|
||||
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
|
||||
|
|
|
@ -123,13 +123,13 @@ func TestWorkspaceAgentListen(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping()
|
||||
_, err := conn.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
})
|
||||
|
@ -253,7 +253,9 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
|
|||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
sshClient, err := conn.SSHClient()
|
||||
|
|
|
@ -861,7 +861,7 @@ func (api *API) convertWorkspaceBuild(
|
|||
apiAgents := make([]codersdk.WorkspaceAgent, 0)
|
||||
for _, agent := range agents {
|
||||
apps := appsByAgentID[agent.ID]
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(apps), api.AgentInactiveDisconnectTimeout)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(apps), api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceBuild{}, xerrors.Errorf("converting workspace agent: %w", err)
|
||||
}
|
||||
|
|
|
@ -128,7 +128,9 @@ func TestCache(t *testing.T) {
|
|||
return
|
||||
}
|
||||
defer release()
|
||||
proxy.Transport = conn.HTTPTransport()
|
||||
transport := conn.HTTPTransport()
|
||||
defer transport.CloseIdleConnections()
|
||||
proxy.Transport = transport
|
||||
res := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(res, req)
|
||||
resp := res.Result()
|
||||
|
|
|
@ -132,10 +132,10 @@ type AgentConn struct {
|
|||
CloseFunc func()
|
||||
}
|
||||
|
||||
func (c *AgentConn) Ping() (time.Duration, error) {
|
||||
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
|
||||
errCh := make(chan error, 1)
|
||||
durCh := make(chan time.Duration, 1)
|
||||
c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
|
||||
go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
|
||||
if pr.Err != "" {
|
||||
errCh <- xerrors.New(pr.Err)
|
||||
return
|
||||
|
@ -145,6 +145,8 @@ func (c *AgentConn) Ping() (time.Duration, error) {
|
|||
select {
|
||||
case err := <-errCh:
|
||||
return 0, err
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
case dur := <-durCh:
|
||||
return dur, nil
|
||||
}
|
||||
|
|
|
@ -15,12 +15,13 @@ const (
|
|||
)
|
||||
|
||||
const (
|
||||
FeatureUserLimit = "user_limit"
|
||||
FeatureAuditLog = "audit_log"
|
||||
FeatureBrowserOnly = "browser_only"
|
||||
FeatureSCIM = "scim"
|
||||
FeatureWorkspaceQuota = "workspace_quota"
|
||||
FeatureTemplateRBAC = "template_rbac"
|
||||
FeatureUserLimit = "user_limit"
|
||||
FeatureAuditLog = "audit_log"
|
||||
FeatureBrowserOnly = "browser_only"
|
||||
FeatureSCIM = "scim"
|
||||
FeatureWorkspaceQuota = "workspace_quota"
|
||||
FeatureTemplateRBAC = "template_rbac"
|
||||
FeatureHighAvailability = "high_availability"
|
||||
)
|
||||
|
||||
var FeatureNames = []string{
|
||||
|
@ -30,6 +31,7 @@ var FeatureNames = []string{
|
|||
FeatureSCIM,
|
||||
FeatureWorkspaceQuota,
|
||||
FeatureTemplateRBAC,
|
||||
FeatureHighAvailability,
|
||||
}
|
||||
|
||||
type Feature struct {
|
||||
|
@ -42,6 +44,7 @@ type Feature struct {
|
|||
type Entitlements struct {
|
||||
Features map[string]Feature `json:"features"`
|
||||
Warnings []string `json:"warnings"`
|
||||
Errors []string `json:"errors"`
|
||||
HasLicense bool `json:"has_license"`
|
||||
Experimental bool `json:"experimental"`
|
||||
Trial bool `json:"trial"`
|
||||
|
|
|
@ -19,6 +19,7 @@ type DeploymentFlags struct {
|
|||
DerpServerRegionCode *StringFlag `json:"derp_server_region_code" typescript:",notnull"`
|
||||
DerpServerRegionName *StringFlag `json:"derp_server_region_name" typescript:",notnull"`
|
||||
DerpServerSTUNAddresses *StringArrayFlag `json:"derp_server_stun_address" typescript:",notnull"`
|
||||
DerpServerRelayAddress *StringFlag `json:"derp_server_relay_address" typescript:",notnull"`
|
||||
DerpConfigURL *StringFlag `json:"derp_config_url" typescript:",notnull"`
|
||||
DerpConfigPath *StringFlag `json:"derp_config_path" typescript:",notnull"`
|
||||
PromEnabled *BoolFlag `json:"prom_enabled" typescript:",notnull"`
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type Replica struct {
|
||||
// ID is the unique identifier for the replica.
|
||||
ID uuid.UUID `json:"id"`
|
||||
// Hostname is the hostname of the replica.
|
||||
Hostname string `json:"hostname"`
|
||||
// CreatedAt is when the replica was first seen.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// RelayAddress is the accessible address to relay DERP connections.
|
||||
RelayAddress string `json:"relay_address"`
|
||||
// RegionID is the region of the replica.
|
||||
RegionID int32 `json:"region_id"`
|
||||
// Error is the error.
|
||||
Error string `json:"error"`
|
||||
// DatabaseLatency is the latency in microseconds to the database.
|
||||
DatabaseLatency int32 `json:"database_latency"`
|
||||
}
|
||||
|
||||
// Replicas fetches the list of replicas.
|
||||
func (c *Client) Replicas(ctx context.Context) ([]Replica, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/replicas", nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
|
||||
var replicas []Replica
|
||||
return replicas, json.NewDecoder(res.Body).Decode(&replicas)
|
||||
}
|
|
@ -21,7 +21,6 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
@ -316,7 +315,8 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err
|
|||
Value: c.SessionToken,
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Jar: jar,
|
||||
Transport: c.HTTPClient.Transport,
|
||||
}
|
||||
// nolint:bodyclose
|
||||
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
||||
|
@ -332,7 +332,17 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err
|
|||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*AgentConn, error) {
|
||||
// @typescript-ignore DialWorkspaceAgentOptions
|
||||
type DialWorkspaceAgentOptions struct {
|
||||
Logger slog.Logger
|
||||
// BlockEndpoints forced a direct connection through DERP.
|
||||
BlockEndpoints bool
|
||||
}
|
||||
|
||||
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) {
|
||||
if options == nil {
|
||||
options = &DialWorkspaceAgentOptions{}
|
||||
}
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -349,9 +359,10 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
|
||||
ip := tailnet.IP()
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||
DERPMap: connInfo.DERPMap,
|
||||
Logger: logger,
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||
DERPMap: connInfo.DERPMap,
|
||||
Logger: options.Logger,
|
||||
BlockEndpoints: options.BlockEndpoints,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create tailnet: %w", err)
|
||||
|
@ -370,7 +381,8 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
Value: c.SessionToken,
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Jar: jar,
|
||||
Transport: c.HTTPClient.Transport,
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
closed := make(chan struct{})
|
||||
|
@ -379,7 +391,7 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
defer close(closed)
|
||||
isFirst := true
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
logger.Debug(ctx, "connecting")
|
||||
options.Logger.Debug(ctx, "connecting")
|
||||
// nolint:bodyclose
|
||||
ws, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
|
@ -398,21 +410,21 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
logger.Debug(ctx, "failed to dial", slog.Error(err))
|
||||
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
logger.Debug(ctx, "serving coordinator")
|
||||
options.Logger.Debug(ctx, "serving coordinator")
|
||||
err = <-errChan
|
||||
if errors.Is(err, context.Canceled) {
|
||||
_ = ws.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "error serving coordinator", slog.Error(err))
|
||||
options.Logger.Debug(ctx, "error serving coordinator", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusGoingAway, "")
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ func TestFeaturesList(t *testing.T) {
|
|||
var entitlements codersdk.Entitlements
|
||||
err := json.Unmarshal(buf.Bytes(), &entitlements)
|
||||
require.NoError(t, err, "unmarshal JSON output")
|
||||
assert.Len(t, entitlements.Features, 6)
|
||||
assert.Len(t, entitlements.Features, 7)
|
||||
assert.Empty(t, entitlements.Warnings)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled,
|
||||
entitlements.Features[codersdk.FeatureUserLimit].Entitlement)
|
||||
|
@ -71,6 +71,8 @@ func TestFeaturesList(t *testing.T) {
|
|||
entitlements.Features[codersdk.FeatureTemplateRBAC].Entitlement)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled,
|
||||
entitlements.Features[codersdk.FeatureSCIM].Entitlement)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled,
|
||||
entitlements.Features[codersdk.FeatureHighAvailability].Entitlement)
|
||||
assert.False(t, entitlements.HasLicense)
|
||||
assert.False(t, entitlements.Experimental)
|
||||
})
|
||||
|
|
|
@ -2,11 +2,20 @@ package cli
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"github.com/coder/coder/cli/deployment"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/enterprise/coderd"
|
||||
"github.com/coder/coder/tailnet"
|
||||
|
||||
agpl "github.com/coder/coder/cli"
|
||||
agplcoderd "github.com/coder/coder/coderd"
|
||||
|
@ -14,23 +23,49 @@ import (
|
|||
|
||||
func server() *cobra.Command {
|
||||
dflags := deployment.Flags()
|
||||
cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, error) {
|
||||
cmd := agpl.Server(dflags, func(ctx context.Context, options *agplcoderd.Options) (*agplcoderd.API, io.Closer, error) {
|
||||
if dflags.DerpServerRelayAddress.Value != "" {
|
||||
_, err := url.Parse(dflags.DerpServerRelayAddress.Value)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("derp-server-relay-address must be a valid HTTP URL: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
|
||||
meshKey, err := options.Database.GetDERPMeshKey(ctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, xerrors.Errorf("get mesh key: %w", err)
|
||||
}
|
||||
meshKey, err = cryptorand.String(32)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("generate mesh key: %w", err)
|
||||
}
|
||||
err = options.Database.InsertDERPMeshKey(ctx, meshKey)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("insert mesh key: %w", err)
|
||||
}
|
||||
}
|
||||
options.DERPServer.SetMeshKey(meshKey)
|
||||
|
||||
o := &coderd.Options{
|
||||
AuditLogging: dflags.AuditLogging.Value,
|
||||
BrowserOnly: dflags.BrowserOnly.Value,
|
||||
SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value),
|
||||
UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value,
|
||||
RBACEnabled: true,
|
||||
Options: options,
|
||||
AuditLogging: dflags.AuditLogging.Value,
|
||||
BrowserOnly: dflags.BrowserOnly.Value,
|
||||
SCIMAPIKey: []byte(dflags.SCIMAuthHeader.Value),
|
||||
UserWorkspaceQuota: dflags.UserWorkspaceQuota.Value,
|
||||
RBAC: true,
|
||||
DERPServerRelayAddress: dflags.DerpServerRelayAddress.Value,
|
||||
DERPServerRegionID: dflags.DerpServerRegionID.Value,
|
||||
|
||||
Options: options,
|
||||
}
|
||||
api, err := coderd.New(ctx, o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
return api.AGPL, nil
|
||||
return api.AGPL, api, nil
|
||||
})
|
||||
|
||||
deployment.AttachFlags(cmd.Flags(), dflags, true)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ func TestCheckACLPermissions(t *testing.T) {
|
|||
// Create adminClient, member, and org adminClient
|
||||
adminUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
_ = coderdenttest.AddLicense(t, adminClient, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
memberClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)
|
||||
|
|
|
@ -3,6 +3,8 @@ package coderd
|
|||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -23,6 +25,10 @@ import (
|
|||
"github.com/coder/coder/enterprise/audit"
|
||||
"github.com/coder/coder/enterprise/audit/backends"
|
||||
"github.com/coder/coder/enterprise/coderd/license"
|
||||
"github.com/coder/coder/enterprise/derpmesh"
|
||||
"github.com/coder/coder/enterprise/replicasync"
|
||||
"github.com/coder/coder/enterprise/tailnet"
|
||||
agpltailnet "github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
// New constructs an Enterprise coderd API instance.
|
||||
|
@ -47,6 +53,7 @@ func New(ctx context.Context, options *Options) (*API, error) {
|
|||
Options: options,
|
||||
cancelEntitlementsLoop: cancelFunc,
|
||||
}
|
||||
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
OIDC: options.OIDCConfig,
|
||||
|
@ -59,6 +66,10 @@ func New(ctx context.Context, options *Options) (*API, error) {
|
|||
|
||||
api.AGPL.APIHandler.Group(func(r chi.Router) {
|
||||
r.Get("/entitlements", api.serveEntitlements)
|
||||
r.Route("/replicas", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", api.replicas)
|
||||
})
|
||||
r.Route("/licenses", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Post("/", api.postLicense)
|
||||
|
@ -117,7 +128,40 @@ func New(ctx context.Context, options *Options) (*API, error) {
|
|||
})
|
||||
}
|
||||
|
||||
err := api.updateEntitlements(ctx)
|
||||
meshRootCA := x509.NewCertPool()
|
||||
for _, certificate := range options.TLSCertificates {
|
||||
for _, certificatePart := range certificate.Certificate {
|
||||
certificate, err := x509.ParseCertificate(certificatePart)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse certificate %s: %w", certificate.Subject.CommonName, err)
|
||||
}
|
||||
meshRootCA.AddCert(certificate)
|
||||
}
|
||||
}
|
||||
// This TLS configuration spoofs access from the access URL hostname
|
||||
// assuming that the certificates provided will cover that hostname.
|
||||
//
|
||||
// Replica sync and DERP meshing require accessing replicas via their
|
||||
// internal IP addresses, and if TLS is configured we use the same
|
||||
// certificates.
|
||||
meshTLSConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: options.TLSCertificates,
|
||||
RootCAs: meshRootCA,
|
||||
ServerName: options.AccessURL.Hostname(),
|
||||
}
|
||||
var err error
|
||||
api.replicaManager, err = replicasync.New(ctx, options.Logger, options.Database, options.Pubsub, &replicasync.Options{
|
||||
RelayAddress: options.DERPServerRelayAddress,
|
||||
RegionID: int32(options.DERPServerRegionID),
|
||||
TLSConfig: meshTLSConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("initialize replica: %w", err)
|
||||
}
|
||||
api.derpMesh = derpmesh.New(options.Logger.Named("derpmesh"), api.DERPServer, meshTLSConfig)
|
||||
|
||||
err = api.updateEntitlements(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update entitlements: %w", err)
|
||||
}
|
||||
|
@ -129,13 +173,17 @@ func New(ctx context.Context, options *Options) (*API, error) {
|
|||
type Options struct {
|
||||
*coderd.Options
|
||||
|
||||
RBACEnabled bool
|
||||
RBAC bool
|
||||
AuditLogging bool
|
||||
// Whether to block non-browser connections.
|
||||
BrowserOnly bool
|
||||
SCIMAPIKey []byte
|
||||
UserWorkspaceQuota int
|
||||
|
||||
// Used for high availability.
|
||||
DERPServerRelayAddress string
|
||||
DERPServerRegionID int
|
||||
|
||||
EntitlementsUpdateInterval time.Duration
|
||||
Keys map[string]ed25519.PublicKey
|
||||
}
|
||||
|
@ -144,6 +192,11 @@ type API struct {
|
|||
AGPL *coderd.API
|
||||
*Options
|
||||
|
||||
// Detects multiple Coder replicas running at the same time.
|
||||
replicaManager *replicasync.Manager
|
||||
// Meshes DERP connections from multiple replicas.
|
||||
derpMesh *derpmesh.Mesh
|
||||
|
||||
cancelEntitlementsLoop func()
|
||||
entitlementsMu sync.RWMutex
|
||||
entitlements codersdk.Entitlements
|
||||
|
@ -151,6 +204,8 @@ type API struct {
|
|||
|
||||
func (api *API) Close() error {
|
||||
api.cancelEntitlementsLoop()
|
||||
_ = api.replicaManager.Close()
|
||||
_ = api.derpMesh.Close()
|
||||
return api.AGPL.Close()
|
||||
}
|
||||
|
||||
|
@ -158,12 +213,13 @@ func (api *API) updateEntitlements(ctx context.Context) error {
|
|||
api.entitlementsMu.Lock()
|
||||
defer api.entitlementsMu.Unlock()
|
||||
|
||||
entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, api.Keys, map[string]bool{
|
||||
codersdk.FeatureAuditLog: api.AuditLogging,
|
||||
codersdk.FeatureBrowserOnly: api.BrowserOnly,
|
||||
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
|
||||
codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0,
|
||||
codersdk.FeatureTemplateRBAC: api.RBACEnabled,
|
||||
entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, len(api.replicaManager.All()), api.Keys, map[string]bool{
|
||||
codersdk.FeatureAuditLog: api.AuditLogging,
|
||||
codersdk.FeatureBrowserOnly: api.BrowserOnly,
|
||||
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
|
||||
codersdk.FeatureWorkspaceQuota: api.UserWorkspaceQuota != 0,
|
||||
codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "",
|
||||
codersdk.FeatureTemplateRBAC: api.RBAC,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -209,6 +265,46 @@ func (api *API) updateEntitlements(ctx context.Context) error {
|
|||
api.AGPL.WorkspaceQuotaEnforcer.Store(&enforcer)
|
||||
}
|
||||
|
||||
if changed, enabled := featureChanged(codersdk.FeatureHighAvailability); changed {
|
||||
coordinator := agpltailnet.NewCoordinator()
|
||||
if enabled {
|
||||
haCoordinator, err := tailnet.NewCoordinator(api.Logger, api.Pubsub)
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "unable to set up high availability coordinator", slog.Error(err))
|
||||
// If we try to setup the HA coordinator and it fails, nothing
|
||||
// is actually changing.
|
||||
changed = false
|
||||
} else {
|
||||
coordinator = haCoordinator
|
||||
}
|
||||
|
||||
api.replicaManager.SetCallback(func() {
|
||||
addresses := make([]string, 0)
|
||||
for _, replica := range api.replicaManager.Regional() {
|
||||
addresses = append(addresses, replica.RelayAddress)
|
||||
}
|
||||
api.derpMesh.SetAddresses(addresses, false)
|
||||
_ = api.updateEntitlements(ctx)
|
||||
})
|
||||
} else {
|
||||
api.derpMesh.SetAddresses([]string{}, false)
|
||||
api.replicaManager.SetCallback(func() {
|
||||
// If the amount of replicas change, so should our entitlements.
|
||||
// This is to display a warning in the UI if the user is unlicensed.
|
||||
_ = api.updateEntitlements(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// Recheck changed in case the HA coordinator failed to set up.
|
||||
if changed {
|
||||
oldCoordinator := *api.AGPL.TailnetCoordinator.Swap(&coordinator)
|
||||
err := oldCoordinator.Close()
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "close old tailnet coordinator", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
api.entitlements = entitlements
|
||||
|
||||
return nil
|
||||
|
|
|
@ -41,9 +41,9 @@ func TestEntitlements(t *testing.T) {
|
|||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
TemplateRBACEnabled: true,
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
res, err := client.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -85,7 +85,7 @@ func TestEntitlements(t *testing.T) {
|
|||
assert.False(t, res.HasLicense)
|
||||
al = res.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.False(t, al.Enabled)
|
||||
})
|
||||
t.Run("Pubsub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -4,7 +4,9 @@ import (
|
|||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -60,19 +62,21 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
if options.Options == nil {
|
||||
options.Options = &coderdtest.Options{}
|
||||
}
|
||||
srv, cancelFunc, oop := coderdtest.NewOptions(t, options.Options)
|
||||
setHandler, cancelFunc, oop := coderdtest.NewOptions(t, options.Options)
|
||||
coderAPI, err := coderd.New(context.Background(), &coderd.Options{
|
||||
RBACEnabled: true,
|
||||
RBAC: true,
|
||||
AuditLogging: options.AuditLogging,
|
||||
BrowserOnly: options.BrowserOnly,
|
||||
SCIMAPIKey: options.SCIMAPIKey,
|
||||
DERPServerRelayAddress: oop.AccessURL.String(),
|
||||
DERPServerRegionID: oop.DERPMap.RegionIDs()[0],
|
||||
UserWorkspaceQuota: options.UserWorkspaceQuota,
|
||||
Options: oop,
|
||||
EntitlementsUpdateInterval: options.EntitlementsUpdateInterval,
|
||||
Keys: Keys,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
srv.Config.Handler = coderAPI.AGPL.RootHandler
|
||||
setHandler(coderAPI.AGPL.RootHandler)
|
||||
var provisionerCloser io.Closer = nopcloser{}
|
||||
if options.IncludeProvisionerDaemon {
|
||||
provisionerCloser = coderdtest.NewProvisionerDaemon(t, coderAPI.AGPL)
|
||||
|
@ -83,22 +87,32 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
_ = provisionerCloser.Close()
|
||||
_ = coderAPI.Close()
|
||||
})
|
||||
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
|
||||
client := codersdk.New(coderAPI.AccessURL)
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
//nolint:gosec
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
return client, provisionerCloser, coderAPI
|
||||
}
|
||||
|
||||
type LicenseOptions struct {
|
||||
AccountType string
|
||||
AccountID string
|
||||
Trial bool
|
||||
AllFeatures bool
|
||||
GraceAt time.Time
|
||||
ExpiresAt time.Time
|
||||
UserLimit int64
|
||||
AuditLog bool
|
||||
BrowserOnly bool
|
||||
SCIM bool
|
||||
WorkspaceQuota bool
|
||||
TemplateRBACEnabled bool
|
||||
AccountType string
|
||||
AccountID string
|
||||
Trial bool
|
||||
AllFeatures bool
|
||||
GraceAt time.Time
|
||||
ExpiresAt time.Time
|
||||
UserLimit int64
|
||||
AuditLog bool
|
||||
BrowserOnly bool
|
||||
SCIM bool
|
||||
WorkspaceQuota bool
|
||||
TemplateRBAC bool
|
||||
HighAvailability bool
|
||||
}
|
||||
|
||||
// AddLicense generates a new license with the options provided and inserts it.
|
||||
|
@ -134,9 +148,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
|
|||
if options.WorkspaceQuota {
|
||||
workspaceQuota = 1
|
||||
}
|
||||
highAvailability := int64(0)
|
||||
if options.HighAvailability {
|
||||
highAvailability = 1
|
||||
}
|
||||
|
||||
rbacEnabled := int64(0)
|
||||
if options.TemplateRBACEnabled {
|
||||
if options.TemplateRBAC {
|
||||
rbacEnabled = 1
|
||||
}
|
||||
|
||||
|
@ -154,12 +172,13 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
|
|||
Version: license.CurrentVersion,
|
||||
AllFeatures: options.AllFeatures,
|
||||
Features: license.Features{
|
||||
UserLimit: options.UserLimit,
|
||||
AuditLog: auditLog,
|
||||
BrowserOnly: browserOnly,
|
||||
SCIM: scim,
|
||||
WorkspaceQuota: workspaceQuota,
|
||||
TemplateRBAC: rbacEnabled,
|
||||
UserLimit: options.UserLimit,
|
||||
AuditLog: auditLog,
|
||||
BrowserOnly: browserOnly,
|
||||
SCIM: scim,
|
||||
WorkspaceQuota: workspaceQuota,
|
||||
HighAvailability: highAvailability,
|
||||
TemplateRBAC: rbacEnabled,
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
|
|||
ctx, _ := testutil.Context(t)
|
||||
admin := coderdtest.CreateFirstUser(t, client)
|
||||
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
group, err := client.CreateGroup(ctx, admin.OrganizationID, codersdk.CreateGroupRequest{
|
||||
Name: "testgroup",
|
||||
|
@ -58,6 +58,10 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
|
|||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
}
|
||||
assertRoute["GET:/api/v2/replicas"] = coderdtest.RouteCheck{
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceReplicas,
|
||||
}
|
||||
assertRoute["DELETE:/api/v2/licenses/{id}"] = coderdtest.RouteCheck{
|
||||
AssertAction: rbac.ActionDelete,
|
||||
AssertObject: rbac.ResourceLicense,
|
||||
|
|
|
@ -24,7 +24,7 @@ func TestCreateGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -43,7 +43,7 @@ func TestCreateGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
_, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -67,7 +67,7 @@ func TestCreateGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
_, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -90,7 +90,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -112,7 +112,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -138,7 +138,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -173,7 +173,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -197,7 +197,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -221,7 +221,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
ctx, _ := testutil.Context(t)
|
||||
|
@ -247,7 +247,7 @@ func TestPatchGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -276,7 +276,7 @@ func TestGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -296,7 +296,7 @@ func TestGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -326,7 +326,7 @@ func TestGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
||||
|
@ -347,7 +347,7 @@ func TestGroup(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -380,7 +380,7 @@ func TestGroup(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -421,7 +421,7 @@ func TestGroups(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
_, user3 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -467,7 +467,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
group1, err := client.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
|
||||
|
@ -492,7 +492,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
ctx, _ := testutil.Context(t)
|
||||
err := client.DeleteGroup(ctx, user.OrganizationID)
|
||||
|
|
|
@ -17,12 +17,20 @@ import (
|
|||
)
|
||||
|
||||
// Entitlements processes licenses to return whether features are enabled or not.
|
||||
func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, keys map[string]ed25519.PublicKey, enablements map[string]bool) (codersdk.Entitlements, error) {
|
||||
func Entitlements(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
logger slog.Logger,
|
||||
replicaCount int,
|
||||
keys map[string]ed25519.PublicKey,
|
||||
enablements map[string]bool,
|
||||
) (codersdk.Entitlements, error) {
|
||||
now := time.Now()
|
||||
// Default all entitlements to be disabled.
|
||||
entitlements := codersdk.Entitlements{
|
||||
Features: map[string]codersdk.Feature{},
|
||||
Warnings: []string{},
|
||||
Errors: []string{},
|
||||
}
|
||||
for _, featureName := range codersdk.FeatureNames {
|
||||
entitlements.Features[featureName] = codersdk.Feature{
|
||||
|
@ -96,6 +104,12 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
|
|||
Enabled: enablements[codersdk.FeatureWorkspaceQuota],
|
||||
}
|
||||
}
|
||||
if claims.Features.HighAvailability > 0 {
|
||||
entitlements.Features[codersdk.FeatureHighAvailability] = codersdk.Feature{
|
||||
Entitlement: entitlement,
|
||||
Enabled: enablements[codersdk.FeatureHighAvailability],
|
||||
}
|
||||
}
|
||||
if claims.Features.TemplateRBAC > 0 {
|
||||
entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{
|
||||
Entitlement: entitlement,
|
||||
|
@ -132,6 +146,10 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
|
|||
if featureName == codersdk.FeatureUserLimit {
|
||||
continue
|
||||
}
|
||||
// High availability has it's own warnings based on replica count!
|
||||
if featureName == codersdk.FeatureHighAvailability {
|
||||
continue
|
||||
}
|
||||
feature := entitlements.Features[featureName]
|
||||
if !feature.Enabled {
|
||||
continue
|
||||
|
@ -141,9 +159,6 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
|
|||
case codersdk.EntitlementNotEntitled:
|
||||
entitlements.Warnings = append(entitlements.Warnings,
|
||||
fmt.Sprintf("%s is enabled but your license is not entitled to this feature.", niceName))
|
||||
// Disable the feature and add a warning...
|
||||
feature.Enabled = false
|
||||
entitlements.Features[featureName] = feature
|
||||
case codersdk.EntitlementGracePeriod:
|
||||
entitlements.Warnings = append(entitlements.Warnings,
|
||||
fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName))
|
||||
|
@ -152,6 +167,32 @@ func Entitlements(ctx context.Context, db database.Store, logger slog.Logger, ke
|
|||
}
|
||||
}
|
||||
|
||||
if replicaCount > 1 {
|
||||
feature := entitlements.Features[codersdk.FeatureHighAvailability]
|
||||
|
||||
switch feature.Entitlement {
|
||||
case codersdk.EntitlementNotEntitled:
|
||||
if entitlements.HasLicense {
|
||||
entitlements.Errors = append(entitlements.Warnings,
|
||||
"You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.")
|
||||
} else {
|
||||
entitlements.Errors = append(entitlements.Warnings,
|
||||
"You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.")
|
||||
}
|
||||
case codersdk.EntitlementGracePeriod:
|
||||
entitlements.Warnings = append(entitlements.Warnings,
|
||||
"You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.")
|
||||
}
|
||||
}
|
||||
|
||||
for _, featureName := range codersdk.FeatureNames {
|
||||
feature := entitlements.Features[featureName]
|
||||
if feature.Entitlement == codersdk.EntitlementNotEntitled {
|
||||
feature.Enabled = false
|
||||
entitlements.Features[featureName] = feature
|
||||
}
|
||||
}
|
||||
|
||||
return entitlements, nil
|
||||
}
|
||||
|
||||
|
@ -171,12 +212,13 @@ var (
|
|||
)
|
||||
|
||||
type Features struct {
|
||||
UserLimit int64 `json:"user_limit"`
|
||||
AuditLog int64 `json:"audit_log"`
|
||||
BrowserOnly int64 `json:"browser_only"`
|
||||
SCIM int64 `json:"scim"`
|
||||
WorkspaceQuota int64 `json:"workspace_quota"`
|
||||
TemplateRBAC int64 `json:"template_rbac"`
|
||||
UserLimit int64 `json:"user_limit"`
|
||||
AuditLog int64 `json:"audit_log"`
|
||||
BrowserOnly int64 `json:"browser_only"`
|
||||
SCIM int64 `json:"scim"`
|
||||
WorkspaceQuota int64 `json:"workspace_quota"`
|
||||
TemplateRBAC int64 `json:"template_rbac"`
|
||||
HighAvailability int64 `json:"high_availability"`
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
|
|
|
@ -20,17 +20,18 @@ import (
|
|||
func TestEntitlements(t *testing.T) {
|
||||
t.Parallel()
|
||||
all := map[string]bool{
|
||||
codersdk.FeatureAuditLog: true,
|
||||
codersdk.FeatureBrowserOnly: true,
|
||||
codersdk.FeatureSCIM: true,
|
||||
codersdk.FeatureWorkspaceQuota: true,
|
||||
codersdk.FeatureTemplateRBAC: true,
|
||||
codersdk.FeatureAuditLog: true,
|
||||
codersdk.FeatureBrowserOnly: true,
|
||||
codersdk.FeatureSCIM: true,
|
||||
codersdk.FeatureWorkspaceQuota: true,
|
||||
codersdk.FeatureHighAvailability: true,
|
||||
codersdk.FeatureTemplateRBAC: true,
|
||||
}
|
||||
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
|
||||
require.NoError(t, err)
|
||||
require.False(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -46,7 +47,7 @@ func TestEntitlements(t *testing.T) {
|
|||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -60,16 +61,17 @@ func TestEntitlements(t *testing.T) {
|
|||
db := databasefake.New()
|
||||
db.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
BrowserOnly: true,
|
||||
SCIM: true,
|
||||
WorkspaceQuota: true,
|
||||
TemplateRBACEnabled: true,
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
BrowserOnly: true,
|
||||
SCIM: true,
|
||||
WorkspaceQuota: true,
|
||||
HighAvailability: true,
|
||||
TemplateRBAC: true,
|
||||
}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -82,18 +84,19 @@ func TestEntitlements(t *testing.T) {
|
|||
db := databasefake.New()
|
||||
db.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
BrowserOnly: true,
|
||||
SCIM: true,
|
||||
WorkspaceQuota: true,
|
||||
TemplateRBACEnabled: true,
|
||||
GraceAt: time.Now().Add(-time.Hour),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
UserLimit: 100,
|
||||
AuditLog: true,
|
||||
BrowserOnly: true,
|
||||
SCIM: true,
|
||||
WorkspaceQuota: true,
|
||||
HighAvailability: true,
|
||||
TemplateRBAC: true,
|
||||
GraceAt: time.Now().Add(-time.Hour),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -101,6 +104,9 @@ func TestEntitlements(t *testing.T) {
|
|||
if featureName == codersdk.FeatureUserLimit {
|
||||
continue
|
||||
}
|
||||
if featureName == codersdk.FeatureHighAvailability {
|
||||
continue
|
||||
}
|
||||
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
|
||||
require.Equal(t, codersdk.EntitlementGracePeriod, entitlements.Features[featureName].Entitlement)
|
||||
require.Contains(t, entitlements.Warnings, fmt.Sprintf("%s is enabled but your license for this feature is expired.", niceName))
|
||||
|
@ -113,7 +119,7 @@ func TestEntitlements(t *testing.T) {
|
|||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -121,6 +127,9 @@ func TestEntitlements(t *testing.T) {
|
|||
if featureName == codersdk.FeatureUserLimit {
|
||||
continue
|
||||
}
|
||||
if featureName == codersdk.FeatureHighAvailability {
|
||||
continue
|
||||
}
|
||||
niceName := strings.Title(strings.ReplaceAll(featureName, "_", " "))
|
||||
// Ensures features that are not entitled are properly disabled.
|
||||
require.False(t, entitlements.Features[featureName].Enabled)
|
||||
|
@ -139,7 +148,7 @@ func TestEntitlements(t *testing.T) {
|
|||
}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.Contains(t, entitlements.Warnings, "Your deployment has 2 active users but is only licensed for 1.")
|
||||
|
@ -161,7 +170,7 @@ func TestEntitlements(t *testing.T) {
|
|||
}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.Empty(t, entitlements.Warnings)
|
||||
|
@ -184,7 +193,7 @@ func TestEntitlements(t *testing.T) {
|
|||
}),
|
||||
})
|
||||
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, map[string]bool{})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, map[string]bool{})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -199,7 +208,7 @@ func TestEntitlements(t *testing.T) {
|
|||
AllFeatures: true,
|
||||
}),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, coderdenttest.Keys, all)
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 1, coderdenttest.Keys, all)
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.False(t, entitlements.Trial)
|
||||
|
@ -211,4 +220,52 @@ func TestEntitlements(t *testing.T) {
|
|||
require.Equal(t, codersdk.EntitlementEntitled, entitlements.Features[featureName].Entitlement)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultipleReplicasNoLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, all)
|
||||
require.NoError(t, err)
|
||||
require.False(t, entitlements.HasLicense)
|
||||
require.Len(t, entitlements.Errors, 1)
|
||||
require.Equal(t, "You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.", entitlements.Errors[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleReplicasNotEntitled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
db.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
AuditLog: true,
|
||||
}),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{
|
||||
codersdk.FeatureHighAvailability: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.Len(t, entitlements.Errors, 1)
|
||||
require.Equal(t, "You have multiple replicas but your license is not entitled to high availability. You will be unable to connect to workspaces.", entitlements.Errors[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleReplicasGrace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
db.InsertLicense(context.Background(), database.InsertLicenseParams{
|
||||
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
|
||||
HighAvailability: true,
|
||||
GraceAt: time.Now().Add(-time.Hour),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}),
|
||||
Exp: time.Now().Add(time.Hour),
|
||||
})
|
||||
entitlements, err := license.Entitlements(context.Background(), db, slog.Logger{}, 2, coderdenttest.Keys, map[string]bool{
|
||||
codersdk.FeatureHighAvailability: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, entitlements.HasLicense)
|
||||
require.Len(t, entitlements.Warnings, 1)
|
||||
require.Equal(t, "You have multiple replicas but your license for high availability is expired. Reduce to one replica or workspace connections will stop working.", entitlements.Warnings[0])
|
||||
})
|
||||
}
|
||||
|
|
|
@ -78,21 +78,21 @@ func TestGetLicense(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountID: "testing",
|
||||
AuditLog: true,
|
||||
SCIM: true,
|
||||
BrowserOnly: true,
|
||||
TemplateRBACEnabled: true,
|
||||
AccountID: "testing",
|
||||
AuditLog: true,
|
||||
SCIM: true,
|
||||
BrowserOnly: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
AccountID: "testing2",
|
||||
AuditLog: true,
|
||||
SCIM: true,
|
||||
BrowserOnly: true,
|
||||
Trial: true,
|
||||
UserLimit: 200,
|
||||
TemplateRBACEnabled: false,
|
||||
AccountID: "testing2",
|
||||
AuditLog: true,
|
||||
SCIM: true,
|
||||
BrowserOnly: true,
|
||||
Trial: true,
|
||||
UserLimit: 200,
|
||||
TemplateRBAC: false,
|
||||
})
|
||||
|
||||
licenses, err := client.Licenses(ctx)
|
||||
|
@ -101,23 +101,25 @@ func TestGetLicense(t *testing.T) {
|
|||
assert.Equal(t, int32(1), licenses[0].ID)
|
||||
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("0"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
codersdk.FeatureSCIM: json.Number("1"),
|
||||
codersdk.FeatureBrowserOnly: json.Number("1"),
|
||||
codersdk.FeatureWorkspaceQuota: json.Number("0"),
|
||||
codersdk.FeatureTemplateRBAC: json.Number("1"),
|
||||
codersdk.FeatureUserLimit: json.Number("0"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
codersdk.FeatureSCIM: json.Number("1"),
|
||||
codersdk.FeatureBrowserOnly: json.Number("1"),
|
||||
codersdk.FeatureWorkspaceQuota: json.Number("0"),
|
||||
codersdk.FeatureHighAvailability: json.Number("0"),
|
||||
codersdk.FeatureTemplateRBAC: json.Number("1"),
|
||||
}, licenses[0].Claims["features"])
|
||||
assert.Equal(t, int32(2), licenses[1].ID)
|
||||
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
|
||||
assert.Equal(t, true, licenses[1].Claims["trial"])
|
||||
assert.Equal(t, map[string]interface{}{
|
||||
codersdk.FeatureUserLimit: json.Number("200"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
codersdk.FeatureSCIM: json.Number("1"),
|
||||
codersdk.FeatureBrowserOnly: json.Number("1"),
|
||||
codersdk.FeatureWorkspaceQuota: json.Number("0"),
|
||||
codersdk.FeatureTemplateRBAC: json.Number("0"),
|
||||
codersdk.FeatureUserLimit: json.Number("200"),
|
||||
codersdk.FeatureAuditLog: json.Number("1"),
|
||||
codersdk.FeatureSCIM: json.Number("1"),
|
||||
codersdk.FeatureBrowserOnly: json.Number("1"),
|
||||
codersdk.FeatureWorkspaceQuota: json.Number("0"),
|
||||
codersdk.FeatureHighAvailability: json.Number("0"),
|
||||
codersdk.FeatureTemplateRBAC: json.Number("0"),
|
||||
}, licenses[1].Claims["features"])
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// replicas returns the number of replicas that are active in Coder.
|
||||
func (api *API) replicas(rw http.ResponseWriter, r *http.Request) {
|
||||
if !api.AGPL.Authorize(r, rbac.ActionRead, rbac.ResourceReplicas) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
replicas := api.replicaManager.All()
|
||||
res := make([]codersdk.Replica, 0, len(replicas))
|
||||
for _, replica := range replicas {
|
||||
res = append(res, convertReplica(replica))
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, res)
|
||||
}
|
||||
|
||||
func convertReplica(replica database.Replica) codersdk.Replica {
|
||||
return codersdk.Replica{
|
||||
ID: replica.ID,
|
||||
Hostname: replica.Hostname,
|
||||
CreatedAt: replica.CreatedAt,
|
||||
RelayAddress: replica.RelayAddress,
|
||||
RegionID: replica.RegionID,
|
||||
Error: replica.Error,
|
||||
DatabaseLatency: replica.DatabaseLatency,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestReplicas(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("ErrorWithoutLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
firstClient := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, firstClient)
|
||||
secondClient, _, secondAPI := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
})
|
||||
secondClient.SessionToken = firstClient.SessionToken
|
||||
ents, err := secondClient.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ents.Errors, 1)
|
||||
_ = secondAPI.Close()
|
||||
|
||||
ents, err = firstClient.Entitlements(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ents.Warnings, 0)
|
||||
})
|
||||
t.Run("ConnectAcrossMultiple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
firstClient := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, firstClient)
|
||||
coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{
|
||||
HighAvailability: true,
|
||||
})
|
||||
|
||||
secondClient := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
},
|
||||
})
|
||||
secondClient.SessionToken = firstClient.SessionToken
|
||||
replicas, err := secondClient.Replicas(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, replicas, 2)
|
||||
|
||||
_, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0)
|
||||
conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
BlockEndpoints: true,
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancelFunc()
|
||||
_, err = conn.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
_ = conn.Close()
|
||||
})
|
||||
t.Run("ConnectAcrossMultipleTLS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
certificates := []tls.Certificate{testutil.GenerateTLSCertificate(t, "localhost")}
|
||||
firstClient := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
TLSCertificates: certificates,
|
||||
},
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, firstClient)
|
||||
coderdenttest.AddLicense(t, firstClient, coderdenttest.LicenseOptions{
|
||||
HighAvailability: true,
|
||||
})
|
||||
|
||||
secondClient := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
TLSCertificates: certificates,
|
||||
},
|
||||
})
|
||||
secondClient.SessionToken = firstClient.SessionToken
|
||||
replicas, err := secondClient.Replicas(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, replicas, 2)
|
||||
|
||||
_, agent := setupWorkspaceAgent(t, firstClient, firstUser, 0)
|
||||
conn, err := secondClient.DialWorkspaceAgent(context.Background(), agent.ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
BlockEndpoints: true,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow)
|
||||
defer cancelFunc()
|
||||
_, err = conn.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
_ = conn.Close()
|
||||
replicas, err = secondClient.Replicas(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, replicas, 2)
|
||||
for _, replica := range replicas {
|
||||
require.Empty(t, replica.Error)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -23,7 +23,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -64,7 +64,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -88,7 +88,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -138,7 +138,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -176,7 +176,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -214,7 +214,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
@ -262,7 +262,7 @@ func TestTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -318,7 +318,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -361,7 +361,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -422,7 +422,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
@ -447,7 +447,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
@ -472,7 +472,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
_, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -498,7 +498,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -533,7 +533,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
client2, user2 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -575,7 +575,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
@ -597,7 +597,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
client1, user1 := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
@ -662,7 +662,7 @@ func TestUpdateTemplateACL(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
|
||||
|
|
|
@ -2,6 +2,7 @@ package coderd_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
@ -9,7 +10,6 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
|
@ -42,7 +42,7 @@ func TestBlockNonBrowser(t *testing.T) {
|
|||
BrowserOnly: true,
|
||||
})
|
||||
_, agent := setupWorkspaceAgent(t, client, user, 0)
|
||||
_, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID)
|
||||
_, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil)
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
|
||||
|
@ -59,7 +59,7 @@ func TestBlockNonBrowser(t *testing.T) {
|
|||
BrowserOnly: false,
|
||||
})
|
||||
_, agent := setupWorkspaceAgent(t, client, user, 0)
|
||||
conn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, agent.ID)
|
||||
conn, err := client.DialWorkspaceAgent(context.Background(), agent.ID, nil)
|
||||
require.NoError(t, err)
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
@ -109,6 +109,14 @@ func setupWorkspaceAgent(t *testing.T, client *codersdk.Client, user codersdk.Cr
|
|||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.HTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
//nolint:gosec
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
|
|
|
@ -26,7 +26,7 @@ func TestCreateWorkspace(t *testing.T) {
|
|||
client := coderdenttest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
_ = coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
|
||||
TemplateRBACEnabled: true,
|
||||
TemplateRBAC: true,
|
||||
})
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
package derpmesh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"github.com/coder/coder/tailnet"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// New constructs a new mesh for DERP servers.
|
||||
func New(logger slog.Logger, server *derp.Server, tlsConfig *tls.Config) *Mesh {
|
||||
return &Mesh{
|
||||
logger: logger,
|
||||
server: server,
|
||||
tlsConfig: tlsConfig,
|
||||
ctx: context.Background(),
|
||||
closed: make(chan struct{}),
|
||||
active: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
type Mesh struct {
|
||||
logger slog.Logger
|
||||
server *derp.Server
|
||||
ctx context.Context
|
||||
tlsConfig *tls.Config
|
||||
|
||||
mutex sync.Mutex
|
||||
closed chan struct{}
|
||||
active map[string]context.CancelFunc
|
||||
}
|
||||
|
||||
// SetAddresses performs a diff of the incoming addresses and adds
|
||||
// or removes DERP clients from the mesh.
|
||||
//
|
||||
// Connect is only used for testing to ensure DERPs are meshed before
|
||||
// exchanging messages.
|
||||
// nolint:revive
|
||||
func (m *Mesh) SetAddresses(addresses []string, connect bool) {
|
||||
total := make(map[string]struct{}, 0)
|
||||
for _, address := range addresses {
|
||||
addressURL, err := url.Parse(address)
|
||||
if err != nil {
|
||||
m.logger.Error(m.ctx, "invalid address", slog.F("address", err), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
derpURL, err := addressURL.Parse("/derp")
|
||||
if err != nil {
|
||||
m.logger.Error(m.ctx, "parse derp", slog.F("address", err), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
address = derpURL.String()
|
||||
|
||||
total[address] = struct{}{}
|
||||
added, err := m.addAddress(address, connect)
|
||||
if err != nil {
|
||||
m.logger.Error(m.ctx, "failed to add address", slog.F("address", address), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if added {
|
||||
m.logger.Debug(m.ctx, "added mesh address", slog.F("address", address))
|
||||
}
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
for address := range m.active {
|
||||
_, found := total[address]
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
removed := m.removeAddress(address)
|
||||
if removed {
|
||||
m.logger.Debug(m.ctx, "removed mesh address", slog.F("address", address))
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// addAddress begins meshing with a new address. It returns false if the address is already being meshed with.
|
||||
// It's expected that this is a full HTTP address with a path.
|
||||
// e.g. http://127.0.0.1:8080/derp
|
||||
// nolint:revive
|
||||
func (m *Mesh) addAddress(address string, connect bool) (bool, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if m.isClosed() {
|
||||
return false, nil
|
||||
}
|
||||
_, isActive := m.active[address]
|
||||
if isActive {
|
||||
return false, nil
|
||||
}
|
||||
client, err := derphttp.NewClient(m.server.PrivateKey(), address, tailnet.Logger(m.logger.Named("client")))
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("create derp client: %w", err)
|
||||
}
|
||||
client.TLSConfig = m.tlsConfig
|
||||
client.MeshKey = m.server.MeshKey()
|
||||
client.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
})
|
||||
if connect {
|
||||
_ = client.Connect(m.ctx)
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(m.ctx)
|
||||
closed := make(chan struct{})
|
||||
closeFunc := func() {
|
||||
cancelFunc()
|
||||
_ = client.Close()
|
||||
<-closed
|
||||
}
|
||||
m.active[address] = closeFunc
|
||||
go func() {
|
||||
defer close(closed)
|
||||
client.RunWatchConnectionLoop(ctx, m.server.PublicKey(), tailnet.Logger(m.logger.Named("loop")), func(np key.NodePublic) {
|
||||
m.server.AddPacketForwarder(np, client)
|
||||
}, func(np key.NodePublic) {
|
||||
m.server.RemovePacketForwarder(np, client)
|
||||
})
|
||||
}()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// removeAddress stops meshing with a given address.
|
||||
func (m *Mesh) removeAddress(address string) bool {
|
||||
cancelFunc, isActive := m.active[address]
|
||||
if isActive {
|
||||
cancelFunc()
|
||||
}
|
||||
return isActive
|
||||
}
|
||||
|
||||
// Close ends all active meshes with the DERP server.
|
||||
func (m *Mesh) Close() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if m.isClosed() {
|
||||
return nil
|
||||
}
|
||||
close(m.closed)
|
||||
for _, cancelFunc := range m.active {
|
||||
cancelFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mesh) isClosed() bool {
|
||||
select {
|
||||
case <-m.closed:
|
||||
return true
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
package derpmesh_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/enterprise/derpmesh"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestDERPMesh(t *testing.T) {
|
||||
t.Parallel()
|
||||
commonName := "something.org"
|
||||
rawCert := testutil.GenerateTLSCertificate(t, commonName)
|
||||
certificate, err := x509.ParseCertificate(rawCert.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(certificate)
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: commonName,
|
||||
RootCAs: pool,
|
||||
Certificates: []tls.Certificate{rawCert},
|
||||
}
|
||||
|
||||
t.Run("ExchangeMessages", func(t *testing.T) {
|
||||
// This tests messages passing through multiple DERP servers.
|
||||
t.Parallel()
|
||||
firstServer, firstServerURL := startDERP(t, tlsConfig)
|
||||
defer firstServer.Close()
|
||||
secondServer, secondServerURL := startDERP(t, tlsConfig)
|
||||
firstMesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), firstServer, tlsConfig)
|
||||
firstMesh.SetAddresses([]string{secondServerURL}, true)
|
||||
secondMesh := derpmesh.New(slogtest.Make(t, nil).Named("second").Leveled(slog.LevelDebug), secondServer, tlsConfig)
|
||||
secondMesh.SetAddresses([]string{firstServerURL}, true)
|
||||
defer firstMesh.Close()
|
||||
defer secondMesh.Close()
|
||||
|
||||
first := key.NewNode()
|
||||
second := key.NewNode()
|
||||
firstClient, err := derphttp.NewClient(first, secondServerURL, tailnet.Logger(slogtest.Make(t, nil)))
|
||||
require.NoError(t, err)
|
||||
firstClient.TLSConfig = tlsConfig
|
||||
secondClient, err := derphttp.NewClient(second, firstServerURL, tailnet.Logger(slogtest.Make(t, nil)))
|
||||
require.NoError(t, err)
|
||||
secondClient.TLSConfig = tlsConfig
|
||||
err = secondClient.Connect(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
closed := make(chan struct{})
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
sent := []byte("hello world")
|
||||
go func() {
|
||||
defer close(closed)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
err = firstClient.Send(second.Public(), sent)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
got := recvData(t, secondClient)
|
||||
require.Equal(t, sent, got)
|
||||
cancelFunc()
|
||||
<-closed
|
||||
})
|
||||
t.Run("RemoveAddress", func(t *testing.T) {
|
||||
// This tests messages passing through multiple DERP servers.
|
||||
t.Parallel()
|
||||
server, serverURL := startDERP(t, tlsConfig)
|
||||
mesh := derpmesh.New(slogtest.Make(t, nil).Named("first").Leveled(slog.LevelDebug), server, tlsConfig)
|
||||
mesh.SetAddresses([]string{"http://fake.com"}, false)
|
||||
// This should trigger a removal...
|
||||
mesh.SetAddresses([]string{}, false)
|
||||
defer mesh.Close()
|
||||
|
||||
first := key.NewNode()
|
||||
second := key.NewNode()
|
||||
firstClient, err := derphttp.NewClient(first, serverURL, tailnet.Logger(slogtest.Make(t, nil)))
|
||||
require.NoError(t, err)
|
||||
firstClient.TLSConfig = tlsConfig
|
||||
secondClient, err := derphttp.NewClient(second, serverURL, tailnet.Logger(slogtest.Make(t, nil)))
|
||||
require.NoError(t, err)
|
||||
secondClient.TLSConfig = tlsConfig
|
||||
err = secondClient.Connect(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
closed := make(chan struct{})
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
sent := []byte("hello world")
|
||||
go func() {
|
||||
defer close(closed)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
err = firstClient.Send(second.Public(), sent)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
got := recvData(t, secondClient)
|
||||
require.Equal(t, sent, got)
|
||||
cancelFunc()
|
||||
<-closed
|
||||
})
|
||||
t.Run("TwentyMeshes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
meshes := make([]*derpmesh.Mesh, 0, 20)
|
||||
serverURLs := make([]string, 0, 20)
|
||||
for i := 0; i < 20; i++ {
|
||||
server, url := startDERP(t, tlsConfig)
|
||||
mesh := derpmesh.New(slogtest.Make(t, nil).Named("mesh").Leveled(slog.LevelDebug), server, tlsConfig)
|
||||
t.Cleanup(func() {
|
||||
_ = server.Close()
|
||||
_ = mesh.Close()
|
||||
})
|
||||
serverURLs = append(serverURLs, url)
|
||||
meshes = append(meshes, mesh)
|
||||
}
|
||||
for _, mesh := range meshes {
|
||||
mesh.SetAddresses(serverURLs, true)
|
||||
}
|
||||
|
||||
first := key.NewNode()
|
||||
second := key.NewNode()
|
||||
firstClient, err := derphttp.NewClient(first, serverURLs[9], tailnet.Logger(slogtest.Make(t, nil)))
|
||||
require.NoError(t, err)
|
||||
firstClient.TLSConfig = tlsConfig
|
||||
secondClient, err := derphttp.NewClient(second, serverURLs[16], tailnet.Logger(slogtest.Make(t, nil)))
|
||||
require.NoError(t, err)
|
||||
secondClient.TLSConfig = tlsConfig
|
||||
err = secondClient.Connect(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
closed := make(chan struct{})
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
sent := []byte("hello world")
|
||||
go func() {
|
||||
defer close(closed)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
err = firstClient.Send(second.Public(), sent)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
got := recvData(t, secondClient)
|
||||
require.Equal(t, sent, got)
|
||||
cancelFunc()
|
||||
<-closed
|
||||
})
|
||||
}
|
||||
|
||||
func recvData(t *testing.T, client *derphttp.Client) []byte {
|
||||
for {
|
||||
msg, err := client.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
t.Logf("derp: %T", msg)
|
||||
switch msg := msg.(type) {
|
||||
case derp.ReceivedPacket:
|
||||
return msg.Data
|
||||
default:
|
||||
// Drop all others!
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func startDERP(t *testing.T, tlsConfig *tls.Config) (*derp.Server, string) {
|
||||
logf := tailnet.Logger(slogtest.Make(t, nil))
|
||||
d := derp.NewServer(key.NewNode(), logf)
|
||||
d.SetMeshKey("some-key")
|
||||
server := httptest.NewUnstartedServer(derphttp.Handler(d))
|
||||
server.TLS = tlsConfig
|
||||
server.StartTLS()
|
||||
t.Cleanup(func() {
|
||||
_ = d.Close()
|
||||
})
|
||||
t.Cleanup(server.Close)
|
||||
return d, server.URL
|
||||
}
|
|
@ -0,0 +1,391 @@
|
|||
package replicasync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
)
|
||||
|
||||
var (
|
||||
PubsubEvent = "replica"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
CleanupInterval time.Duration
|
||||
UpdateInterval time.Duration
|
||||
PeerTimeout time.Duration
|
||||
RelayAddress string
|
||||
RegionID int32
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// New registers the replica with the database and periodically updates to ensure
|
||||
// it's healthy. It contacts all other alive replicas to ensure they are reachable.
|
||||
func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.PeerTimeout == 0 {
|
||||
options.PeerTimeout = 3 * time.Second
|
||||
}
|
||||
if options.UpdateInterval == 0 {
|
||||
options.UpdateInterval = 5 * time.Second
|
||||
}
|
||||
if options.CleanupInterval == 0 {
|
||||
// The cleanup interval can be quite long, because it's
|
||||
// primary purpose is to clean up dead replicas.
|
||||
options.CleanupInterval = 30 * time.Minute
|
||||
}
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get hostname: %w", err)
|
||||
}
|
||||
databaseLatency, err := db.Ping(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ping database: %w", err)
|
||||
}
|
||||
id := uuid.New()
|
||||
replica, err := db.InsertReplica(ctx, database.InsertReplicaParams{
|
||||
ID: id,
|
||||
CreatedAt: database.Now(),
|
||||
StartedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Hostname: hostname,
|
||||
RegionID: options.RegionID,
|
||||
RelayAddress: options.RelayAddress,
|
||||
Version: buildinfo.Version(),
|
||||
DatabaseLatency: int32(databaseLatency.Microseconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert replica: %w", err)
|
||||
}
|
||||
err = pubsub.Publish(PubsubEvent, []byte(id.String()))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish new replica: %w", err)
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
manager := &Manager{
|
||||
id: id,
|
||||
options: options,
|
||||
db: db,
|
||||
pubsub: pubsub,
|
||||
self: replica,
|
||||
logger: logger,
|
||||
closed: make(chan struct{}),
|
||||
closeCancel: cancelFunc,
|
||||
}
|
||||
err = manager.syncReplicas(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("run replica: %w", err)
|
||||
}
|
||||
peers := manager.Regional()
|
||||
if len(peers) > 0 {
|
||||
self := manager.Self()
|
||||
if self.RelayAddress == "" {
|
||||
return nil, xerrors.Errorf("a relay address must be specified when running multiple replicas in the same region")
|
||||
}
|
||||
}
|
||||
|
||||
err = manager.subscribe(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("subscribe: %w", err)
|
||||
}
|
||||
manager.closeWait.Add(1)
|
||||
go manager.loop(ctx)
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// Manager keeps the replica up to date and in sync with other replicas.
|
||||
type Manager struct {
|
||||
id uuid.UUID
|
||||
options *Options
|
||||
db database.Store
|
||||
pubsub database.Pubsub
|
||||
logger slog.Logger
|
||||
|
||||
closeWait sync.WaitGroup
|
||||
closeMutex sync.Mutex
|
||||
closed chan (struct{})
|
||||
closeCancel context.CancelFunc
|
||||
|
||||
self database.Replica
|
||||
mutex sync.Mutex
|
||||
peers []database.Replica
|
||||
callback func()
|
||||
}
|
||||
|
||||
// updateInterval is used to determine a replicas state.
|
||||
// If the replica was updated > the time, it's considered healthy.
|
||||
// If the replica was updated < the time, it's considered stale.
|
||||
func (m *Manager) updateInterval() time.Time {
|
||||
return database.Now().Add(-3 * m.options.UpdateInterval)
|
||||
}
|
||||
|
||||
// loop runs the replica update sequence on an update interval.
|
||||
func (m *Manager) loop(ctx context.Context) {
|
||||
defer m.closeWait.Done()
|
||||
updateTicker := time.NewTicker(m.options.UpdateInterval)
|
||||
defer updateTicker.Stop()
|
||||
deleteTicker := time.NewTicker(m.options.CleanupInterval)
|
||||
defer deleteTicker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-deleteTicker.C:
|
||||
err := m.db.DeleteReplicasUpdatedBefore(ctx, m.updateInterval())
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "delete old replicas", slog.Error(err))
|
||||
}
|
||||
continue
|
||||
case <-updateTicker.C:
|
||||
}
|
||||
err := m.syncReplicas(ctx)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
m.logger.Warn(ctx, "run replica update loop", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// subscribe listens for new replica information!
|
||||
func (m *Manager) subscribe(ctx context.Context) error {
|
||||
var (
|
||||
needsUpdate = false
|
||||
updating = false
|
||||
updateMutex = sync.Mutex{}
|
||||
)
|
||||
|
||||
// This loop will continually update nodes as updates are processed.
|
||||
// The intent is to always be up to date without spamming the run
|
||||
// function, so if a new update comes in while one is being processed,
|
||||
// it will reprocess afterwards.
|
||||
var update func()
|
||||
update = func() {
|
||||
err := m.syncReplicas(ctx)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
m.logger.Warn(ctx, "run replica from subscribe", slog.Error(err))
|
||||
}
|
||||
updateMutex.Lock()
|
||||
if needsUpdate {
|
||||
needsUpdate = false
|
||||
updateMutex.Unlock()
|
||||
update()
|
||||
return
|
||||
}
|
||||
updating = false
|
||||
updateMutex.Unlock()
|
||||
}
|
||||
cancelFunc, err := m.pubsub.Subscribe(PubsubEvent, func(ctx context.Context, message []byte) {
|
||||
updateMutex.Lock()
|
||||
defer updateMutex.Unlock()
|
||||
id, err := uuid.Parse(string(message))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Don't process updates for ourself!
|
||||
if id == m.id {
|
||||
return
|
||||
}
|
||||
if updating {
|
||||
needsUpdate = true
|
||||
return
|
||||
}
|
||||
updating = true
|
||||
go update()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cancelFunc()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) syncReplicas(ctx context.Context) error {
|
||||
m.closeMutex.Lock()
|
||||
m.closeWait.Add(1)
|
||||
m.closeMutex.Unlock()
|
||||
defer m.closeWait.Done()
|
||||
// Expect replicas to update once every three times the interval...
|
||||
// If they don't, assume death!
|
||||
replicas, err := m.db.GetReplicasUpdatedAfter(ctx, m.updateInterval())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get replicas: %w", err)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.peers = make([]database.Replica, 0, len(replicas))
|
||||
for _, replica := range replicas {
|
||||
if replica.ID == m.id {
|
||||
continue
|
||||
}
|
||||
m.peers = append(m.peers, replica)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
client := http.Client{
|
||||
Timeout: m.options.PeerTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: m.options.TLSConfig,
|
||||
},
|
||||
}
|
||||
defer client.CloseIdleConnections()
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
failed := make([]string, 0)
|
||||
for _, peer := range m.Regional() {
|
||||
wg.Add(1)
|
||||
go func(peer database.Replica) {
|
||||
defer wg.Done()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, peer.RelayAddress, nil)
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "create http request for relay probe",
|
||||
slog.F("relay_address", peer.RelayAddress), slog.Error(err))
|
||||
return
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
failed = append(failed, fmt.Sprintf("relay %s (%s): %s", peer.Hostname, peer.RelayAddress, err))
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
}(peer)
|
||||
}
|
||||
wg.Wait()
|
||||
replicaError := ""
|
||||
if len(failed) > 0 {
|
||||
replicaError = fmt.Sprintf("Failed to dial peers: %s", strings.Join(failed, ", "))
|
||||
}
|
||||
|
||||
databaseLatency, err := m.db.Ping(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ping database: %w", err)
|
||||
}
|
||||
|
||||
replica, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{
|
||||
ID: m.self.ID,
|
||||
UpdatedAt: database.Now(),
|
||||
StartedAt: m.self.StartedAt,
|
||||
StoppedAt: m.self.StoppedAt,
|
||||
RelayAddress: m.self.RelayAddress,
|
||||
RegionID: m.self.RegionID,
|
||||
Hostname: m.self.Hostname,
|
||||
Version: m.self.Version,
|
||||
Error: replicaError,
|
||||
DatabaseLatency: int32(databaseLatency.Microseconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update replica: %w", err)
|
||||
}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if m.self.Error != replica.Error {
|
||||
// Publish an update occurred!
|
||||
err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String()))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish replica update: %w", err)
|
||||
}
|
||||
}
|
||||
m.self = replica
|
||||
if m.callback != nil {
|
||||
go m.callback()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Self represents the current replica.
|
||||
func (m *Manager) Self() database.Replica {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.self
|
||||
}
|
||||
|
||||
// All returns every replica, including itself.
|
||||
func (m *Manager) All() []database.Replica {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return append(m.peers[:], m.self)
|
||||
}
|
||||
|
||||
// Regional returns all replicas in the same region excluding itself.
|
||||
func (m *Manager) Regional() []database.Replica {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
replicas := make([]database.Replica, 0)
|
||||
for _, replica := range m.peers {
|
||||
if replica.RegionID != m.self.RegionID {
|
||||
continue
|
||||
}
|
||||
replicas = append(replicas, replica)
|
||||
}
|
||||
return replicas
|
||||
}
|
||||
|
||||
// SetCallback sets a function to execute whenever new peers
|
||||
// are refreshed or updated.
|
||||
func (m *Manager) SetCallback(callback func()) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.callback = callback
|
||||
// Instantly call the callback to inform replicas!
|
||||
go callback()
|
||||
}
|
||||
|
||||
func (m *Manager) Close() error {
|
||||
m.closeMutex.Lock()
|
||||
select {
|
||||
case <-m.closed:
|
||||
m.closeMutex.Unlock()
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
close(m.closed)
|
||||
m.closeCancel()
|
||||
m.closeWait.Wait()
|
||||
m.closeMutex.Unlock()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFunc()
|
||||
_, err := m.db.UpdateReplica(ctx, database.UpdateReplicaParams{
|
||||
ID: m.self.ID,
|
||||
UpdatedAt: database.Now(),
|
||||
StartedAt: m.self.StartedAt,
|
||||
StoppedAt: sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
RelayAddress: m.self.RelayAddress,
|
||||
RegionID: m.self.RegionID,
|
||||
Hostname: m.self.Hostname,
|
||||
Version: m.self.Version,
|
||||
Error: m.self.Error,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update replica: %w", err)
|
||||
}
|
||||
err = m.pubsub.Publish(PubsubEvent, []byte(m.self.ID.String()))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish replica update: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
package replicasync_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/enterprise/replicasync"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("CreateOnNew", func(t *testing.T) {
|
||||
// This ensures that a new replica is created on New.
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
closeChan := make(chan struct{}, 1)
|
||||
cancel, err := pubsub.Subscribe(replicasync.PubsubEvent, func(ctx context.Context, message []byte) {
|
||||
closeChan <- struct{}{}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil)
|
||||
require.NoError(t, err)
|
||||
<-closeChan
|
||||
_ = server.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("ErrorsWithoutRelayAddress", func(t *testing.T) {
|
||||
// Ensures that the replica reports a successful status for
|
||||
// accessing all of its peers.
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now(),
|
||||
StartedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Hostname: "something",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "a relay address must be specified when running multiple replicas in the same region", err.Error())
|
||||
})
|
||||
t.Run("ConnectsToPeerReplica", func(t *testing.T) {
|
||||
// Ensures that the replica reports a successful status for
|
||||
// accessing all of its peers.
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now(),
|
||||
StartedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Hostname: "something",
|
||||
RelayAddress: srv.URL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
|
||||
RelayAddress: "http://169.254.169.254",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, server.Regional(), 1)
|
||||
require.Equal(t, peer.ID, server.Regional()[0].ID)
|
||||
require.Empty(t, server.Self().Error)
|
||||
_ = server.Close()
|
||||
})
|
||||
t.Run("ConnectsToPeerReplicaTLS", func(t *testing.T) {
|
||||
// Ensures that the replica reports a successful status for
|
||||
// accessing all of its peers.
|
||||
t.Parallel()
|
||||
rawCert := testutil.GenerateTLSCertificate(t, "hello.org")
|
||||
certificate, err := x509.ParseCertificate(rawCert.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(certificate)
|
||||
// nolint:gosec
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{rawCert},
|
||||
ServerName: "hello.org",
|
||||
RootCAs: pool,
|
||||
}
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
srv.TLS = tlsConfig
|
||||
srv.StartTLS()
|
||||
defer srv.Close()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now(),
|
||||
StartedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Hostname: "something",
|
||||
RelayAddress: srv.URL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
|
||||
RelayAddress: "http://169.254.169.254",
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, server.Regional(), 1)
|
||||
require.Equal(t, peer.ID, server.Regional()[0].ID)
|
||||
require.Empty(t, server.Self().Error)
|
||||
_ = server.Close()
|
||||
})
|
||||
t.Run("ConnectsToFakePeerWithError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now().Add(time.Minute),
|
||||
StartedAt: database.Now().Add(time.Minute),
|
||||
UpdatedAt: database.Now().Add(time.Minute),
|
||||
Hostname: "something",
|
||||
// Fake address to dial!
|
||||
RelayAddress: "http://127.0.0.1:1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
|
||||
PeerTimeout: 1 * time.Millisecond,
|
||||
RelayAddress: "http://127.0.0.1:1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, server.Regional(), 1)
|
||||
require.Equal(t, peer.ID, server.Regional()[0].ID)
|
||||
require.NotEmpty(t, server.Self().Error)
|
||||
require.Contains(t, server.Self().Error, "Failed to dial peers")
|
||||
_ = server.Close()
|
||||
})
|
||||
t.Run("RefreshOnPublish", func(t *testing.T) {
|
||||
// Refresh when a new replica appears!
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, nil)
|
||||
require.NoError(t, err)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
peer, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
RelayAddress: srv.URL,
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Publish multiple times to ensure it can handle that case.
|
||||
err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String()))
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(replicasync.PubsubEvent, []byte(peer.ID.String()))
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
return len(server.Regional()) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
_ = server.Close()
|
||||
})
|
||||
t.Run("DeletesOld", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{
|
||||
ID: uuid.New(),
|
||||
UpdatedAt: database.Now().Add(-time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
server, err := replicasync.New(context.Background(), slogtest.Make(t, nil), db, pubsub, &replicasync.Options{
|
||||
RelayAddress: "google.com",
|
||||
CleanupInterval: time.Millisecond,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer server.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
return len(server.Regional()) == 0
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
t.Run("TwentyConcurrent", func(t *testing.T) {
|
||||
// Ensures that twenty concurrent replicas can spawn and all
|
||||
// discover each other in parallel!
|
||||
t.Parallel()
|
||||
// This doesn't use the database fake because creating
|
||||
// this many PostgreSQL connections takes some
|
||||
// configuration tweaking.
|
||||
db := databasefake.New()
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
logger := slogtest.Make(t, nil)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
var wg sync.WaitGroup
|
||||
count := 20
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
server, err := replicasync.New(context.Background(), logger, db, pubsub, &replicasync.Options{
|
||||
RelayAddress: srv.URL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = server.Close()
|
||||
})
|
||||
done := false
|
||||
server.SetCallback(func() {
|
||||
if len(server.All()) != count {
|
||||
return
|
||||
}
|
||||
if done {
|
||||
return
|
||||
}
|
||||
done = true
|
||||
wg.Done()
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
|
@ -0,0 +1,575 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
agpl "github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
// NewCoordinator creates a new high availability coordinator
|
||||
// that uses PostgreSQL pubsub to exchange handshakes.
|
||||
func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
coord := &haCoordinator{
|
||||
id: uuid.New(),
|
||||
log: logger,
|
||||
pubsub: pubsub,
|
||||
closeFunc: cancelFunc,
|
||||
close: make(chan struct{}),
|
||||
nodes: map[uuid.UUID]*agpl.Node{},
|
||||
agentSockets: map[uuid.UUID]net.Conn{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
|
||||
}
|
||||
|
||||
if err := coord.runPubsub(ctx); err != nil {
|
||||
return nil, xerrors.Errorf("run coordinator pubsub: %w", err)
|
||||
}
|
||||
|
||||
return coord, nil
|
||||
}
|
||||
|
||||
type haCoordinator struct {
|
||||
id uuid.UUID
|
||||
log slog.Logger
|
||||
mutex sync.RWMutex
|
||||
pubsub database.Pubsub
|
||||
close chan struct{}
|
||||
closeFunc context.CancelFunc
|
||||
|
||||
// nodes maps agent and connection IDs their respective node.
|
||||
nodes map[uuid.UUID]*agpl.Node
|
||||
// agentSockets maps agent IDs to their open websocket.
|
||||
agentSockets map[uuid.UUID]net.Conn
|
||||
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
||||
// are subscribed to updates for that agent.
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
|
||||
}
|
||||
|
||||
// Node returns an in-memory node by ID.
|
||||
func (c *haCoordinator) Node(id uuid.UUID) *agpl.Node {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
node := c.nodes[id]
|
||||
return node
|
||||
}
|
||||
|
||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||
// with the specified ID.
|
||||
func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||
c.mutex.Lock()
|
||||
// When a new connection is requested, we update it with the latest
|
||||
// node of the agent. This allows the connection to establish.
|
||||
node, ok := c.nodes[agent]
|
||||
c.mutex.Unlock()
|
||||
if ok {
|
||||
data, err := json.Marshal([]*agpl.Node{node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal node: %w", err)
|
||||
}
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write nodes: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := c.publishClientHello(agent)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish client hello: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
if !ok {
|
||||
connectionSockets = map[uuid.UUID]net.Conn{}
|
||||
c.agentToConnectionSockets[agent] = connectionSockets
|
||||
}
|
||||
|
||||
// Insert this connection into a map so the agent can publish node updates.
|
||||
connectionSockets[id] = conn
|
||||
c.mutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Clean all traces of this connection from the map.
|
||||
delete(c.nodes, id)
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(connectionSockets, id)
|
||||
if len(connectionSockets) != 0 {
|
||||
return
|
||||
}
|
||||
delete(c.agentToConnectionSockets, agent)
|
||||
}()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
// Indefinitely handle messages from the client websocket.
|
||||
for {
|
||||
err := c.handleNextClientMessage(id, agent, decoder)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("handle next client message: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
|
||||
var node agpl.Node
|
||||
err := decoder.Decode(&node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
// Update the node of this client in our in-memory map. If an agent entirely
|
||||
// shuts down and reconnects, it needs to be aware of all clients attempting
|
||||
// to establish connections.
|
||||
c.nodes[id] = &node
|
||||
// Write the new node from this client to the actively connected agent.
|
||||
agentSocket, ok := c.agentSockets[agent]
|
||||
c.mutex.Unlock()
|
||||
if !ok {
|
||||
// If we don't own the agent locally, send it over pubsub to a node that
|
||||
// owns the agent.
|
||||
err := c.publishNodesToAgent(agent, []*agpl.Node{&node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish node to agent")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write the new node from this client to the actively
|
||||
// connected agent.
|
||||
data, err := json.Marshal([]*agpl.Node{&node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
|
||||
_, err = agentSocket.Write(data)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("write json: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeAgent accepts a WebSocket connection to an agent that listens to
|
||||
// incoming connections and publishes node updates.
|
||||
func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
||||
// Tell clients on other instances to send a callmemaybe to us.
|
||||
err := c.publishAgentHello(id)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish agent hello: %w", err)
|
||||
}
|
||||
|
||||
// Publish all nodes on this instance that want to connect to this agent.
|
||||
nodes := c.nodesSubscribedToAgent(id)
|
||||
if len(nodes) > 0 {
|
||||
data, err := json.Marshal(nodes)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal json: %w", err)
|
||||
}
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write nodes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If an old agent socket is connected, we close it
|
||||
// to avoid any leaks. This shouldn't ever occur because
|
||||
// we expect one agent to be running.
|
||||
c.mutex.Lock()
|
||||
oldAgentSocket, ok := c.agentSockets[id]
|
||||
if ok {
|
||||
_ = oldAgentSocket.Close()
|
||||
}
|
||||
c.agentSockets[id] = conn
|
||||
c.mutex.Unlock()
|
||||
defer func() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.agentSockets, id)
|
||||
delete(c.nodes, id)
|
||||
}()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
for {
|
||||
node, err := c.handleAgentUpdate(id, decoder)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("handle next agent message: %w", err)
|
||||
}
|
||||
|
||||
err = c.publishAgentToNodes(id, node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish agent to nodes: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *haCoordinator) nodesSubscribedToAgent(agentID uuid.UUID) []*agpl.Node {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
sockets, ok := c.agentToConnectionSockets[agentID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
nodes := make([]*agpl.Node, 0, len(sockets))
|
||||
for targetID := range sockets {
|
||||
node, ok := c.nodes[targetID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handleClientHello(id uuid.UUID) error {
|
||||
c.mutex.Lock()
|
||||
node, ok := c.nodes[id]
|
||||
c.mutex.Unlock()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return c.publishAgentToNodes(id, node)
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) {
|
||||
var node agpl.Node
|
||||
err := decoder.Decode(&node)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
oldNode := c.nodes[id]
|
||||
if oldNode != nil {
|
||||
if oldNode.AsOf.After(node.AsOf) {
|
||||
c.mutex.Unlock()
|
||||
return oldNode, nil
|
||||
}
|
||||
}
|
||||
c.nodes[id] = &node
|
||||
connectionSockets, ok := c.agentToConnectionSockets[id]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal([]*agpl.Node{&node})
|
||||
if err != nil {
|
||||
c.mutex.Unlock()
|
||||
return nil, xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
|
||||
// Publish the new node to every listening socket.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(connectionSockets))
|
||||
for _, connectionSocket := range connectionSockets {
|
||||
connectionSocket := connectionSocket
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_, _ = connectionSocket.Write(data)
|
||||
}()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
// Close closes all of the open connections in the coordinator and stops the
|
||||
// coordinator from accepting new connections.
|
||||
func (c *haCoordinator) Close() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
select {
|
||||
case <-c.close:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
close(c.close)
|
||||
c.closeFunc()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(len(c.agentSockets))
|
||||
for _, socket := range c.agentSockets {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
for _, connMap := range c.agentToConnectionSockets {
|
||||
wg.Add(len(connMap))
|
||||
for _, socket := range connMap {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) publishNodesToAgent(recipient uuid.UUID, nodes []*agpl.Node) error {
|
||||
msg, err := c.formatCallMeMaybe(recipient, nodes)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format publish message: %w", err)
|
||||
}
|
||||
|
||||
err = c.pubsub.Publish("wireguard_peers", msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) publishAgentHello(id uuid.UUID) error {
|
||||
msg, err := c.formatAgentHello(id)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format publish message: %w", err)
|
||||
}
|
||||
|
||||
err = c.pubsub.Publish("wireguard_peers", msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) publishClientHello(id uuid.UUID) error {
|
||||
msg, err := c.formatClientHello(id)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format client hello: %w", err)
|
||||
}
|
||||
err = c.pubsub.Publish("wireguard_peers", msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish client hello: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) publishAgentToNodes(id uuid.UUID, node *agpl.Node) error {
|
||||
msg, err := c.formatAgentUpdate(id, node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format publish message: %w", err)
|
||||
}
|
||||
|
||||
err = c.pubsub.Publish("wireguard_peers", msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) runPubsub(ctx context.Context) error {
|
||||
messageQueue := make(chan []byte, 64)
|
||||
cancelSub, err := c.pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) {
|
||||
select {
|
||||
case messageQueue <- message:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe wireguard peers")
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
var message []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case message = <-messageQueue:
|
||||
}
|
||||
c.handlePubsubMessage(ctx, message)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer cancelSub()
|
||||
<-c.close
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handlePubsubMessage(ctx context.Context, message []byte) {
|
||||
sp := bytes.Split(message, []byte("|"))
|
||||
if len(sp) != 4 {
|
||||
c.log.Error(ctx, "invalid wireguard peer message", slog.F("msg", string(message)))
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
coordinatorID = sp[0]
|
||||
eventType = sp[1]
|
||||
agentID = sp[2]
|
||||
nodeJSON = sp[3]
|
||||
)
|
||||
|
||||
sender, err := uuid.ParseBytes(coordinatorID)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "invalid sender id", slog.F("id", string(coordinatorID)), slog.F("msg", string(message)))
|
||||
return
|
||||
}
|
||||
|
||||
// We sent this message!
|
||||
if sender == c.id {
|
||||
return
|
||||
}
|
||||
|
||||
switch string(eventType) {
|
||||
case "callmemaybe":
|
||||
agentUUID, err := uuid.ParseBytes(agentID)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
|
||||
return
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
agentSocket, ok := c.agentSockets[agentUUID]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
// We get a single node over pubsub, so turn into an array.
|
||||
_, err = agentSocket.Write(nodeJSON)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
return
|
||||
}
|
||||
c.log.Error(ctx, "send callmemaybe to agent", slog.Error(err))
|
||||
return
|
||||
}
|
||||
case "clienthello":
|
||||
agentUUID, err := uuid.ParseBytes(agentID)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
|
||||
return
|
||||
}
|
||||
|
||||
err = c.handleClientHello(agentUUID)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "handle agent request node", slog.Error(err))
|
||||
return
|
||||
}
|
||||
case "agenthello":
|
||||
agentUUID, err := uuid.ParseBytes(agentID)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
|
||||
return
|
||||
}
|
||||
|
||||
nodes := c.nodesSubscribedToAgent(agentUUID)
|
||||
if len(nodes) > 0 {
|
||||
err := c.publishNodesToAgent(agentUUID, nodes)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "publish nodes to agent", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
case "agentupdate":
|
||||
agentUUID, err := uuid.ParseBytes(agentID)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "invalid agent id", slog.F("id", string(agentID)))
|
||||
return
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(nodeJSON))
|
||||
_, err = c.handleAgentUpdate(agentUUID, decoder)
|
||||
if err != nil {
|
||||
c.log.Error(ctx, "handle agent update", slog.Error(err))
|
||||
return
|
||||
}
|
||||
default:
|
||||
c.log.Error(ctx, "unknown peer event", slog.F("name", string(eventType)))
|
||||
}
|
||||
}
|
||||
|
||||
// format: <coordinator id>|callmemaybe|<recipient id>|<node json>
|
||||
func (c *haCoordinator) formatCallMeMaybe(recipient uuid.UUID, nodes []*agpl.Node) ([]byte, error) {
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
buf.WriteString(c.id.String() + "|")
|
||||
buf.WriteString("callmemaybe|")
|
||||
buf.WriteString(recipient.String() + "|")
|
||||
err := json.NewEncoder(&buf).Encode(nodes)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("encode node: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// format: <coordinator id>|agenthello|<node id>|
|
||||
func (c *haCoordinator) formatAgentHello(id uuid.UUID) ([]byte, error) {
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
buf.WriteString(c.id.String() + "|")
|
||||
buf.WriteString("agenthello|")
|
||||
buf.WriteString(id.String() + "|")
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// format: <coordinator id>|clienthello|<agent id>|
|
||||
func (c *haCoordinator) formatClientHello(id uuid.UUID) ([]byte, error) {
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
buf.WriteString(c.id.String() + "|")
|
||||
buf.WriteString("clienthello|")
|
||||
buf.WriteString(id.String() + "|")
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// format: <coordinator id>|agentupdate|<node id>|<node json>
|
||||
func (c *haCoordinator) formatAgentUpdate(id uuid.UUID, node *agpl.Node) ([]byte, error) {
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
buf.WriteString(c.id.String() + "|")
|
||||
buf.WriteString("agentupdate|")
|
||||
buf.WriteString(id.String() + "|")
|
||||
err := json.NewEncoder(&buf).Encode(node)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("encode node: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
|
@ -0,0 +1,261 @@
|
|||
package tailnet_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestCoordinatorSingle(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("ClientWithoutAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
client, server := net.Pipe()
|
||||
sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error {
|
||||
return nil
|
||||
})
|
||||
id := uuid.New()
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeClient(server, id, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&agpl.Node{})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(id) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
<-errChan
|
||||
<-closeChan
|
||||
})
|
||||
|
||||
t.Run("AgentWithoutClients", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
client, server := net.Pipe()
|
||||
sendNode, errChan := agpl.ServeCoordinator(client, func(node []*agpl.Node) error {
|
||||
return nil
|
||||
})
|
||||
id := uuid.New()
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(server, id)
|
||||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&agpl.Node{})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(id) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
<-errChan
|
||||
<-closeChan
|
||||
})
|
||||
|
||||
t.Run("AgentWithClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agentWS, agentServerWS := net.Pipe()
|
||||
defer agentWS.Close()
|
||||
agentNodeChan := make(chan []*agpl.Node)
|
||||
sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
|
||||
agentNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
agentID := uuid.New()
|
||||
closeAgentChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(agentServerWS, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeAgentChan)
|
||||
}()
|
||||
sendAgentNode(&agpl.Node{})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(agentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
clientWS, clientServerWS := net.Pipe()
|
||||
defer clientWS.Close()
|
||||
defer clientServerWS.Close()
|
||||
clientNodeChan := make(chan []*agpl.Node)
|
||||
sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error {
|
||||
clientNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
clientID := uuid.New()
|
||||
closeClientChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeClientChan)
|
||||
}()
|
||||
agentNodes := <-clientNodeChan
|
||||
require.Len(t, agentNodes, 1)
|
||||
sendClientNode(&agpl.Node{})
|
||||
clientNodes := <-agentNodeChan
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
||||
// Ensure an update to the agent node reaches the client!
|
||||
sendAgentNode(&agpl.Node{})
|
||||
agentNodes = <-clientNodeChan
|
||||
require.Len(t, agentNodes, 1)
|
||||
|
||||
// Close the agent WebSocket so a new one can connect.
|
||||
err = agentWS.Close()
|
||||
require.NoError(t, err)
|
||||
<-agentErrChan
|
||||
<-closeAgentChan
|
||||
|
||||
// Create a new agent connection. This is to simulate a reconnect!
|
||||
agentWS, agentServerWS = net.Pipe()
|
||||
defer agentWS.Close()
|
||||
agentNodeChan = make(chan []*agpl.Node)
|
||||
_, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
|
||||
agentNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
closeAgentChan = make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(agentServerWS, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeAgentChan)
|
||||
}()
|
||||
// Ensure the existing listening client sends it's node immediately!
|
||||
clientNodes = <-agentNodeChan
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
||||
err = agentWS.Close()
|
||||
require.NoError(t, err)
|
||||
<-agentErrChan
|
||||
<-closeAgentChan
|
||||
|
||||
err = clientWS.Close()
|
||||
require.NoError(t, err)
|
||||
<-clientErrChan
|
||||
<-closeClientChan
|
||||
})
|
||||
}
|
||||
|
||||
func TestCoordinatorHA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AgentWithClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, pubsub := dbtestutil.NewDB(t)
|
||||
|
||||
coordinator1, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
|
||||
require.NoError(t, err)
|
||||
defer coordinator1.Close()
|
||||
|
||||
agentWS, agentServerWS := net.Pipe()
|
||||
defer agentWS.Close()
|
||||
agentNodeChan := make(chan []*agpl.Node)
|
||||
sendAgentNode, agentErrChan := agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
|
||||
agentNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
agentID := uuid.New()
|
||||
closeAgentChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator1.ServeAgent(agentServerWS, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeAgentChan)
|
||||
}()
|
||||
sendAgentNode(&agpl.Node{})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator1.Node(agentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
coordinator2, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub)
|
||||
require.NoError(t, err)
|
||||
defer coordinator2.Close()
|
||||
|
||||
clientWS, clientServerWS := net.Pipe()
|
||||
defer clientWS.Close()
|
||||
defer clientServerWS.Close()
|
||||
clientNodeChan := make(chan []*agpl.Node)
|
||||
sendClientNode, clientErrChan := agpl.ServeCoordinator(clientWS, func(nodes []*agpl.Node) error {
|
||||
clientNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
clientID := uuid.New()
|
||||
closeClientChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator2.ServeClient(clientServerWS, clientID, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeClientChan)
|
||||
}()
|
||||
agentNodes := <-clientNodeChan
|
||||
require.Len(t, agentNodes, 1)
|
||||
sendClientNode(&agpl.Node{})
|
||||
_ = sendClientNode
|
||||
clientNodes := <-agentNodeChan
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
||||
// Ensure an update to the agent node reaches the client!
|
||||
sendAgentNode(&agpl.Node{})
|
||||
agentNodes = <-clientNodeChan
|
||||
require.Len(t, agentNodes, 1)
|
||||
|
||||
// Close the agent WebSocket so a new one can connect.
|
||||
require.NoError(t, agentWS.Close())
|
||||
require.NoError(t, agentServerWS.Close())
|
||||
<-agentErrChan
|
||||
<-closeAgentChan
|
||||
|
||||
// Create a new agent connection. This is to simulate a reconnect!
|
||||
agentWS, agentServerWS = net.Pipe()
|
||||
defer agentWS.Close()
|
||||
agentNodeChan = make(chan []*agpl.Node)
|
||||
_, agentErrChan = agpl.ServeCoordinator(agentWS, func(nodes []*agpl.Node) error {
|
||||
agentNodeChan <- nodes
|
||||
return nil
|
||||
})
|
||||
closeAgentChan = make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator1.ServeAgent(agentServerWS, agentID)
|
||||
assert.NoError(t, err)
|
||||
close(closeAgentChan)
|
||||
}()
|
||||
// Ensure the existing listening client sends it's node immediately!
|
||||
clientNodes = <-agentNodeChan
|
||||
require.Len(t, clientNodes, 1)
|
||||
|
||||
err = agentWS.Close()
|
||||
require.NoError(t, err)
|
||||
<-agentErrChan
|
||||
<-closeAgentChan
|
||||
|
||||
err = clientWS.Close()
|
||||
require.NoError(t, err)
|
||||
<-clientErrChan
|
||||
<-closeClientChan
|
||||
})
|
||||
}
|
2
go.mod
2
go.mod
|
@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0
|
|||
|
||||
// There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here:
|
||||
// https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main
|
||||
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c
|
||||
replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5
|
||||
|
||||
// Switch to our fork that imports fixes from http://github.com/tailscale/ssh.
|
||||
// See: https://github.com/coder/coder/issues/3371
|
||||
|
|
4
go.sum
4
go.sum
|
@ -351,8 +351,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY=
|
|||
github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY=
|
||||
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko=
|
||||
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
|
||||
github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c h1:xa6lr5Pj87Is26tgpzwBsEGKL7aVz7/fRGgY9QIbf3E=
|
||||
github.com/coder/tailscale v1.1.1-0.20220926024748-50f068456c6c/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak=
|
||||
github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5 h1:WVH6e/qK3Wpl0wbmpORD2oQ1qLJborF3fsFHyO1ps0Y=
|
||||
github.com/coder/tailscale v1.1.1-0.20221015033036-5861cbbf7bf5/go.mod h1:5amxy08qijEa8bcTW2SeIy4MIqcmd7LMsuOxqOlj2Ak=
|
||||
github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE=
|
||||
github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU=
|
||||
github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU=
|
||||
|
|
|
@ -14,10 +14,7 @@ metadata:
|
|||
{{- include "coder.labels" . | nindent 4 }}
|
||||
annotations: {{ toYaml .Values.coder.annotations | nindent 4}}
|
||||
spec:
|
||||
# NOTE: this is currently not used as coder v2 does not support high
|
||||
# availability yet.
|
||||
# replicas: {{ .Values.coder.replicaCount }}
|
||||
replicas: 1
|
||||
replicas: {{ .Values.coder.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "coder.selectorLabels" . | nindent 6 }}
|
||||
|
@ -38,6 +35,13 @@ spec:
|
|||
env:
|
||||
- name: CODER_ADDRESS
|
||||
value: "0.0.0.0:{{ include "coder.port" . }}"
|
||||
# Used for inter-pod communication with high-availability.
|
||||
- name: KUBE_POD_IP
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: status.podIP
|
||||
- name: CODER_DERP_SERVER_RELAY_ADDRESS
|
||||
value: "{{ include "coder.portName" . }}://$(KUBE_POD_IP):{{ include "coder.port" . }}"
|
||||
{{- include "coder.tlsEnv" . | nindent 12 }}
|
||||
{{- with .Values.coder.env -}}
|
||||
{{ toYaml . | nindent 12 }}
|
||||
|
|
|
@ -10,6 +10,7 @@ metadata:
|
|||
{{- toYaml .Values.coder.service.annotations | nindent 4 }}
|
||||
spec:
|
||||
type: {{ .Values.coder.service.type }}
|
||||
sessionAffinity: ClientIP
|
||||
ports:
|
||||
- name: {{ include "coder.portName" . | quote }}
|
||||
port: {{ include "coder.servicePort" . }}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# coder -- Primary configuration for `coder server`.
|
||||
coder:
|
||||
# NOTE: this is currently not used as coder v2 does not support high
|
||||
# availability yet.
|
||||
# # coder.replicaCount -- The number of Kubernetes deployment replicas.
|
||||
# replicaCount: 1
|
||||
# coder.replicaCount -- The number of Kubernetes deployment replicas.
|
||||
# This should only be increased if High Availability is enabled.
|
||||
# This is an Enterprise feature. Contact sales@coder.com.
|
||||
replicaCount: 1
|
||||
|
||||
# coder.image -- The image to use for Coder.
|
||||
image:
|
||||
|
|
|
@ -28,6 +28,7 @@ export const defaultEntitlements = (): TypesGen.Entitlements => {
|
|||
return {
|
||||
features: features,
|
||||
has_license: false,
|
||||
errors: [],
|
||||
warnings: [],
|
||||
experimental: false,
|
||||
trial: false,
|
||||
|
|
|
@ -274,6 +274,7 @@ export interface DeploymentFlags {
|
|||
readonly derp_server_region_code: StringFlag
|
||||
readonly derp_server_region_name: StringFlag
|
||||
readonly derp_server_stun_address: StringArrayFlag
|
||||
readonly derp_server_relay_address: StringFlag
|
||||
readonly derp_config_url: StringFlag
|
||||
readonly derp_config_path: StringFlag
|
||||
readonly prom_enabled: BoolFlag
|
||||
|
@ -337,6 +338,7 @@ export interface DurationFlag {
|
|||
export interface Entitlements {
|
||||
readonly features: Record<string, Feature>
|
||||
readonly warnings: string[]
|
||||
readonly errors: string[]
|
||||
readonly has_license: boolean
|
||||
readonly experimental: boolean
|
||||
readonly trial: boolean
|
||||
|
@ -528,6 +530,17 @@ export interface PutExtendWorkspaceRequest {
|
|||
readonly deadline: string
|
||||
}
|
||||
|
||||
// From codersdk/replicas.go
|
||||
export interface Replica {
|
||||
readonly id: string
|
||||
readonly hostname: string
|
||||
readonly created_at: string
|
||||
readonly relay_address: string
|
||||
readonly region_id: number
|
||||
readonly error: string
|
||||
readonly database_latency: number
|
||||
}
|
||||
|
||||
// From codersdk/error.go
|
||||
export interface Response {
|
||||
readonly message: string
|
||||
|
|
|
@ -8,15 +8,15 @@ export const LicenseBanner: React.FC = () => {
|
|||
const [entitlementsState, entitlementsSend] = useActor(
|
||||
xServices.entitlementsXService,
|
||||
)
|
||||
const { warnings } = entitlementsState.context.entitlements
|
||||
const { errors, warnings } = entitlementsState.context.entitlements
|
||||
|
||||
/** Gets license data on app mount because LicenseBanner is mounted in App */
|
||||
useEffect(() => {
|
||||
entitlementsSend("GET_ENTITLEMENTS")
|
||||
}, [entitlementsSend])
|
||||
|
||||
if (warnings.length > 0) {
|
||||
return <LicenseBannerView warnings={warnings} />
|
||||
if (errors.length > 0 || warnings.length > 0) {
|
||||
return <LicenseBannerView errors={errors} warnings={warnings} />
|
||||
} else {
|
||||
return null
|
||||
}
|
||||
|
|
|
@ -12,13 +12,23 @@ const Template: Story<LicenseBannerViewProps> = (args) => (
|
|||
|
||||
export const OneWarning = Template.bind({})
|
||||
OneWarning.args = {
|
||||
errors: [],
|
||||
warnings: ["You have exceeded the number of seats in your license."],
|
||||
}
|
||||
|
||||
export const TwoWarnings = Template.bind({})
|
||||
TwoWarnings.args = {
|
||||
errors: [],
|
||||
warnings: [
|
||||
"You have exceeded the number of seats in your license.",
|
||||
"You are flying too close to the sun.",
|
||||
],
|
||||
}
|
||||
|
||||
export const OneError = Template.bind({})
|
||||
OneError.args = {
|
||||
errors: [
|
||||
"You have multiple replicas but high availability is an Enterprise feature. You will be unable to connect to workspaces.",
|
||||
],
|
||||
warnings: [],
|
||||
}
|
||||
|
|
|
@ -2,47 +2,56 @@ import { makeStyles } from "@material-ui/core/styles"
|
|||
import { Expander } from "components/Expander/Expander"
|
||||
import { Pill } from "components/Pill/Pill"
|
||||
import { useState } from "react"
|
||||
import { colors } from "theme/colors"
|
||||
|
||||
export const Language = {
|
||||
licenseIssue: "License Issue",
|
||||
licenseIssues: (num: number): string => `${num} License Issues`,
|
||||
upgrade: "Contact us to upgrade your license.",
|
||||
upgrade: "Contact sales@coder.com.",
|
||||
exceeded: "It looks like you've exceeded some limits of your license.",
|
||||
lessDetails: "Less",
|
||||
moreDetails: "More",
|
||||
}
|
||||
|
||||
export interface LicenseBannerViewProps {
|
||||
errors: string[]
|
||||
warnings: string[]
|
||||
}
|
||||
|
||||
export const LicenseBannerView: React.FC<LicenseBannerViewProps> = ({
|
||||
errors,
|
||||
warnings,
|
||||
}) => {
|
||||
const styles = useStyles()
|
||||
const [showDetails, setShowDetails] = useState(false)
|
||||
if (warnings.length === 1) {
|
||||
const isError = errors.length > 0
|
||||
const messages = [...errors, ...warnings]
|
||||
const type = isError ? "error" : "warning"
|
||||
|
||||
if (messages.length === 1) {
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<Pill text={Language.licenseIssue} type="warning" lightBorder />
|
||||
<span className={styles.text}>{warnings[0]}</span>
|
||||
|
||||
<a href="mailto:sales@coder.com" className={styles.link}>
|
||||
{Language.upgrade}
|
||||
</a>
|
||||
<div className={`${styles.container} ${type}`}>
|
||||
<Pill text={Language.licenseIssue} type={type} lightBorder />
|
||||
<div className={styles.leftContent}>
|
||||
<span>{messages[0]}</span>
|
||||
|
||||
<a href="mailto:sales@coder.com" className={styles.link}>
|
||||
{Language.upgrade}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
} else {
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.flex}>
|
||||
<div className={styles.leftContent}>
|
||||
<Pill
|
||||
text={Language.licenseIssues(warnings.length)}
|
||||
type="warning"
|
||||
lightBorder
|
||||
/>
|
||||
<span className={styles.text}>{Language.exceeded}</span>
|
||||
<div className={`${styles.container} ${type}`}>
|
||||
<Pill
|
||||
text={Language.licenseIssues(messages.length)}
|
||||
type={type}
|
||||
lightBorder
|
||||
/>
|
||||
<div className={styles.leftContent}>
|
||||
<div>
|
||||
{Language.exceeded}
|
||||
|
||||
<a href="mailto:sales@coder.com" className={styles.link}>
|
||||
{Language.upgrade}
|
||||
|
@ -50,9 +59,9 @@ export const LicenseBannerView: React.FC<LicenseBannerViewProps> = ({
|
|||
</div>
|
||||
<Expander expanded={showDetails} setExpanded={setShowDetails}>
|
||||
<ul className={styles.list}>
|
||||
{warnings.map((warning) => (
|
||||
<li className={styles.listItem} key={`${warning}`}>
|
||||
{warning}
|
||||
{messages.map((message) => (
|
||||
<li className={styles.listItem} key={`${message}`}>
|
||||
{message}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
|
@ -67,14 +76,18 @@ const useStyles = makeStyles((theme) => ({
|
|||
container: {
|
||||
padding: theme.spacing(1.5),
|
||||
backgroundColor: theme.palette.warning.main,
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
|
||||
"&.error": {
|
||||
backgroundColor: colors.red[12],
|
||||
},
|
||||
},
|
||||
flex: {
|
||||
display: "flex",
|
||||
display: "column",
|
||||
},
|
||||
leftContent: {
|
||||
marginRight: theme.spacing(1),
|
||||
},
|
||||
text: {
|
||||
marginLeft: theme.spacing(1),
|
||||
},
|
||||
link: {
|
||||
|
@ -83,9 +96,10 @@ const useStyles = makeStyles((theme) => ({
|
|||
fontWeight: "bold",
|
||||
},
|
||||
list: {
|
||||
margin: theme.spacing(1.5),
|
||||
padding: theme.spacing(1),
|
||||
margin: 0,
|
||||
},
|
||||
listItem: {
|
||||
margin: theme.spacing(1),
|
||||
margin: theme.spacing(0.5),
|
||||
},
|
||||
}))
|
||||
|
|
|
@ -821,6 +821,7 @@ export const makeMockApiError = ({
|
|||
})
|
||||
|
||||
export const MockEntitlements: TypesGen.Entitlements = {
|
||||
errors: [],
|
||||
warnings: [],
|
||||
has_license: false,
|
||||
features: {},
|
||||
|
@ -829,6 +830,7 @@ export const MockEntitlements: TypesGen.Entitlements = {
|
|||
}
|
||||
|
||||
export const MockEntitlementsWithWarnings: TypesGen.Entitlements = {
|
||||
errors: [],
|
||||
warnings: ["You are over your active user limit.", "And another thing."],
|
||||
has_license: true,
|
||||
experimental: false,
|
||||
|
@ -852,6 +854,7 @@ export const MockEntitlementsWithWarnings: TypesGen.Entitlements = {
|
|||
}
|
||||
|
||||
export const MockEntitlementsWithAuditLog: TypesGen.Entitlements = {
|
||||
errors: [],
|
||||
warnings: [],
|
||||
has_license: true,
|
||||
experimental: false,
|
||||
|
|
|
@ -20,6 +20,7 @@ export type EntitlementsEvent =
|
|||
| { type: "HIDE_MOCK_BANNER" }
|
||||
|
||||
const emptyEntitlements = {
|
||||
errors: [],
|
||||
warnings: [],
|
||||
features: {},
|
||||
has_license: false,
|
||||
|
|
|
@ -48,7 +48,10 @@ type Options struct {
|
|||
Addresses []netip.Prefix
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
||||
Logger slog.Logger
|
||||
// BlockEndpoints specifies whether P2P endpoints are blocked.
|
||||
// If so, only DERPs can establish connections.
|
||||
BlockEndpoints bool
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
|
||||
|
@ -175,6 +178,7 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter"))))
|
||||
dialContext, dialCancel := context.WithCancel(context.Background())
|
||||
server := &Conn{
|
||||
blockEndpoints: options.BlockEndpoints,
|
||||
dialContext: dialContext,
|
||||
dialCancel: dialCancel,
|
||||
closed: make(chan struct{}),
|
||||
|
@ -240,11 +244,12 @@ func IP() netip.Addr {
|
|||
|
||||
// Conn is an actively listening Wireguard connection.
|
||||
type Conn struct {
|
||||
dialContext context.Context
|
||||
dialCancel context.CancelFunc
|
||||
mutex sync.Mutex
|
||||
closed chan struct{}
|
||||
logger slog.Logger
|
||||
dialContext context.Context
|
||||
dialCancel context.CancelFunc
|
||||
mutex sync.Mutex
|
||||
closed chan struct{}
|
||||
logger slog.Logger
|
||||
blockEndpoints bool
|
||||
|
||||
dialer *tsdial.Dialer
|
||||
tunDevice *tstun.Wrapper
|
||||
|
@ -323,6 +328,8 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
|
|||
delete(c.peerMap, peer.ID)
|
||||
}
|
||||
for _, node := range nodes {
|
||||
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
|
||||
|
||||
peerStatus, ok := status.Peer[node.Key]
|
||||
peerNode := &tailcfg.Node{
|
||||
ID: node.ID,
|
||||
|
@ -339,6 +346,13 @@ func (c *Conn) UpdateNodes(nodes []*Node) error {
|
|||
// reason. TODO: @kylecarbs debug this!
|
||||
KeepAlive: ok && peerStatus.Active,
|
||||
}
|
||||
// If no preferred DERP is provided, don't set an IP!
|
||||
if node.PreferredDERP == 0 {
|
||||
peerNode.DERP = ""
|
||||
}
|
||||
if c.blockEndpoints {
|
||||
peerNode.Endpoints = nil
|
||||
}
|
||||
c.peerMap[node.ID] = peerNode
|
||||
}
|
||||
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
|
||||
|
@ -421,6 +435,7 @@ func (c *Conn) sendNode() {
|
|||
}
|
||||
node := &Node{
|
||||
ID: c.netMap.SelfNode.ID,
|
||||
AsOf: c.lastStatus,
|
||||
Key: c.netMap.SelfNode.Key,
|
||||
Addresses: c.netMap.SelfNode.Addresses,
|
||||
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
|
||||
|
@ -429,6 +444,9 @@ func (c *Conn) sendNode() {
|
|||
PreferredDERP: c.lastPreferredDERP,
|
||||
DERPLatency: c.lastDERPLatency,
|
||||
}
|
||||
if c.blockEndpoints {
|
||||
node.Endpoints = nil
|
||||
}
|
||||
nodeCallback := c.nodeCallback
|
||||
if nodeCallback == nil {
|
||||
return
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -14,10 +15,30 @@ import (
|
|||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// Coordinator exchanges nodes with agents to establish connections.
|
||||
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
|
||||
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
|
||||
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
|
||||
// Coordinators have different guarantees for HA support.
|
||||
type Coordinator interface {
|
||||
// Node returns an in-memory node by ID.
|
||||
Node(id uuid.UUID) *Node
|
||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||
// with the specified ID.
|
||||
ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error
|
||||
// ServeAgent accepts a WebSocket connection to an agent that listens to
|
||||
// incoming connections and publishes node updates.
|
||||
ServeAgent(conn net.Conn, id uuid.UUID) error
|
||||
// Close closes the coordinator.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Node represents a node in the network.
|
||||
type Node struct {
|
||||
// ID is used to identify the connection.
|
||||
ID tailcfg.NodeID `json:"id"`
|
||||
// AsOf is the time the node was created.
|
||||
AsOf time.Time `json:"as_of"`
|
||||
// Key is the Wireguard public key of the node.
|
||||
Key key.NodePublic `json:"key"`
|
||||
// DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections.
|
||||
|
@ -75,48 +96,59 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func
|
|||
}, errChan
|
||||
}
|
||||
|
||||
// NewCoordinator constructs a new in-memory connection coordinator.
|
||||
func NewCoordinator() *Coordinator {
|
||||
return &Coordinator{
|
||||
// NewCoordinator constructs a new in-memory connection coordinator. This
|
||||
// coordinator is incompatible with multiple Coder replicas as all node data is
|
||||
// in-memory.
|
||||
func NewCoordinator() Coordinator {
|
||||
return &coordinator{
|
||||
closed: false,
|
||||
nodes: map[uuid.UUID]*Node{},
|
||||
agentSockets: map[uuid.UUID]net.Conn{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]net.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
// Coordinator exchanges nodes with agents to establish connections.
|
||||
// coordinator exchanges nodes with agents to establish connections entirely in-memory.
|
||||
// The Enterprise implementation provides this for high-availability.
|
||||
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
|
||||
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
|
||||
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
|
||||
// This coordinator is incompatible with multiple Coder
|
||||
// replicas as all node data is in-memory.
|
||||
type Coordinator struct {
|
||||
mutex sync.Mutex
|
||||
type coordinator struct {
|
||||
mutex sync.Mutex
|
||||
closed bool
|
||||
|
||||
// Maps agent and connection IDs to a node.
|
||||
// nodes maps agent and connection IDs their respective node.
|
||||
nodes map[uuid.UUID]*Node
|
||||
// Maps agent ID to an open socket.
|
||||
// agentSockets maps agent IDs to their open websocket.
|
||||
agentSockets map[uuid.UUID]net.Conn
|
||||
// Maps agent ID to connection ID for sending
|
||||
// new node data as it comes in!
|
||||
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
||||
// are subscribed to updates for that agent.
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]net.Conn
|
||||
}
|
||||
|
||||
// Node returns an in-memory node by ID.
|
||||
func (c *Coordinator) Node(id uuid.UUID) *Node {
|
||||
// If the node does not exist, nil is returned.
|
||||
func (c *coordinator) Node(id uuid.UUID) *Node {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
node := c.nodes[id]
|
||||
return node
|
||||
return c.nodes[id]
|
||||
}
|
||||
|
||||
// ServeClient accepts a WebSocket connection that wants to
|
||||
// connect to an agent with the specified ID.
|
||||
func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||
// with the specified ID.
|
||||
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||
c.mutex.Lock()
|
||||
if c.closed {
|
||||
c.mutex.Unlock()
|
||||
return xerrors.New("coordinator is closed")
|
||||
}
|
||||
|
||||
// When a new connection is requested, we update it with the latest
|
||||
// node of the agent. This allows the connection to establish.
|
||||
node, ok := c.nodes[agent]
|
||||
c.mutex.Unlock()
|
||||
if ok {
|
||||
data, err := json.Marshal([]*Node{node})
|
||||
if err != nil {
|
||||
|
@ -129,6 +161,7 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
|
|||
return xerrors.Errorf("write nodes: %w", err)
|
||||
}
|
||||
}
|
||||
c.mutex.Lock()
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
if !ok {
|
||||
connectionSockets = map[uuid.UUID]net.Conn{}
|
||||
|
@ -156,47 +189,62 @@ func (c *Coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
|
|||
|
||||
decoder := json.NewDecoder(conn)
|
||||
for {
|
||||
var node Node
|
||||
err := decoder.Decode(&node)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
err := c.handleNextClientMessage(id, agent, decoder)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
c.mutex.Lock()
|
||||
// Update the node of this client in our in-memory map.
|
||||
// If an agent entirely shuts down and reconnects, it
|
||||
// needs to be aware of all clients attempting to
|
||||
// establish connections.
|
||||
c.nodes[id] = &node
|
||||
agentSocket, ok := c.agentSockets[agent]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
continue
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
// Write the new node from this client to the actively
|
||||
// connected agent.
|
||||
data, err := json.Marshal([]*Node{&node})
|
||||
if err != nil {
|
||||
c.mutex.Unlock()
|
||||
return xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
_, err = agentSocket.Write(data)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write json: %w", err)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("handle next client message: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
|
||||
var node Node
|
||||
err := decoder.Decode(&node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
// Update the node of this client in our in-memory map. If an agent entirely
|
||||
// shuts down and reconnects, it needs to be aware of all clients attempting
|
||||
// to establish connections.
|
||||
c.nodes[id] = &node
|
||||
|
||||
agentSocket, ok := c.agentSockets[agent]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
// Write the new node from this client to the actively connected agent.
|
||||
data, err := json.Marshal([]*Node{&node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
|
||||
_, err = agentSocket.Write(data)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("write json: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeAgent accepts a WebSocket connection to an agent that
|
||||
// listens to incoming connections and publishes node updates.
|
||||
func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
||||
func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
||||
c.mutex.Lock()
|
||||
if c.closed {
|
||||
c.mutex.Unlock()
|
||||
return xerrors.New("coordinator is closed")
|
||||
}
|
||||
|
||||
sockets, ok := c.agentToConnectionSockets[id]
|
||||
if ok {
|
||||
// Publish all nodes that want to connect to the
|
||||
|
@ -209,16 +257,16 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
|||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
data, err := json.Marshal(nodes)
|
||||
if err != nil {
|
||||
c.mutex.Unlock()
|
||||
return xerrors.Errorf("marshal json: %w", err)
|
||||
}
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
c.mutex.Unlock()
|
||||
return xerrors.Errorf("write nodes: %w", err)
|
||||
}
|
||||
c.mutex.Lock()
|
||||
}
|
||||
|
||||
// If an old agent socket is connected, we close it
|
||||
|
@ -239,36 +287,84 @@ func (c *Coordinator) ServeAgent(conn net.Conn, id uuid.UUID) error {
|
|||
|
||||
decoder := json.NewDecoder(conn)
|
||||
for {
|
||||
var node Node
|
||||
err := decoder.Decode(&node)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
err := c.handleNextAgentMessage(id, decoder)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("handle next agent message: %w", err)
|
||||
}
|
||||
c.mutex.Lock()
|
||||
c.nodes[id] = &node
|
||||
connectionSockets, ok := c.agentToConnectionSockets[id]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
continue
|
||||
}
|
||||
data, err := json.Marshal([]*Node{&node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
// Publish the new node to every listening socket.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(connectionSockets))
|
||||
for _, connectionSocket := range connectionSockets {
|
||||
connectionSocket := connectionSocket
|
||||
}
|
||||
}
|
||||
|
||||
func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error {
|
||||
var node Node
|
||||
err := decoder.Decode(&node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.nodes[id] = &node
|
||||
connectionSockets, ok := c.agentToConnectionSockets[id]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal([]*Node{&node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal nodes: %w", err)
|
||||
}
|
||||
|
||||
// Publish the new node to every listening socket.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(connectionSockets))
|
||||
for _, connectionSocket := range connectionSockets {
|
||||
connectionSocket := connectionSocket
|
||||
go func() {
|
||||
_ = connectionSocket.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
_, _ = connectionSocket.Write(data)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
c.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes all of the open connections in the coordinator and stops the
|
||||
// coordinator from accepting new connections.
|
||||
func (c *coordinator) Close() error {
|
||||
c.mutex.Lock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
c.mutex.Unlock()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(len(c.agentSockets))
|
||||
for _, socket := range c.agentSockets {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
for _, connMap := range c.agentToConnectionSockets {
|
||||
wg.Add(len(connMap))
|
||||
for _, socket := range connMap {
|
||||
socket := socket
|
||||
go func() {
|
||||
_, _ = connectionSocket.Write(data)
|
||||
_ = socket.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@ func TestCoordinator(t *testing.T) {
|
|||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(id) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
err := client.Close()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, client.Close())
|
||||
require.NoError(t, server.Close())
|
||||
<-errChan
|
||||
<-closeChan
|
||||
})
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
package testutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func GenerateTLSCertificate(t testing.TB, commonName string) tls.Certificate {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Acme Co"},
|
||||
CommonName: commonName,
|
||||
},
|
||||
DNSNames: []string{commonName},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 180),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
var certFile bytes.Buffer
|
||||
require.NoError(t, err)
|
||||
_, err = certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||||
require.NoError(t, err)
|
||||
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
var keyFile bytes.Buffer
|
||||
err = pem.Encode(&keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes})
|
||||
require.NoError(t, err)
|
||||
cert, err := tls.X509KeyPair(certFile.Bytes(), keyFile.Bytes())
|
||||
require.NoError(t, err)
|
||||
return cert
|
||||
}
|
Loading…
Reference in New Issue