feat: add single tailnet support to moons (#8587)

This commit is contained in:
Colin Adler 2023-07-19 11:11:11 -05:00 committed by GitHub
parent cc8d0af027
commit 517fb19474
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 1195 additions and 80 deletions

71
coderd/apidoc/docs.go generated
View File

@ -4693,6 +4693,44 @@ const docTemplate = `{
}
}
},
"/workspaceagents/{workspaceagent}/legacy": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Enterprise"
],
"summary": "Agent is legacy",
"operationId": "agent-is-legacy",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace Agent ID",
"name": "workspaceagent",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/wsproxysdk.AgentIsLegacyResponse"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspaceagents/{workspaceagent}/listening-ports": {
"get": {
"security": [
@ -5147,6 +5185,28 @@ const docTemplate = `{
}
}
},
"/workspaceproxies/me/coordinate": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": [
"Enterprise"
],
"summary": "Workspace Proxy Coordinate",
"operationId": "workspace-proxy-coordinate",
"responses": {
"101": {
"description": "Switching Protocols"
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspaceproxies/me/goingaway": {
"post": {
"security": [
@ -10881,6 +10941,17 @@ const docTemplate = `{
}
}
},
"wsproxysdk.AgentIsLegacyResponse": {
"type": "object",
"properties": {
"found": {
"type": "boolean"
},
"legacy": {
"type": "boolean"
}
}
},
"wsproxysdk.IssueSignedAppTokenResponse": {
"type": "object",
"properties": {

View File

@ -4129,6 +4129,40 @@
}
}
},
"/workspaceagents/{workspaceagent}/legacy": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Enterprise"],
"summary": "Agent is legacy",
"operationId": "agent-is-legacy",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace Agent ID",
"name": "workspaceagent",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/wsproxysdk.AgentIsLegacyResponse"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspaceagents/{workspaceagent}/listening-ports": {
"get": {
"security": [
@ -4537,6 +4571,26 @@
}
}
},
"/workspaceproxies/me/coordinate": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": ["Enterprise"],
"summary": "Workspace Proxy Coordinate",
"operationId": "workspace-proxy-coordinate",
"responses": {
"101": {
"description": "Switching Protocols"
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspaceproxies/me/goingaway": {
"post": {
"security": [
@ -9912,6 +9966,17 @@
}
}
},
"wsproxysdk.AgentIsLegacyResponse": {
"type": "object",
"properties": {
"found": {
"type": "boolean"
},
"legacy": {
"type": "boolean"
}
}
},
"wsproxysdk.IssueSignedAppTokenResponse": {
"type": "object",
"properties": {

View File

@ -199,7 +199,7 @@ func New(options *Options) *API {
options.Authorizer,
options.Logger.Named("authz_querier"),
)
experiments := initExperiments(
experiments := ReadExperiments(
options.Logger, options.DeploymentValues.Experiments.Value(),
)
if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil {
@ -370,7 +370,9 @@ func New(options *Options) *API {
options.Logger,
options.DERPServer,
options.DERPMap,
&api.TailnetCoordinator,
func(context.Context) (tailnet.MultiAgentConn, error) {
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
},
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
)
if err != nil {
@ -1081,7 +1083,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
}
// nolint:revive
func initExperiments(log slog.Logger, raw []string) codersdk.Experiments {
func ReadExperiments(log slog.Logger, raw []string) codersdk.Experiments {
exps := make([]codersdk.Experiment, 0, len(raw))
for _, v := range raw {
switch v {

View File

@ -384,6 +384,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
TemplateScheduleStore: &templateScheduleStore,
TLSCertificates: options.TLSCertificates,
TrialGenerator: options.TrialGenerator,
TailnetCoordinator: options.Coordinator,
DERPMap: derpMap,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,

View File

@ -25,3 +25,24 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) {
}
}
}
// Heartbeat loops to ping a WebSocket to keep it alive. It kills the connection
// on ping failure.
func HeartbeatClose(ctx context.Context, exit func(), conn *websocket.Conn) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
err := conn.Ping(ctx)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "Ping failed")
exit()
return
}
}
}

View File

@ -64,7 +64,7 @@ func ExtractGroupParam(db database.Store) func(http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
groupID, parsed := parseUUID(rw, r, "group")
groupID, parsed := ParseUUIDParam(rw, r, "group")
if !parsed {
return
}

View File

@ -11,8 +11,8 @@ import (
"github.com/coder/coder/codersdk"
)
// parseUUID consumes a url parameter and parses it as a UUID.
func parseUUID(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) {
// ParseUUIDParam consumes a url parameter and parses it as a UUID.
func ParseUUIDParam(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) {
rawID := chi.URLParam(r, param)
if rawID == "" {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{

View File

@ -29,7 +29,7 @@ func TestParseUUID_Valid(t *testing.T) {
ctx.URLParams.Add(testParam, testWorkspaceAgentID)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
parsed, ok := parseUUID(rw, r, "workspaceagent")
parsed, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.True(t, ok, "UUID should be parsed")
assert.Equal(t, testWorkspaceAgentID, parsed.String())
}
@ -44,7 +44,7 @@ func TestParseUUID_Invalid(t *testing.T) {
ctx.URLParams.Add(testParam, "wrong-id")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
_, ok := parseUUID(rw, r, "workspaceagent")
_, ok := ParseUUIDParam(rw, r, "workspaceagent")
assert.False(t, ok, "UUID should not be parsed")
assert.Equal(t, http.StatusBadRequest, rw.Code)

View File

@ -39,7 +39,7 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
orgID, ok := parseUUID(rw, r, "organization")
orgID, ok := ParseUUIDParam(rw, r, "organization")
if !ok {
return
}

View File

@ -27,7 +27,7 @@ func ExtractTemplateParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateID, parsed := parseUUID(rw, r, "template")
templateID, parsed := ParseUUIDParam(rw, r, "template")
if !parsed {
return
}

View File

@ -29,7 +29,7 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersionID, parsed := parseUUID(rw, r, "templateversion")
templateVersionID, parsed := ParseUUIDParam(rw, r, "templateversion")
if !parsed {
return
}

View File

@ -29,7 +29,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
agentUUID, parsed := parseUUID(rw, r, "workspaceagent")
agentUUID, parsed := ParseUUIDParam(rw, r, "workspaceagent")
if !parsed {
return
}

View File

@ -27,7 +27,7 @@ func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handl
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceBuildID, parsed := parseUUID(rw, r, "workspacebuild")
workspaceBuildID, parsed := ParseUUIDParam(rw, r, "workspacebuild")
if !parsed {
return
}

View File

@ -30,7 +30,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceID, parsed := parseUUID(rw, r, "workspace")
workspaceID, parsed := ParseUUIDParam(rw, r, "workspace")
if !parsed {
return
}

View File

@ -29,7 +29,7 @@ func ExtractWorkspaceResourceParam(db database.Store) func(http.Handler) http.Ha
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resourceUUID, parsed := parseUUID(rw, r, "workspaceresource")
resourceUUID, parsed := ParseUUIDParam(rw, r, "workspaceresource")
if !parsed {
return
}

View File

@ -22,6 +22,7 @@ import (
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
)
var tailnetTransport *http.Transport
@ -41,7 +42,7 @@ func NewServerTailnet(
logger slog.Logger,
derpServer *derp.Server,
derpMap *tailcfg.DERPMap,
coord *atomic.Pointer[tailnet.Coordinator],
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
cache *wsconncache.Cache,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
@ -56,20 +57,23 @@ func NewServerTailnet(
serverCtx, cancel := context.WithCancel(ctx)
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
coord: coord,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
tn.transport.MaxIdleConnsPerHost = 10
tn.transport.MaxIdleConns = 0
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
agentConn, err := getMultiAgent(ctx)
if err != nil {
return nil, xerrors.Errorf("get initial multi agent: %w", err)
}
tn.agentConn.Store(&agentConn)
err = tn.getAgentConn().UpdateSelf(conn.Node())
@ -86,19 +90,21 @@ func NewServerTailnet(
// This is set to allow local DERP traffic to be proxied through memory
// instead of needing to hit the external access URL. Don't use the ctx
// given in this callback, it's only valid while connecting.
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
if derpServer != nil {
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
}
go tn.watchAgentUpdates()
go tn.expireOldAgents()
@ -167,30 +173,38 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
}
func (s *ServerTailnet) reinitCoordinator() {
s.nodesMu.Lock()
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
s.agentConn.Store(&agentConn)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
s.nodesMu.Unlock()
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
continue
}
s.agentConn.Store(&agentConn)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.nodesMu.Unlock()
return
}
s.nodesMu.Unlock()
}
type ServerTailnet struct {
ctx context.Context
cancel func()
logger slog.Logger
conn *tailnet.Conn
coord *atomic.Pointer[tailnet.Coordinator]
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
logger slog.Logger
conn *tailnet.Conn
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentNodes is a map of agent tailnetNodes the server wants to keep a
// connection to. It contains the last time the agent was connected to.
agentNodes map[uuid.UUID]time.Time

View File

@ -8,7 +8,6 @@ import (
"net/http/httptest"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"github.com/google/uuid"
@ -133,9 +132,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
DERPMap: derpMap,
}
var coordPtr atomic.Pointer[tailnet.Coordinator]
coord := tailnet.NewCoordinator(logger)
coordPtr.Store(&coord)
t.Cleanup(func() {
_ = coord.Close()
})
@ -194,7 +191,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
logger,
derpServer,
manifest.DERPMap,
&coordPtr,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
cache,
)
require.NoError(t, err)

View File

@ -13,17 +13,14 @@ import (
"time"
"cloud.google.com/go/compute/metadata"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"github.com/coder/retry"
"cdr.dev/slog"
"github.com/google/uuid"
"github.com/coder/coder/codersdk"
"github.com/coder/retry"
)
// New returns a client that is used to interact with the

View File

@ -6919,6 +6919,22 @@ _None_
| `username_or_id` | string | false | | For the following fields, if the AccessMethod is AccessMethodTerminal, then only AgentNameOrID may be set and it must be a UUID. The other fields must be left blank. |
| `workspace_name_or_id` | string | false | | |
## wsproxysdk.AgentIsLegacyResponse
```json
{
"found": true,
"legacy": true
}
```
### Properties
| Name | Type | Required | Restrictions | Description |
| -------- | ------- | -------- | ------------ | ----------- |
| `found` | boolean | false | | |
| `legacy` | boolean | false | | |
## wsproxysdk.IssueSignedAppTokenResponse
```json

View File

@ -10,12 +10,12 @@ import (
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
"github.com/google/uuid"
)
var jwtRegexp = regexp.MustCompile(`^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`)

View File

@ -25,6 +25,7 @@ import (
"github.com/coder/coder/cli"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
@ -220,6 +221,7 @@ func (*RootCmd) proxyServer() *clibase.Cmd {
proxy, err := wsproxy.New(ctx, &wsproxy.Options{
Logger: logger,
Experiments: coderd.ReadExperiments(logger, cfg.Experiments.Value()),
HTTPClient: httpClient,
DashboardURL: primaryAccessURL.Value(),
AccessURL: cfg.AccessURL.Value(),

View File

@ -125,6 +125,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Use(apiKeyMiddleware)
r.Post("/", api.reconnectingPTYSignedToken)
})
r.With(
apiKeyMiddlewareOptional,
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
DB: options.Database,
Optional: true,
}),
httpmw.RequireAPIKeyOrWorkspaceProxyAuth(),
).Get("/workspaceagents/{workspaceagent}/legacy", api.agentIsLegacy)
r.Route("/workspaceproxies", func(r chi.Router) {
r.Use(
api.moonsEnabledMW,
@ -143,6 +152,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
Optional: false,
}),
)
r.Get("/coordinate", api.workspaceProxyCoordinate)
r.Post("/issue-signed-app-token", api.workspaceProxyIssueSignedAppToken)
r.Post("/register", api.workspaceProxyRegister)
r.Post("/goingaway", api.workspaceProxyGoingAway)

View File

@ -25,7 +25,8 @@ import (
)
type ProxyOptions struct {
Name string
Name string
Experiments codersdk.Experiments
TLSCertificates []tls.Certificate
AppHostname string
@ -118,6 +119,7 @@ func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Clie
wssrv, err := wsproxy.New(ctx, &wsproxy.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
Experiments: options.Experiments,
DashboardURL: coderdAPI.AccessURL,
AccessURL: accessURL,
AppHostname: options.AppHostname,

View File

@ -0,0 +1,78 @@
package coderd
import (
"net/http"
"github.com/google/uuid"
"nhooyr.io/websocket"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/tailnet"
"github.com/coder/coder/enterprise/wsproxy/wsproxysdk"
)
// @Summary Agent is legacy
// @ID agent-is-legacy
// @Security CoderSessionToken
// @Produce json
// @Tags Enterprise
// @Param workspaceagent path string true "Workspace Agent ID" format(uuid)
// @Success 200 {object} wsproxysdk.AgentIsLegacyResponse
// @Router /workspaceagents/{workspaceagent}/legacy [get]
// @x-apidocgen {"skip": true}
func (api *API) agentIsLegacy(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
agentID, ok := httpmw.ParseUUIDParam(rw, r, "workspaceagent")
if !ok {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing UUID in URL.",
})
return
}
node := (*api.AGPL.TailnetCoordinator.Load()).Node(agentID)
httpapi.Write(ctx, rw, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{
Found: node != nil,
Legacy: node != nil &&
len(node.Addresses) > 0 &&
node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP,
})
}
// @Summary Workspace Proxy Coordinate
// @ID workspace-proxy-coordinate
// @Security CoderSessionToken
// @Tags Enterprise
// @Success 101
// @Router /workspaceproxies/me/coordinate [get]
// @x-apidocgen {"skip": true}
func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock()
defer api.AGPL.WebsocketWaitGroup.Done()
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
id := uuid.New()
sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id)
nc := websocket.NetConn(ctx, conn, websocket.MessageText)
defer nc.Close()
err = tailnet.ServeWorkspaceProxy(ctx, nc, sub)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
}
}

View File

@ -0,0 +1,158 @@
package coderd_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/enterprise/coderd/license"
"github.com/coder/coder/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
// workspaceProxyCoordinate and agentIsLegacy are both tested by wsproxy tests.
func Test_agentIsLegacy(t *testing.T) {
t.Parallel()
t.Run("Legacy", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{
string(codersdk.ExperimentMoons),
"*",
}
var (
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
db, pubsub = dbtestutil.NewDB(t)
logger = slogtest.Make(t, nil)
coordinator = agpl.NewCoordinator(logger)
client, _ = coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: dv,
Coordinator: coordinator,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspaceProxy: 1,
},
},
})
)
defer cancel()
nodeID := uuid.New()
ma := coordinator.ServeMultiAgent(nodeID)
defer ma.Close()
require.NoError(t, ma.UpdateSelf(&agpl.Node{
ID: 55,
AsOf: time.Unix(1689653252, 0),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 0,
DERPLatency: map[string]float64{
"0": 1.0,
},
DERPForcedWebsocket: map[int]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
Endpoints: []string{"192.168.1.1:18842"},
}))
proxyRes, err := client.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{
Name: namesgenerator.GetRandomName(1),
Icon: "/emojis/flag.png",
})
require.NoError(t, err)
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(proxyRes.ProxyToken)
legacyRes, err := proxyClient.AgentIsLegacy(ctx, nodeID)
require.NoError(t, err)
assert.True(t, legacyRes.Found)
assert.True(t, legacyRes.Legacy)
})
t.Run("NotLegacy", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{
string(codersdk.ExperimentMoons),
"*",
}
var (
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
db, pubsub = dbtestutil.NewDB(t)
logger = slogtest.Make(t, nil)
coordinator = agpl.NewCoordinator(logger)
client, _ = coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: dv,
Coordinator: coordinator,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspaceProxy: 1,
},
},
})
)
defer cancel()
nodeID := uuid.New()
ma := coordinator.ServeMultiAgent(nodeID)
defer ma.Close()
require.NoError(t, ma.UpdateSelf(&agpl.Node{
ID: 55,
AsOf: time.Unix(1689653252, 0),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 0,
DERPLatency: map[string]float64{
"0": 1.0,
},
DERPForcedWebsocket: map[int]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)},
Endpoints: []string{"192.168.1.1:18842"},
}))
proxyRes, err := client.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{
Name: namesgenerator.GetRandomName(1),
Icon: "/emojis/flag.png",
})
require.NoError(t, err)
proxyClient := wsproxysdk.New(client.URL)
proxyClient.SetSessionToken(proxyRes.ProxyToken)
legacyRes, err := proxyClient.AgentIsLegacy(ctx, nodeID)
require.NoError(t, err)
assert.True(t, legacyRes.Found)
assert.False(t, legacyRes.Legacy)
})
}

View File

@ -56,7 +56,6 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
m := (&agpl.MultiAgent{
ID: id,
Logger: c.log,
AgentIsLegacyFunc: c.agentIsLegacy,
OnSubscribe: c.clientSubscribeToAgent,
OnNodeUpdate: c.clientNodeUpdate,

View File

@ -0,0 +1,95 @@
package tailnet
import (
"bytes"
"context"
"encoding/json"
"net"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/tailnet"
)
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
go func() {
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)
if err != nil {
_ = conn.Close()
}
}()
decoder := json.NewDecoder(conn)
for {
var msg wsproxysdk.CoordinateMessage
err := decoder.Decode(&msg)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
switch msg.Type {
case wsproxysdk.CoordinateMessageTypeSubscribe:
err := ma.SubscribeAgent(msg.AgentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeUnsubscribe:
err := ma.UnsubscribeAgent(msg.AgentID)
if err != nil {
return xerrors.Errorf("unsubscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
err := ma.UpdateSelf(msg.Node)
if err != nil {
return xerrors.Errorf("update self: %w", err)
}
default:
return xerrors.Errorf("unknown message type %q", msg.Type)
}
}
}
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
var lastData []byte
for {
nodes, ok := ma.NextUpdate(ctx)
if !ok {
return xerrors.New("multiagent is closed")
}
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
if err != nil {
return err
}
if bytes.Equal(lastData, data) {
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
return err
}
_, err = conn.Write(data)
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
return err
}
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = conn.SetWriteDeadline(time.Time{})
if err != nil {
return err
}
lastData = data
}
}

View File

@ -27,10 +27,12 @@ import (
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/wsproxy/wsproxysdk"
"github.com/coder/coder/site"
agpl "github.com/coder/coder/tailnet"
)
type Options struct {
Logger slog.Logger
Logger slog.Logger
Experiments codersdk.Experiments
HTTPClient *http.Client
// DashboardURL is the URL of the primary coderd instance.
@ -168,6 +170,30 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
cancel: cancel,
}
connInfo, err := client.SDKClient.WorkspaceAgentConnectionInfo(ctx)
if err != nil {
return nil, xerrors.Errorf("get derpmap: %w", err)
}
var agentProvider workspaceapps.AgentProvider
if opts.Experiments.Enabled(codersdk.ExperimentSingleTailnet) {
stn, err := coderd.NewServerTailnet(ctx,
s.Logger.Named("server_tailnet"),
nil,
connInfo.DERPMap,
s.DialCoordinator,
wsconncache.New(s.DialWorkspaceAgent, 0),
)
if err != nil {
return nil, xerrors.Errorf("create server tailnet: %w", err)
}
agentProvider = stn
} else {
agentProvider = &wsconncache.AgentProvider{
Cache: wsconncache.New(s.DialWorkspaceAgent, 0),
}
}
s.AppServer = &workspaceapps.Server{
Logger: opts.Logger.Named("workspaceapps"),
DashboardURL: opts.DashboardURL,
@ -185,10 +211,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
},
AppSecurityKey: secKey,
// TODO: Convert wsproxy to use coderd.ServerTailnet.
AgentProvider: &wsconncache.AgentProvider{
Cache: wsconncache.New(s.DialWorkspaceAgent, 0),
},
AgentProvider: agentProvider,
DisablePathApps: opts.DisablePathApps,
SecureAuthCookie: opts.SecureAuthCookie,
}
@ -285,6 +308,10 @@ func (s *Server) DialWorkspaceAgent(id uuid.UUID) (*codersdk.WorkspaceAgentConn,
return s.SDKClient.DialWorkspaceAgent(s.ctx, id, nil)
}
func (s *Server) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
return s.SDKClient.DialCoordinator(ctx)
}
func (s *Server) buildInfo(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
ExternalURL: buildinfo.ExternalURL(),

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/workspaceapps/apptest"
@ -13,7 +14,7 @@ import (
"github.com/coder/coder/enterprise/coderd/license"
)
func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
func TestWorkspaceProxyWorkspaceApps_Wsconncache(t *testing.T) {
t.Parallel()
apptest.Run(t, false, func(t *testing.T, opts *apptest.DeploymentOptions) *apptest.Deployment {
@ -66,3 +67,59 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) {
}
})
}
func TestWorkspaceProxyWorkspaceApps_SingleTailnet(t *testing.T) {
t.Parallel()
apptest.Run(t, false, func(t *testing.T, opts *apptest.DeploymentOptions) *apptest.Deployment {
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.DisablePathApps = clibase.Bool(opts.DisablePathApps)
deploymentValues.Dangerous.AllowPathAppSharing = clibase.Bool(opts.DangerousAllowPathAppSharing)
deploymentValues.Dangerous.AllowPathAppSiteOwnerAccess = clibase.Bool(opts.DangerousAllowPathAppSiteOwnerAccess)
deploymentValues.Experiments = []string{
string(codersdk.ExperimentMoons),
string(codersdk.ExperimentSingleTailnet),
"*",
}
client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: deploymentValues,
AppHostname: "*.primary.test.coder.com",
IncludeProvisionerDaemon: true,
RealIPConfig: &httpmw.RealIPConfig{
TrustedOrigins: []*net.IPNet{{
IP: net.ParseIP("127.0.0.1"),
Mask: net.CIDRMask(8, 32),
}},
TrustedHeaders: []string{
"CF-Connecting-IP",
},
},
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspaceProxy: 1,
},
},
})
// Create the external proxy
if opts.DisableSubdomainApps {
opts.AppHost = ""
}
proxyAPI := coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{
Name: "best-proxy",
Experiments: coderd.ReadExperiments(api.Logger, deploymentValues.Experiments.Value()),
AppHostname: opts.AppHost,
DisablePathApps: opts.DisablePathApps,
})
return &apptest.Deployment{
Options: opts,
SDKClient: client,
FirstUser: user,
PathAppBaseURL: proxyAPI.Options.AccessURL,
}
})
}

View File

@ -3,16 +3,25 @@ package wsproxysdk
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"tailscale.com/util/singleflight"
"cdr.dev/slog"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/workspaceapps"
"github.com/coder/coder/codersdk"
agpl "github.com/coder/coder/tailnet"
)
// Client is a HTTP client for a subset of Coder API routes that external
@ -186,3 +195,206 @@ func (c *Client) WorkspaceProxyGoingAway(ctx context.Context) error {
}
return nil
}
type CoordinateMessageType int
const (
CoordinateMessageTypeSubscribe CoordinateMessageType = 1 + iota
CoordinateMessageTypeUnsubscribe
CoordinateMessageTypeNodeUpdate
)
type CoordinateMessage struct {
Type CoordinateMessageType `json:"type"`
AgentID uuid.UUID `json:"agent_id"`
Node *agpl.Node `json:"node"`
}
type CoordinateNodes struct {
Nodes []*agpl.Node
}
func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
ctx, cancel := context.WithCancel(ctx)
coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate")
if err != nil {
cancel()
return nil, xerrors.Errorf("parse url: %w", err)
}
coordinateHeaders := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader
if c.SDKClient.SessionTokenHeader != "" {
tokenHeader = c.SDKClient.SessionTokenHeader
}
coordinateHeaders.Set(tokenHeader, c.SessionToken())
//nolint:bodyclose
conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
HTTPClient: c.SDKClient.HTTPClient,
HTTPHeader: coordinateHeaders,
})
if err != nil {
cancel()
return nil, xerrors.Errorf("dial coordinate websocket: %w", err)
}
go httpapi.HeartbeatClose(ctx, cancel, conn)
nc := websocket.NetConn(ctx, conn, websocket.MessageText)
rma := remoteMultiAgentHandler{
sdk: c,
nc: nc,
legacyAgentCache: map[uuid.UUID]bool{},
}
ma := (&agpl.MultiAgent{
ID: uuid.New(),
AgentIsLegacyFunc: rma.AgentIsLegacy,
OnSubscribe: rma.OnSubscribe,
OnUnsubscribe: rma.OnUnsubscribe,
OnNodeUpdate: rma.OnNodeUpdate,
OnRemove: func(uuid.UUID) { conn.Close(websocket.StatusGoingAway, "closed") },
}).Init()
go func() {
defer cancel()
dec := json.NewDecoder(nc)
for {
var msg CoordinateNodes
err := dec.Decode(&msg)
if err != nil {
if xerrors.Is(err, io.EOF) {
return
}
c.SDKClient.Logger().Error(ctx, "failed to decode coordinator nodes", slog.Error(err))
return
}
err = ma.Enqueue(msg.Nodes)
if err != nil {
c.SDKClient.Logger().Error(ctx, "enqueue nodes from coordinator", slog.Error(err))
continue
}
}
}()
return ma, nil
}
type remoteMultiAgentHandler struct {
sdk *Client
nc net.Conn
legacyMu sync.RWMutex
legacyAgentCache map[uuid.UUID]bool
legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse]
}
func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return xerrors.Errorf("json marshal message: %w", err)
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
return xerrors.Errorf("set write deadline: %w", err)
}
_, err = a.nc.Write(data)
if err != nil {
return xerrors.Errorf("write message: %w", err)
}
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = a.nc.SetWriteDeadline(time.Time{})
if err != nil {
return xerrors.Errorf("clear write deadline: %w", err)
}
return nil
}
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error {
return a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeNodeUpdate,
Node: node,
})
}
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
return nil, a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeSubscribe,
AgentID: agentID,
})
}
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeUnsubscribe,
AgentID: agentID,
})
}
func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool {
a.legacyMu.RLock()
if isLegacy, ok := a.legacyAgentCache[agentID]; ok {
a.legacyMu.RUnlock()
return isLegacy
}
a.legacyMu.RUnlock()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
resp, err, _ := a.legacySingleflight.Do(agentID, func() (AgentIsLegacyResponse, error) {
return a.sdk.AgentIsLegacy(ctx, agentID)
})
if err != nil {
a.sdk.SDKClient.Logger().Error(ctx, "failed to check agent legacy status", slog.Error(err))
// Assume that the agent is legacy since this failed, while less
// efficient it will always work.
return true
}
// Assume legacy since the agent didn't exist.
if !resp.Found {
return true
}
a.legacyMu.Lock()
a.legacyAgentCache[agentID] = resp.Legacy
a.legacyMu.Unlock()
return resp.Legacy
}
type AgentIsLegacyResponse struct {
Found bool `json:"found"`
Legacy bool `json:"legacy"`
}
func (c *Client) AgentIsLegacy(ctx context.Context, agentID uuid.UUID) (AgentIsLegacyResponse, error) {
res, err := c.Request(ctx, http.MethodGet,
fmt.Sprintf("/api/v2/workspaceagents/%s/legacy", agentID.String()),
nil,
)
if err != nil {
return AgentIsLegacyResponse{}, xerrors.Errorf("make request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return AgentIsLegacyResponse{}, codersdk.ReadBodyAsError(res)
}
var resp AgentIsLegacyResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}

View File

@ -1,21 +1,36 @@
package wsproxysdk_test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/workspaceapps"
"github.com/coder/coder/enterprise/tailnet"
"github.com/coder/coder/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
@ -136,6 +151,135 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) {
})
}
func TestDialCoordinator(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID = uuid.New()
serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t))
r = chi.NewRouter()
srv = httptest.NewServer(r)
)
defer cancel()
r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, nil)
require.NoError(t, err)
nc := websocket.NetConn(r.Context(), conn, websocket.MessageText)
defer serverMultiAgent.Close()
err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent)
if !xerrors.Is(err, io.EOF) {
assert.NoError(t, err)
}
})
r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{
Found: true,
Legacy: true,
})
})
u, err := url.Parse(srv.URL)
require.NoError(t, err)
client := wsproxysdk.New(u)
client.SDKClient.SetLogger(logger)
expected := []*agpl.Node{{
ID: 55,
AsOf: time.Unix(1689653252, 0),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 0,
DERPLatency: map[string]float64{
"0": 1.0,
},
DERPForcedWebsocket: map[int]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
Endpoints: []string{"192.168.1.1:18842"},
}}
sendNode := make(chan struct{})
serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) {
select {
case <-sendNode:
return expected, true
case <-ctx.Done():
return nil, false
}
})
rma, err := client.DialCoordinator(ctx)
require.NoError(t, err)
// Subscribe
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) {
close(ch)
})
require.NoError(t, rma.SubscribeAgent(agentID))
waitOrCancel(ctx, t, ch)
}
// Read updated agent node
{
sendNode <- struct{}{}
got, ok := rma.NextUpdate(ctx)
assert.True(t, ok)
got[0].AsOf = got[0].AsOf.In(time.Local)
assert.Equal(t, *expected[0], *got[0])
}
// Check legacy
{
isLegacy := rma.AgentIsLegacy(agentID)
assert.True(t, isLegacy)
}
// UpdateSelf
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) {
node.AsOf = node.AsOf.In(time.Local)
assert.Equal(t, expected[0], node)
close(ch)
})
require.NoError(t, rma.UpdateSelf(expected[0]))
waitOrCancel(ctx, t, ch)
}
// Unsubscribe
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) {
close(ch)
})
require.NoError(t, rma.UnsubscribeAgent(agentID))
waitOrCancel(ctx, t, ch)
}
// Close
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().Close().Do(func() {
close(ch)
})
require.NoError(t, rma.Close())
waitOrCancel(ctx, t, ch)
}
})
}
func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
case <-ctx.Done():
t.Fatal("timed out waiting for channel")
}
}
type ResponseRecorder struct {
rw *httptest.ResponseRecorder
wasWritten atomic.Bool

2
go.mod
View File

@ -135,6 +135,7 @@ require (
github.com/mitchellh/go-wordwrap v1.0.1
github.com/mitchellh/mapstructure v1.5.0
github.com/moby/moby v24.0.1+incompatible
github.com/muesli/termenv v0.15.1
github.com/open-policy-agent/opa v0.51.0
github.com/ory/dockertest/v3 v3.10.0
github.com/pion/udp v0.1.2
@ -305,7 +306,6 @@ require (
github.com/muesli/ansi v0.0.0-20221106050444-61f0cd9a192a // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.1
github.com/niklasfasching/go-org v1.7.0 // indirect
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect

View File

@ -140,7 +140,6 @@ type coordinator struct {
func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
m := (&MultiAgent{
ID: id,
Logger: c.core.logger,
AgentIsLegacyFunc: c.core.agentIsLegacy,
OnSubscribe: c.core.clientSubscribeToAgent,
OnUnsubscribe: c.core.clientUnsubscribeFromAgent,

View File

@ -8,8 +8,6 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
type MultiAgentConn interface {
@ -25,10 +23,7 @@ type MultiAgentConn interface {
type MultiAgent struct {
mu sync.RWMutex
closed bool
ID uuid.UUID
Logger slog.Logger
ID uuid.UUID
AgentIsLegacyFunc func(agentID uuid.UUID) bool
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
@ -36,6 +31,7 @@ type MultiAgent struct {
OnNodeUpdate func(id uuid.UUID, node *Node) error
OnRemove func(id uuid.UUID)
closed bool
updates chan []*Node
closeOnce sync.Once
start int64

View File

@ -0,0 +1,150 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/tailnet/tailnettest (interfaces: MultiAgentConn)
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
reflect "reflect"
tailnet "github.com/coder/coder/tailnet"
gomock "github.com/golang/mock/gomock"
uuid "github.com/google/uuid"
)
// MockMultiAgentConn is a mock of MultiAgentConn interface.
type MockMultiAgentConn struct {
ctrl *gomock.Controller
recorder *MockMultiAgentConnMockRecorder
}
// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn.
type MockMultiAgentConnMockRecorder struct {
mock *MockMultiAgentConn
}
// NewMockMultiAgentConn creates a new mock instance.
func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn {
mock := &MockMultiAgentConn{ctrl: ctrl}
mock.recorder = &MockMultiAgentConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder {
return m.recorder
}
// AgentIsLegacy mocks base method.
func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AgentIsLegacy", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// AgentIsLegacy indicates an expected call of AgentIsLegacy.
func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0)
}
// Close mocks base method.
func (m *MockMultiAgentConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close))
}
// Enqueue mocks base method.
func (m *MockMultiAgentConn) Enqueue(arg0 []*tailnet.Node) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Enqueue", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Enqueue indicates an expected call of Enqueue.
func (mr *MockMultiAgentConnMockRecorder) Enqueue(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Enqueue", reflect.TypeOf((*MockMultiAgentConn)(nil).Enqueue), arg0)
}
// IsClosed mocks base method.
func (m *MockMultiAgentConn) IsClosed() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosed")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosed indicates an expected call of IsClosed.
func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed))
}
// NextUpdate mocks base method.
func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextUpdate", arg0)
ret0, _ := ret[0].([]*tailnet.Node)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// NextUpdate indicates an expected call of NextUpdate.
func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0)
}
// SubscribeAgent mocks base method.
func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SubscribeAgent indicates an expected call of SubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0)
}
// UnsubscribeAgent mocks base method.
func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UnsubscribeAgent indicates an expected call of UnsubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0)
}
// UpdateSelf mocks base method.
func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateSelf", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateSelf indicates an expected call of UpdateSelf.
func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0)
}

View File

@ -21,6 +21,8 @@ import (
"github.com/coder/coder/tailnet"
)
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/tailnet MultiAgentConn
// RunDERPAndSTUN creates a DERP mapping for tests.
func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {
logf := tailnet.Logger(slogtest.Make(t, nil))