diff --git a/Makefile b/Makefile index f1da02f2f5..d771fb0233 100644 --- a/Makefile +++ b/Makefile @@ -475,7 +475,8 @@ gen: \ site/.eslintignore \ site/e2e/provisionerGenerated.ts \ site/src/theme/icons.json \ - examples/examples.gen.json + examples/examples.gen.json \ + tailnet/tailnettest/coordinatormock.go .PHONY: gen # Mark all generated files as fresh so make thinks they're up-to-date. This is @@ -502,6 +503,7 @@ gen/mark-fresh: site/e2e/provisionerGenerated.ts \ site/src/theme/icons.json \ examples/examples.gen.json \ + tailnet/tailnettest/coordinatormock.go \ " for file in $$files; do echo "$$file" @@ -529,6 +531,9 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go go generate ./coderd/database/dbmock/ +tailnet/tailnettest/coordinatormock.go: tailnet/coordinator.go + go generate ./tailnet/tailnettest/ + tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto protoc \ --go_out=. \ diff --git a/agent/agent.go b/agent/agent.go index 514e10a7af..25e24215d9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -30,6 +30,7 @@ import ( "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" + "storj.io/drpc" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" @@ -47,6 +48,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" + tailnetproto "github.com/coder/coder/v2/tailnet/proto" ) const ( @@ -86,7 +88,7 @@ type Options struct { type Client interface { Manifest(ctx context.Context) (agentsdk.Manifest, error) - Listen(ctx context.Context) (net.Conn, error) + Listen(ctx context.Context) (drpc.Conn, error) DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error @@ -1058,20 +1060,34 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error ctx, cancel := context.WithCancel(ctx) defer cancel() - coordinator, err := a.client.Listen(ctx) + conn, err := a.client.Listen(ctx) if err != nil { return err } - defer coordinator.Close() + defer func() { + cErr := conn.Close() + if cErr != nil { + a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err)) + } + }() + + tClient := tailnetproto.NewDRPCTailnetClient(conn) + coordinate, err := tClient.Coordinate(ctx) + if err != nil { + return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err) + } + defer func() { + cErr := coordinate.Close() + if cErr != nil { + a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err)) + } + }() a.logger.Info(ctx, "connected to coordination endpoint") - sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error { - return network.UpdateNodes(nodes, false) - }) - network.SetNodeCallback(sendNodes) + coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) select { case <-ctx.Done(): return ctx.Err() - case err := <-errChan: + case err := <-coordination.Error(): return err } } diff --git a/agent/agent_test.go b/agent/agent_test.go index f884918c83..163c64b788 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1664,9 +1664,11 @@ func TestAgent_UpdatedDERP(t *testing.T) { require.NotNil(t, originalDerpMap) coordinator := tailnet.NewCoordinator(logger) - defer func() { + // use t.Cleanup so the coordinator closing doesn't deadlock with in-memory + // coordination + t.Cleanup(func() { _ = coordinator.Close() - }() + }) agentID := uuid.New() statsCh := make(chan *agentsdk.Stats, 50) fs := afero.NewMemMapFs() @@ -1681,41 +1683,42 @@ func TestAgent_UpdatedDERP(t *testing.T) { statsCh, coordinator, ) - closer := agent.New(agent.Options{ + uut := agent.New(agent.Options{ Client: client, Filesystem: fs, Logger: logger.Named("agent"), ReconnectingPTYTimeout: time.Minute, }) - defer func() { - _ = closer.Close() - }() + t.Cleanup(func() { + _ = uut.Close() + }) // Setup a client connection. - newClientConn := func(derpMap *tailcfg.DERPMap) *codersdk.WorkspaceAgentConn { + newClientConn := func(derpMap *tailcfg.DERPMap, name string) *codersdk.WorkspaceAgentConn { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, DERPMap: derpMap, - Logger: logger.Named("client"), + Logger: logger.Named(name), }) require.NoError(t, err) - clientConn, serverConn := net.Pipe() - serveClientDone := make(chan struct{}) t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() + t.Logf("closing conn %s", name) _ = conn.Close() - <-serveClientDone }) - go func() { - defer close(serveClientDone) - err := coordinator.ServeClient(serverConn, uuid.New(), agentID) - assert.NoError(t, err) - }() - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + clientID := uuid.New() + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, agentID, + coordinator, conn) + t.Cleanup(func() { + t.Logf("closing coordination %s", name) + err := coordination.Close() + if err != nil { + t.Logf("error closing in-memory coordination: %s", err.Error()) + } }) - conn.SetNodeCallback(sendNode) // Force DERP. conn.SetBlockEndpoints(true) @@ -1724,6 +1727,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { CloseFunc: func() error { return codersdk.ErrSkipClose }, }) t.Cleanup(func() { + t.Logf("closing sdkConn %s", name) _ = sdkConn.Close() }) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -1734,7 +1738,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { return sdkConn } - conn1 := newClientConn(originalDerpMap) + conn1 := newClientConn(originalDerpMap, "client1") // Change the DERP map. newDerpMap, _ := tailnettest.RunDERPAndSTUN(t) @@ -1753,27 +1757,34 @@ func TestAgent_UpdatedDERP(t *testing.T) { DERPMap: newDerpMap, }) require.NoError(t, err) + t.Logf("client Pushed DERPMap update") require.Eventually(t, func() bool { - conn := closer.TailnetConn() + conn := uut.TailnetConn() if conn == nil { return false } regionIDs := conn.DERPMap().RegionIDs() - return len(regionIDs) == 1 && regionIDs[0] == 2 && conn.Node().PreferredDERP == 2 + preferredDERP := conn.Node().PreferredDERP + t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP) + return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2 }, testutil.WaitLong, testutil.IntervalFast) + t.Logf("agent got the new DERPMap") // Connect from a second client and make sure it uses the new DERP map. - conn2 := newClientConn(newDerpMap) + conn2 := newClientConn(newDerpMap, "client2") require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + t.Log("conn2 got the new DERPMap") // If the first client gets a DERP map update, it should be able to // reconnect just fine. conn1.SetDERPMap(newDerpMap) require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) + t.Log("set the new DERPMap on conn1") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() require.True(t, conn1.AwaitReachable(ctx)) + t.Log("conn1 reached agent with new DERP") } func TestAgent_Speedtest(t *testing.T) { @@ -2050,22 +2061,22 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati Logger: logger.Named("client"), }) require.NoError(t, err) - clientConn, serverConn := net.Pipe() - serveClientDone := make(chan struct{}) t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() _ = conn.Close() - <-serveClientDone }) - go func() { - defer close(serveClientDone) - coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID) - }() - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + clientID := uuid.New() + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, metadata.AgentID, + coordinator, conn) + t.Cleanup(func() { + err := coordination.Close() + if err != nil { + t.Logf("error closing in-mem coordination: %s", err.Error()) + } }) - conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: metadata.AgentID, }) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index c63962f0e4..ddea2d749e 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -3,19 +3,26 @@ package agenttest import ( "context" "io" - "net" "sync" + "sync/atomic" "testing" "time" "github.com/google/uuid" + "github.com/stretchr/testify/require" "golang.org/x/exp/maps" "golang.org/x/xerrors" + "storj.io/drpc" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -24,11 +31,31 @@ func NewClient(t testing.TB, agentID uuid.UUID, manifest agentsdk.Manifest, statsChan chan *agentsdk.Stats, - coordinator tailnet.CoordinatorV1, + coordinator tailnet.Coordinator, ) *Client { if manifest.AgentID == uuid.Nil { manifest.AgentID = agentID } + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coordinator) + mux := drpcmux.New() + drpcService := &tailnet.DRPCService{ + CoordPtr: &coordPtr, + Logger: logger, + // TODO: handle DERPMap too! + DerpMapUpdateFrequency: time.Hour, + DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + } + err := proto.DRPCRegisterTailnet(mux, drpcService) + require.NoError(t, err) + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) return &Client{ t: t, logger: logger.Named("client"), @@ -36,6 +63,7 @@ func NewClient(t testing.TB, manifest: manifest, statsChan: statsChan, coordinator: coordinator, + server: server, derpMapUpdates: make(chan agentsdk.DERPMapUpdate), } } @@ -47,7 +75,8 @@ type Client struct { manifest agentsdk.Manifest metadata map[string]agentsdk.Metadata statsChan chan *agentsdk.Stats - coordinator tailnet.CoordinatorV1 + coordinator tailnet.Coordinator + server *drpcserver.Server LastWorkspaceAgent func() PatchWorkspaceLogs func() error GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) @@ -63,20 +92,29 @@ func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) { return c.manifest, nil } -func (c *Client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() +func (c *Client) Listen(_ context.Context) (drpc.Conn, error) { + conn, lis := drpcsdk.MemTransportPipe() closed := make(chan struct{}) c.LastWorkspaceAgent = func() { - _ = serverConn.Close() - _ = clientConn.Close() + _ = conn.Close() + _ = lis.Close() <-closed } c.t.Cleanup(c.LastWorkspaceAgent) + serveCtx, cancel := context.WithCancel(context.Background()) + c.t.Cleanup(cancel) + auth := tailnet.AgentTunnelAuth{} + streamID := tailnet.StreamID{ + Name: "agenttest", + ID: c.agentID, + Auth: auth, + } + serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + _ = c.server.Serve(serveCtx, lis) close(closed) }() - return clientConn, nil + return conn, nil } func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { diff --git a/coderd/coderd_test.go b/coderd/coderd_test.go index 8d7c129746..4c98feffb7 100644 --- a/coderd/coderd_test.go +++ b/coderd/coderd_test.go @@ -33,6 +33,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/tailnet" + tailnetproto "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -98,14 +99,32 @@ func TestDERP(t *testing.T) { w2Ready := make(chan struct{}) w2ReadyOnce := sync.Once{} + w1ID := uuid.New() w1.SetNodeCallback(func(node *tailnet.Node) { - w2.UpdateNodes([]*tailnet.Node{node}, false) + pn, err := tailnet.NodeToProto(node) + if !assert.NoError(t, err) { + return + } + w2.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{ + Id: w1ID[:], + Node: pn, + Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE, + }}) w2ReadyOnce.Do(func() { close(w2Ready) }) }) + w2ID := uuid.New() w2.SetNodeCallback(func(node *tailnet.Node) { - w1.UpdateNodes([]*tailnet.Node{node}, false) + pn, err := tailnet.NodeToProto(node) + if !assert.NoError(t, err) { + return + } + w1.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{ + Id: w2ID[:], + Node: pn, + Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE, + }}) }) conn := make(chan struct{}) @@ -199,7 +218,11 @@ func TestDERPForceWebSockets(t *testing.T) { defer cancel() resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, + &codersdk.DialWorkspaceAgentOptions{ + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("client"), + }, + ) require.NoError(t, err) defer func() { _ = conn.Close() diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 6521d79149..5f3300711a 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -121,12 +121,23 @@ func NewServerTailnet( } tn.agentConn.Store(&agentConn) - err = tn.getAgentConn().UpdateSelf(conn.Node()) + pn, err := tailnet.NodeToProto(conn.Node()) if err != nil { - tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + } else { + err = tn.getAgentConn().UpdateSelf(pn) + if err != nil { + tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) + } } + conn.SetNodeCallback(func(node *tailnet.Node) { - err := tn.getAgentConn().UpdateSelf(node) + pn, err := tailnet.NodeToProto(node) + if err != nil { + tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + return + } + err = tn.getAgentConn().UpdateSelf(pn) if err != nil { tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) } @@ -191,21 +202,9 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) { // If no one has connected since the cutoff and there are no active // connections, remove the agent. if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { - deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{ - ID: tailnet.NodeID(agentID), - IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128), - }) - if err != nil { - s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err)) - continue - } - if !deleted { - s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err)) - } - deletedCount++ delete(s.agentConnectionTimes, agentID) - err = agentConn.UnsubscribeAgent(agentID) + err := agentConn.UnsubscribeAgent(agentID) if err != nil { s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) } @@ -221,7 +220,7 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) { func (s *ServerTailnet) watchAgentUpdates() { for { conn := s.getAgentConn() - nodes, ok := conn.NextUpdate(s.ctx) + resp, ok := conn.NextUpdate(s.ctx) if !ok { if conn.IsClosed() && s.ctx.Err() == nil { s.logger.Warn(s.ctx, "multiagent closed, reinitializing") @@ -231,7 +230,7 @@ func (s *ServerTailnet) watchAgentUpdates() { return } - err := s.conn.UpdateNodes(nodes, false) + err := s.conn.UpdatePeers(resp.GetPeerUpdates()) if err != nil { if xerrors.Is(err, tailnet.ErrConnClosed) { s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err)) diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 2a0b0dfdba..392bc8d306 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -3,7 +3,6 @@ package coderd_test import ( "context" "fmt" - "net" "net/http" "net/http/httptest" "net/netip" @@ -204,22 +203,20 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A Logger: logger.Named("client"), }) require.NoError(t, err) - clientConn, serverConn := net.Pipe() - serveClientDone := make(chan struct{}) t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() _ = conn.Close() - <-serveClientDone }) - go func() { - defer close(serveClientDone) - coord.ServeClient(serverConn, uuid.New(), manifest.AgentID) - }() - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { - return conn.UpdateNodes(node, false) + clientID := uuid.New() + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, manifest.AgentID, + coord, conn, + ) + t.Cleanup(func() { + _ = coordination.Close() }) - conn.SetNodeCallback(sendNode) return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: manifest.AgentID, AgentIP: codersdk.WorkspaceAgentIP, diff --git a/coderd/util/apiversion/apiversion.go b/coderd/util/apiversion/apiversion.go index 7decaeab32..225fe01785 100644 --- a/coderd/util/apiversion/apiversion.go +++ b/coderd/util/apiversion/apiversion.go @@ -30,6 +30,10 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion { return v } +func (v *APIVersion) String() string { + return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor) +} + // Validate validates the given version against the given constraints: // A given major.minor version is valid iff: // 1. The requested major version is contained within v.supportedMajors @@ -42,10 +46,6 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion { // - 1.x is supported, // - 2.0, 2.1, and 2.2 are supported, // - 2.3+ is not supported. -func (v *APIVersion) String() string { - return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor) -} - func (v *APIVersion) Validate(version string) error { major, minor, err := Parse(version) if err != nil { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 1e48ea0e7a..ad508eebed 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -857,8 +857,6 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req // Deprecated: use api.tailnet.AgentConn instead. // See: https://github.com/coder/coder/issues/8218 func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - clientConn, serverConn := net.Pipe() - derpMap := api.DERPMap() conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, @@ -868,8 +866,6 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(), }) if err != nil { - _ = clientConn.Close() - _ = serverConn.Close() return nil, xerrors.Errorf("create tailnet conn: %w", err) } ctx, cancel := context.WithCancel(api.ctx) @@ -887,10 +883,10 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa return left }) - sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, true) - }) - conn.SetNodeCallback(sendNodes) + clientID := uuid.New() + coordination := tailnet.NewInMemoryCoordination(ctx, api.Logger, + clientID, agentID, + *(api.TailnetCoordinator.Load()), conn) // Check for updated DERP map every 5 seconds. go func() { @@ -920,27 +916,13 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa AgentID: agentID, AgentIP: codersdk.WorkspaceAgentIP, CloseFunc: func() error { + _ = coordination.Close() cancel() - _ = clientConn.Close() - _ = serverConn.Close() return nil }, }) - go func() { - err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) - if err != nil { - // Sometimes, we get benign closed pipe errors when the server is - // shutting down. - if api.ctx.Err() == nil { - api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err)) - } - _ = agentConn.Close() - } - }() if !agentConn.AwaitReachable(ctx) { _ = agentConn.Close() - _ = serverConn.Close() - _ = clientConn.Close() cancel() return nil, xerrors.Errorf("agent not reachable") } diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 0d620c991e..9d5fd8da1b 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -535,7 +535,6 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) { }) require.NoError(t, err) defer conn.Close() - require.True(t, conn.BlockEndpoints()) require.True(t, conn.AwaitReachable(ctx)) _, p2p, _, err := conn.Ping(ctx) diff --git a/coderd/wsconncache/wsconncache_test.go b/coderd/wsconncache/wsconncache_test.go index c824159a81..8a66e3ba03 100644 --- a/coderd/wsconncache/wsconncache_test.go +++ b/coderd/wsconncache/wsconncache_test.go @@ -12,14 +12,19 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "go.uber.org/goleak" + "golang.org/x/xerrors" + "storj.io/drpc" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -27,7 +32,9 @@ import ( "github.com/coder/coder/v2/coderd/wsconncache" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" + drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -41,7 +48,7 @@ func TestCache(t *testing.T) { t.Run("Same", func(t *testing.T) { t.Parallel() cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - return setupAgent(t, agentsdk.Manifest{}, 0), nil + return setupAgent(t, agentsdk.Manifest{}, 0) }, 0) defer func() { _ = cache.Close() @@ -54,10 +61,10 @@ func TestCache(t *testing.T) { }) t.Run("Expire", func(t *testing.T) { t.Parallel() - called := atomic.NewInt32(0) + called := int32(0) cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - called.Add(1) - return setupAgent(t, agentsdk.Manifest{}, 0), nil + atomic.AddInt32(&called, 1) + return setupAgent(t, agentsdk.Manifest{}, 0) }, time.Microsecond) defer func() { _ = cache.Close() @@ -70,12 +77,12 @@ func TestCache(t *testing.T) { require.NoError(t, err) release() <-conn.Closed() - require.Equal(t, int32(2), called.Load()) + require.Equal(t, int32(2), called) }) t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Parallel() cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - return setupAgent(t, agentsdk.Manifest{}, 0), nil + return setupAgent(t, agentsdk.Manifest{}, 0) }, time.Microsecond) defer func() { _ = cache.Close() @@ -108,7 +115,7 @@ func TestCache(t *testing.T) { go server.Serve(random) cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { - return setupAgent(t, agentsdk.Manifest{}, 0), nil + return setupAgent(t, agentsdk.Manifest{}, 0) }, time.Microsecond) defer func() { _ = cache.Close() @@ -154,7 +161,7 @@ func TestCache(t *testing.T) { }) } -func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { +func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) { t.Helper() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) @@ -184,18 +191,25 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati DERPForceWebSockets: manifest.DERPForceWebSockets, Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug), }) - require.NoError(t, err) - clientConn, serverConn := net.Pipe() + // setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the + // test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks. + if !assert.NoError(t, err) { + return nil, err + } t.Cleanup(func() { - _ = clientConn.Close() - _ = serverConn.Close() _ = conn.Close() }) - go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) - sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) + clientID := uuid.New() + testCtx, testCtxCancel := context.WithCancel(context.Background()) + t.Cleanup(testCtxCancel) + coordination := tailnet.NewInMemoryCoordination( + testCtx, logger, + clientID, manifest.AgentID, + coordinator, conn, + ) + t.Cleanup(func() { + _ = coordination.Close() }) - conn.SetNodeCallback(sendNode) agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ AgentID: manifest.AgentID, AgentIP: codersdk.WorkspaceAgentIP, @@ -206,16 +220,20 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() if !agentConn.AwaitReachable(ctx) { - t.Fatal("agent not reachable") + // setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end + // the test, which in turn closes the wsconncache, which in turn waits for the Dialer and + // deadlocks. + t.Error("agent not reachable") + return nil, xerrors.New("agent not reachable") } - return agentConn + return agentConn, nil } type client struct { t *testing.T agentID uuid.UUID manifest agentsdk.Manifest - coordinator tailnet.CoordinatorV1 + coordinator tailnet.Coordinator } func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { @@ -240,19 +258,53 @@ func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, }, nil } -func (c *client) Listen(_ context.Context) (net.Conn, error) { - clientConn, serverConn := net.Pipe() +func (c *client) Listen(_ context.Context) (drpc.Conn, error) { + logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc") + conn, lis := drpcsdk.MemTransportPipe() closed := make(chan struct{}) c.t.Cleanup(func() { - _ = serverConn.Close() - _ = clientConn.Close() + _ = conn.Close() + _ = lis.Close() <-closed }) + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&c.coordinator) + mux := drpcmux.New() + drpcService := &tailnet.DRPCService{ + CoordPtr: &coordPtr, + Logger: logger, + // TODO: handle DERPMap too! + DerpMapUpdateFrequency: time.Hour, + DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + } + err := proto.DRPCRegisterTailnet(mux, drpcService) + if err != nil { + return nil, xerrors.Errorf("register DRPC service: %w", err) + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) + serveCtx, cancel := context.WithCancel(context.Background()) + c.t.Cleanup(cancel) + auth := tailnet.AgentTunnelAuth{} + streamID := tailnet.StreamID{ + Name: "wsconncache_test-agent", + ID: c.agentID, + Auth: auth, + } + serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { - _ = c.coordinator.ServeAgent(serverConn, c.agentID, "") + server.Serve(serveCtx, lis) close(closed) }() - return clientConn, nil + return conn, nil } func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) { diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index b1960bc7d2..2b65f3a316 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -14,12 +14,15 @@ import ( "cloud.google.com/go/compute/metadata" "github.com/google/uuid" + "github.com/hashicorp/yamux" "golang.org/x/xerrors" "nhooyr.io/websocket" + "storj.io/drpc" "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/codersdk" + drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/retry" ) @@ -280,8 +283,8 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C // Listen connects to the workspace agent coordinate WebSocket // that handles connection negotiation. -func (c *Client) Listen(ctx context.Context) (net.Conn, error) { - coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/coordinate") +func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) { + coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc") if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } @@ -312,14 +315,21 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) { ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "coordinate") - return &closeNetConn{ + netConn := &closeNetConn{ Conn: wsNetConn, closeFunc: func() { cancelFunc() _ = conn.Close(websocket.StatusGoingAway, "Listen closed") <-pingClosed }, - }, nil + } + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(netConn, config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return drpcsdk.MultiplexedConn(session), nil } type PostAppHealthsRequest struct { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index d3cfabcb63..8165b78c12 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -313,6 +313,9 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } + q := coordinateURL.Query() + q.Add("version", tailnet.CurrentVersion.String()) + coordinateURL.RawQuery = q.Encode() closedCoordinator := make(chan struct{}) // Must only ever be used once, send error OR close to avoid // reassignment race. Buffered so we don't hang in goroutine. @@ -344,12 +347,22 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) continue } - sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { - return conn.UpdateNodes(nodes, false) - }) - conn.SetNodeCallback(sendNode) + client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary)) + if err != nil { + options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + continue + } + coordinate, err := client.Coordinate(ctx) + if err != nil { + options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err)) + _ = ws.Close(websocket.StatusInternalError, "") + continue + } + + coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID) options.Logger.Debug(ctx, "serving coordinator") - err = <-errChan + err = <-coordination.Error() if errors.Is(err, context.Canceled) { _ = ws.Close(websocket.StatusGoingAway, "") return diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index bf291e45ce..4fe25827b5 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -8,6 +8,7 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" @@ -53,6 +54,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request ctx := r.Context() version := "1.0" + msgType := websocket.MessageText qv := r.URL.Query().Get("version") if qv != "" { version = qv @@ -66,6 +68,11 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request }) return } + maj, _, _ := apiversion.Parse(version) + if maj >= 2 { + // Versions 2+ use dRPC over a binary connection + msgType = websocket.MessageBinary + } api.AGPL.WebsocketWaitMutex.Lock() api.AGPL.WebsocketWaitGroup.Add(1) @@ -81,7 +88,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request return } - ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) + ctx, nc := websocketNetConn(ctx, conn, msgType) defer nc.Close() id := uuid.New() diff --git a/enterprise/coderd/workspaceproxycoordinator_test.go b/enterprise/coderd/workspaceproxycoordinator_test.go index 83bbb5c49d..38ba957bf6 100644 --- a/enterprise/coderd/workspaceproxycoordinator_test.go +++ b/enterprise/coderd/workspaceproxycoordinator_test.go @@ -10,6 +10,7 @@ import ( "github.com/moby/moby/pkg/namesgenerator" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/types/key" "cdr.dev/slog/sloggers/slogtest" @@ -20,6 +21,7 @@ import ( "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -27,6 +29,12 @@ import ( func Test_agentIsLegacy(t *testing.T) { t.Parallel() + nodeKey := key.NewNode().Public() + discoKey := key.NewDisco().Public() + nkBin, err := nodeKey.MarshalBinary() + require.NoError(t, err) + dkBin, err := discoKey.MarshalText() + require.NoError(t, err) t.Run("Legacy", func(t *testing.T) { t.Parallel() @@ -54,18 +62,18 @@ func Test_agentIsLegacy(t *testing.T) { nodeID := uuid.New() ma := coordinator.ServeMultiAgent(nodeID) defer ma.Close() - require.NoError(t, ma.UpdateSelf(&agpl.Node{ - ID: 55, - AsOf: time.Unix(1689653252, 0), - Key: key.NewNode().Public(), - DiscoKey: key.NewDisco().Public(), - PreferredDERP: 0, - DERPLatency: map[string]float64{ + require.NoError(t, ma.UpdateSelf(&proto.Node{ + Id: 55, + AsOf: timestamppb.New(time.Unix(1689653252, 0)), + Key: nkBin, + Disco: string(dkBin), + PreferredDerp: 0, + DerpLatency: map[string]float64{ "0": 1.0, }, - DERPForcedWebsocket: map[int]string{}, - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, - AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + DerpForcedWebsocket: map[int32]string{}, + Addresses: []string{codersdk.WorkspaceAgentIP.String() + "/128"}, + AllowedIps: []string{codersdk.WorkspaceAgentIP.String() + "/128"}, Endpoints: []string{"192.168.1.1:18842"}, })) require.Eventually(t, func() bool { @@ -114,18 +122,18 @@ func Test_agentIsLegacy(t *testing.T) { nodeID := uuid.New() ma := coordinator.ServeMultiAgent(nodeID) defer ma.Close() - require.NoError(t, ma.UpdateSelf(&agpl.Node{ - ID: 55, - AsOf: time.Unix(1689653252, 0), - Key: key.NewNode().Public(), - DiscoKey: key.NewDisco().Public(), - PreferredDERP: 0, - DERPLatency: map[string]float64{ + require.NoError(t, ma.UpdateSelf(&proto.Node{ + Id: 55, + AsOf: timestamppb.New(time.Unix(1689653252, 0)), + Key: nkBin, + Disco: string(dkBin), + PreferredDerp: 0, + DerpLatency: map[string]float64{ "0": 1.0, }, - DERPForcedWebsocket: map[int]string{}, - Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, - AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, + DerpForcedWebsocket: map[int32]string{}, + Addresses: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()}, + AllowedIps: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()}, Endpoints: []string{"192.168.1.1:18842"}, })) require.Eventually(t, func() bool { diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index e51cab8814..c9f8f73fe9 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -6,12 +6,15 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + "tailscale.com/types/key" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/tailnet" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/testutil" ) @@ -39,22 +42,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord1) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -86,23 +86,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord1) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) - require.NoError(t, ma1.Close()) + ma1.unsubscribeAgent(agent1.id) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -134,37 +131,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord1.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord1) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) + ma1.unsubscribeAgent(agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() - require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9})) + ma1.sendNodeWithDERP(9) assertNeverHasDERPs(ctx, t, agent1, 9) }() func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() agent1.sendNode(&agpl.Node{PreferredDERP: 8}) - assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8) + ma1.assertNeverHasDERPs(ctx, 8) }() - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -201,22 +196,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord2.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord2) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -254,22 +246,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test defer agent1.close() agent1.sendNode(&agpl.Node{PreferredDERP: 5}) - id := uuid.New() - ma1 := coord2.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord2) + defer ma1.close() - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) @@ -316,33 +305,129 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { defer agent1.close() agent2.sendNode(&agpl.Node{PreferredDERP: 6}) - id := uuid.New() - ma1 := coord3.ServeMultiAgent(id) - defer ma1.Close() + ma1 := newTestMultiAgent(t, coord3) + defer ma1.close() - err = ma1.SubscribeAgent(agent1.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5) + ma1.subscribeAgent(agent1.id) + ma1.assertEventuallyHasDERPs(ctx, 5) agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) + ma1.assertEventuallyHasDERPs(ctx, 1) - err = ma1.SubscribeAgent(agent2.id) - require.NoError(t, err) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6) + ma1.subscribeAgent(agent2.id) + ma1.assertEventuallyHasDERPs(ctx, 6) agent2.sendNode(&agpl.Node{PreferredDERP: 2}) - assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2) + ma1.assertEventuallyHasDERPs(ctx, 2) - err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) - require.NoError(t, err) + ma1.sendNodeWithDERP(3) assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent2, 3) - require.NoError(t, ma1.Close()) + ma1.close() require.NoError(t, agent1.close()) require.NoError(t, agent2.close()) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id) } + +type testMultiAgent struct { + t testing.TB + id uuid.UUID + a agpl.MultiAgentConn + nodeKey []byte + discoKey string +} + +func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent { + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} + m.a = coord.ServeMultiAgent(m.id) + return m +} + +func (m *testMultiAgent) sendNodeWithDERP(derp int32) { + m.t.Helper() + err := m.a.UpdateSelf(&proto.Node{ + Key: m.nodeKey, + Disco: m.discoKey, + PreferredDerp: derp, + }) + require.NoError(m.t, err) +} + +func (m *testMultiAgent) close() { + m.t.Helper() + err := m.a.Close() + require.NoError(m.t, err) +} + +func (m *testMultiAgent) subscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.SubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) { + m.t.Helper() + err := m.a.UnsubscribeAgent(id) + require.NoError(m.t, err) +} + +func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + require.True(m.t, ok) + nodes, err := agpl.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} + +func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) { + m.t.Helper() + for { + resp, ok := m.a.NextUpdate(ctx) + if !ok { + return + } + nodes, err := agpl.OnlyNodeUpdates(resp) + require.NoError(m.t, err) + if len(nodes) != len(expected) { + m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) + continue + } + + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + if !slices.Contains(derps, e) { + m.t.Logf("expected DERP %d to be in %v", e, derps) + continue + } + return + } + } +} diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 63ee818eae..c26418ff66 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -818,56 +818,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte } } -func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { - t.Helper() - for { - nodes, ok := ma.NextUpdate(ctx) - require.True(t, ok) - if len(nodes) != len(expected) { - t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - -func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) { - t.Helper() - for { - nodes, ok := ma.NextUpdate(ctx) - if !ok { - return - } - if len(nodes) != len(expected) { - t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { t.Helper() assert.Eventually(t, func() bool { diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index 0471c076b0..d8f64aa398 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -96,7 +96,11 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC return xerrors.Errorf("unsubscribe agent: %w", err) } case wsproxysdk.CoordinateMessageTypeNodeUpdate: - err := ma.UpdateSelf(msg.Node) + pn, err := agpl.NodeToProto(msg.Node) + if err != nil { + return err + } + err = ma.UpdateSelf(pn) if err != nil { return xerrors.Errorf("update self: %w", err) } @@ -110,11 +114,14 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { var lastData []byte for { - nodes, ok := ma.NextUpdate(ctx) + resp, ok := ma.NextUpdate(ctx) if !ok { return xerrors.New("multiagent is closed") } - + nodes, err := agpl.OnlyNodeUpdates(resp) + if err != nil { + return xerrors.Errorf("failed to convert response: %w", err) + } data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes}) if err != nil { return err diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index cbf9695bd7..fe4b1d3b22 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -158,7 +158,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { // TODO: Probably do some version checking here info, err := client.SDKClient.BuildInfo(ctx) if err != nil { - return nil, fmt.Errorf("buildinfo: %w", errors.Join( + return nil, xerrors.Errorf("buildinfo: %w", errors.Join( xerrors.Errorf("unable to fetch build info from primary coderd. Are you sure %q is a coderd instance?", opts.DashboardURL), err, )) diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 142d0b5c1e..f8d8c22543 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "sync" @@ -23,6 +22,7 @@ import ( "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/codersdk" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" ) // Client is a HTTP client for a subset of Coder API routes that external @@ -438,6 +438,9 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro cancel() return nil, xerrors.Errorf("parse url: %w", err) } + q := coordinateURL.Query() + q.Add("version", agpl.CurrentVersion.String()) + coordinateURL.RawQuery = q.Encode() coordinateHeaders := make(http.Header) tokenHeader := codersdk.SessionTokenHeader if c.SDKClient.SessionTokenHeader != "" { @@ -457,10 +460,24 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro go httpapi.HeartbeatClose(ctx, logger, cancel, conn) - nc := websocket.NetConn(ctx, conn, websocket.MessageText) + nc := websocket.NetConn(ctx, conn, websocket.MessageBinary) + client, err := agpl.NewDRPCClient(nc) + if err != nil { + logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, "") + return nil, xerrors.Errorf("failed to create DRPCClient: %w", err) + } + protocol, err := client.Coordinate(ctx) + if err != nil { + logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, "") + return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err) + } + rma := remoteMultiAgentHandler{ sdk: c, - nc: nc, + logger: logger, + protocol: protocol, cancel: cancel, legacyAgentCache: map[uuid.UUID]bool{}, } @@ -471,103 +488,75 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro OnSubscribe: rma.OnSubscribe, OnUnsubscribe: rma.OnUnsubscribe, OnNodeUpdate: rma.OnNodeUpdate, - OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") }, + OnRemove: rma.OnRemove, }).Init() go func() { <-ctx.Done() ma.Close() + _ = conn.Close(websocket.StatusGoingAway, "closed") }() - go func() { - defer cancel() - dec := json.NewDecoder(nc) - for { - var msg CoordinateNodes - err := dec.Decode(&msg) - if err != nil { - if xerrors.Is(err, io.EOF) { - logger.Info(ctx, "websocket connection severed", slog.Error(err)) - return - } - - logger.Error(ctx, "decode coordinator nodes", slog.Error(err)) - return - } - - err = ma.Enqueue(msg.Nodes) - if err != nil { - logger.Error(ctx, "enqueue nodes from coordinator", slog.Error(err)) - continue - } - } - }() + rma.ma = ma + go rma.respLoop() return ma, nil } type remoteMultiAgentHandler struct { - sdk *Client - nc net.Conn - cancel func() + sdk *Client + logger slog.Logger + protocol proto.DRPCTailnet_CoordinateClient + ma *agpl.MultiAgent + cancel func() legacyMu sync.RWMutex legacyAgentCache map[uuid.UUID]bool legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse] } -func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { - data, err := json.Marshal(v) - if err != nil { - return xerrors.Errorf("json marshal message: %w", err) - } +func (a *remoteMultiAgentHandler) respLoop() { + { + defer a.cancel() + for { + resp, err := a.protocol.Recv() + if err != nil { + if xerrors.Is(err, io.EOF) { + a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err)) + return + } - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) - if err != nil { - a.cancel() - return xerrors.Errorf("set write deadline: %w", err) - } - _, err = a.nc.Write(data) - if err != nil { - a.cancel() - return xerrors.Errorf("write message: %w", err) - } + a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err)) + return + } - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = a.nc.SetWriteDeadline(time.Time{}) - if err != nil { - a.cancel() - return xerrors.Errorf("clear write deadline: %w", err) + err = a.ma.Enqueue(resp) + if err != nil { + a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err)) + continue + } + } } - - return nil } -func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error { - return a.writeJSON(CoordinateMessage{ - Type: CoordinateMessageTypeNodeUpdate, - Node: node, - }) +func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error { + return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}}) } -func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { - return nil, a.writeJSON(CoordinateMessage{ - Type: CoordinateMessageTypeSubscribe, - AgentID: agentID, - }) +func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error { + return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}) } func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error { - return a.writeJSON(CoordinateMessage{ - Type: CoordinateMessageTypeUnsubscribe, - AgentID: agentID, - }) + return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}) +} + +func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) { + err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + if err != nil { + a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err)) + } + _ = a.protocol.CloseSend() } func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool { diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index 1901b3207b..8cf8b1ee18 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -18,8 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/timestamppb" "nhooyr.io/websocket" + "tailscale.com/tailcfg" "tailscale.com/types/key" "cdr.dev/slog" @@ -30,6 +31,7 @@ import ( "github.com/coder/coder/v2/enterprise/tailnet" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -156,25 +158,48 @@ func TestDialCoordinator(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() var ( - ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) - logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) - agentID = uuid.New() - serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t)) - r = chi.NewRouter() - srv = httptest.NewServer(r) + ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agentID = uuid.UUID{33} + proxyID = uuid.UUID{44} + mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t)) + coord agpl.Coordinator = mCoord + r = chi.NewRouter() + srv = httptest.NewServer(r) ) defer cancel() + defer srv.Close() + coordPtr := atomic.Pointer[agpl.Coordinator]{} + coordPtr.Store(&coord) + cSrv, err := tailnet.NewClientService( + logger, &coordPtr, + time.Hour, + func() *tailcfg.DERPMap { panic("not implemented") }, + ) + require.NoError(t, err) + + // buffer the channels here, so we don't need to read and write in goroutines to + // avoid blocking + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}). + Times(1). + Return(reqs, resps) + + serveMACErr := make(chan error, 1) r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) { conn, err := websocket.Accept(w, r, nil) - require.NoError(t, err) - nc := websocket.NetConn(r.Context(), conn, websocket.MessageText) - defer serverMultiAgent.Close() - - err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent) - if !xerrors.Is(err, io.EOF) { - assert.NoError(t, err) + if !assert.NoError(t, err) { + return } + version := r.URL.Query().Get("version") + if !assert.Equal(t, version, agpl.CurrentVersion.String()) { + return + } + nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary) + err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID) + serveMACErr <- err }) r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{ @@ -188,51 +213,50 @@ func TestDialCoordinator(t *testing.T) { client := wsproxysdk.New(u) client.SDKClient.SetLogger(logger) - expected := []*agpl.Node{{ - ID: 55, - AsOf: time.Unix(1689653252, 0), - Key: key.NewNode().Public(), - DiscoKey: key.NewDisco().Public(), - PreferredDERP: 0, - DERPLatency: map[string]float64{ - "0": 1.0, + peerID := uuid.UUID{55} + peerNodeKey, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + peerDiscoKey, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{ + Id: peerID[:], + Node: &proto.Node{ + Id: 55, + AsOf: timestamppb.New(time.Unix(1689653252, 0)), + Key: peerNodeKey[:], + Disco: string(peerDiscoKey), + PreferredDerp: 0, + DerpLatency: map[string]float64{ + "0": 1.0, + }, + DerpForcedWebsocket: map[int32]string{}, + Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()}, + AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()}, + Endpoints: []string{"192.168.1.1:18842"}, }, - DERPForcedWebsocket: map[int]string{}, - Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)}, - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)}, - Endpoints: []string{"192.168.1.1:18842"}, - }} - sendNode := make(chan struct{}) - - serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes(). - DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) { - select { - case <-sendNode: - return expected, true - case <-ctx.Done(): - return nil, false - } - }) + }}} rma, err := client.DialCoordinator(ctx) require.NoError(t, err) // Subscribe { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) { - close(ch) - }) require.NoError(t, rma.SubscribeAgent(agentID)) - waitOrCancel(ctx, t, ch) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) } // Read updated agent node { - sendNode <- struct{}{} - got, ok := rma.NextUpdate(ctx) + resps <- expected + + resp, ok := rma.NextUpdate(ctx) assert.True(t, ok) - got[0].AsOf = got[0].AsOf.In(time.Local) - assert.Equal(t, *expected[0], *got[0]) + updates := resp.GetPeerUpdates() + assert.Len(t, updates, 1) + eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode()) + assert.NoError(t, err) + assert.True(t, eq) } // Check legacy { @@ -241,45 +265,38 @@ func TestDialCoordinator(t *testing.T) { } // UpdateSelf { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) { - node.AsOf = node.AsOf.In(time.Local) - assert.Equal(t, expected[0], node) - close(ch) - }) - require.NoError(t, rma.UpdateSelf(expected[0])) - waitOrCancel(ctx, t, ch) + require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode())) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode()) + require.NoError(t, err) + require.True(t, eq) } // Unsubscribe { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) { - close(ch) - }) require.NoError(t, rma.UnsubscribeAgent(agentID)) - waitOrCancel(ctx, t, ch) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId()) } // Close { - ch := make(chan struct{}) - serverMultiAgent.EXPECT().Close().Do(func() { - close(ch) - }) require.NoError(t, rma.Close()) - waitOrCancel(ctx, t, ch) + + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, req.Disconnect) + close(resps) + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for req close") + case _, ok := <-reqs: + require.False(t, ok, "didn't close requests") + } + require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr)) } }) } -func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) { - t.Helper() - select { - case <-ch: - case <-ctx.Done(): - t.Fatal("timed out waiting for channel") - } -} - type ResponseRecorder struct { rw *httptest.ResponseRecorder wasWritten atomic.Bool diff --git a/tailnet/configmaps.go b/tailnet/configmaps.go index 49200aa5fd..9c9fe7ee8d 100644 --- a/tailnet/configmaps.go +++ b/tailnet/configmaps.go @@ -490,6 +490,18 @@ func (c *configMaps) protoNodeToTailcfg(p *proto.Node) (*tailcfg.Node, error) { }, nil } +// nodeAddresses returns the addresses for the peer with the given publicKey, if known. +func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { + c.L.Lock() + defer c.L.Unlock() + for _, lc := range c.peers { + if lc.node.Key == publicKey { + return lc.node.Addresses, true + } + } + return nil, false +} + type peerLifecycle struct { peerID uuid.UUID node *tailcfg.Node diff --git a/tailnet/conn.go b/tailnet/conn.go index 34712ee0ff..0b4b942f8f 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -3,48 +3,40 @@ package tailnet import ( "context" "encoding/binary" - "errors" "fmt" "net" "net/http" "net/netip" "os" - "reflect" "strconv" "sync" "time" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" - "go4.org/netipx" "golang.org/x/xerrors" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/net/connstats" - "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/tsdial" "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tsd" - "tailscale.com/types/ipproto" "tailscale.com/types/key" tslogger "tailscale.com/types/logger" "tailscale.com/types/netlogtype" - "tailscale.com/types/netmap" "tailscale.com/wgengine" - "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/netstack" "tailscale.com/wgengine/router" - "tailscale.com/wgengine/wgcfg/nmcfg" "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/cryptorand" + "github.com/coder/coder/v2/tailnet/proto" ) var ErrConnClosed = xerrors.New("connection closed") @@ -128,42 +120,6 @@ func NewConn(options *Options) (conn *Conn, err error) { } nodePrivateKey := key.NewNode() - nodePublicKey := nodePrivateKey.Public() - - netMap := &netmap.NetworkMap{ - DERPMap: options.DERPMap, - NodeKey: nodePublicKey, - PrivateKey: nodePrivateKey, - Addresses: options.Addresses, - PacketFilter: []filter.Match{{ - // Allow any protocol! - IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP}, - // Allow traffic sourced from anywhere. - Srcs: []netip.Prefix{ - netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0), - netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0), - }, - // Allow traffic to route anywhere. - Dsts: []filter.NetPortRange{ - { - Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0), - Ports: filter.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0), - Ports: filter.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - Caps: []filter.CapMatch{}, - }}, - } - var nodeID tailcfg.NodeID // If we're provided with a UUID, use it to populate our node ID. @@ -177,14 +133,6 @@ func NewConn(options *Options) (conn *Conn, err error) { nodeID = tailcfg.NodeID(uid) } - // This is used by functions below to identify the node via key - netMap.SelfNode = &tailcfg.Node{ - ID: nodeID, - Key: nodePublicKey, - Addresses: options.Addresses, - AllowedIPs: options.Addresses, - } - wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor"))) if err != nil { return nil, xerrors.Errorf("create wireguard link monitor: %w", err) @@ -243,7 +191,6 @@ func NewConn(options *Options) (conn *Conn, err error) { if err != nil { return nil, xerrors.Errorf("set node private key: %w", err) } - netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey() netStack, err := netstack.Create( Logger(options.Logger.Named("net.netstack")), @@ -262,44 +209,46 @@ func NewConn(options *Options) (conn *Conn, err error) { } netStack.ProcessLocalIPs = true wireguardEngine = wgengine.NewWatchdog(wireguardEngine) - wireguardEngine.SetDERPMap(options.DERPMap) - netMapCopy := *netMap - options.Logger.Debug(context.Background(), "updating network map") - wireguardEngine.SetNetworkMap(&netMapCopy) - localIPSet := netipx.IPSetBuilder{} - for _, addr := range netMap.Addresses { - localIPSet.AddPrefix(addr) - } - localIPs, _ := localIPSet.IPSet() - logIPSet := netipx.IPSetBuilder{} - logIPs, _ := logIPSet.IPSet() - wireguardEngine.SetFilter(filter.New( - netMap.PacketFilter, - localIPs, - logIPs, + cfgMaps := newConfigMaps( + options.Logger, + wireguardEngine, + nodeID, + nodePrivateKey, + magicConn.DiscoPublicKey(), + ) + cfgMaps.setAddresses(options.Addresses) + cfgMaps.setDERPMap(DERPMapToProto(options.DERPMap)) + cfgMaps.setBlockEndpoints(options.BlockEndpoints) + + nodeUp := newNodeUpdater( + options.Logger, nil, - Logger(options.Logger.Named("net.packet-filter")), - )) + nodeID, + nodePrivateKey.Public(), + magicConn.DiscoPublicKey(), + ) + nodeUp.setAddresses(options.Addresses) + nodeUp.setBlockEndpoints(options.BlockEndpoints) + wireguardEngine.SetStatusCallback(nodeUp.setStatus) + wireguardEngine.SetNetInfoCallback(nodeUp.setNetInfo) + magicConn.SetDERPForcedWebsocketCallback(nodeUp.setDERPForcedWebsocket) server := &Conn{ - blockEndpoints: options.BlockEndpoints, - derpForceWebSockets: options.DERPForceWebSockets, - closed: make(chan struct{}), - logger: options.Logger, - magicConn: magicConn, - dialer: dialer, - listeners: map[listenKey]*listener{}, - peerMap: map[tailcfg.NodeID]*tailcfg.Node{}, - lastDERPForcedWebSockets: map[int]string{}, - tunDevice: sys.Tun.Get(), - netMap: netMap, - netStack: netStack, - wireguardMonitor: wireguardMonitor, + closed: make(chan struct{}), + logger: options.Logger, + magicConn: magicConn, + dialer: dialer, + listeners: map[listenKey]*listener{}, + tunDevice: sys.Tun.Get(), + netStack: netStack, + wireguardMonitor: wireguardMonitor, wireguardRouter: &router.Config{ - LocalAddrs: netMap.Addresses, + LocalAddrs: options.Addresses, }, wireguardEngine: wireguardEngine, + configMaps: cfgMaps, + nodeUpdater: nodeUp, } defer func() { if err != nil { @@ -307,52 +256,6 @@ func NewConn(options *Options) (conn *Conn, err error) { } }() - wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { - server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err)) - if err != nil { - return - } - server.lastMutex.Lock() - if s.AsOf.Before(server.lastStatus) { - // Don't process outdated status! - server.lastMutex.Unlock() - return - } - server.lastStatus = s.AsOf - if endpointsEqual(s.LocalAddrs, server.lastEndpoints) { - // No need to update the node if nothing changed! - server.lastMutex.Unlock() - return - } - server.lastEndpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...) - server.lastMutex.Unlock() - server.sendNode() - }) - - wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { - server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni)) - server.lastMutex.Lock() - if reflect.DeepEqual(server.lastNetInfo, ni) { - server.lastMutex.Unlock() - return - } - server.lastNetInfo = ni.Clone() - server.lastMutex.Unlock() - server.sendNode() - }) - - magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) { - server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason)) - server.lastMutex.Lock() - if server.lastDERPForcedWebSockets[region] == reason { - server.lastMutex.Unlock() - return - } - server.lastDERPForcedWebSockets[region] = reason - server.lastMutex.Unlock() - server.sendNode() - }) - netStack.GetTCPHandlerForFlow = server.forwardTCP err = netStack.Start(nil) @@ -389,16 +292,14 @@ func IPFromUUID(uid uuid.UUID) netip.Addr { // Conn is an actively listening Wireguard connection. type Conn struct { - mutex sync.Mutex - closed chan struct{} - logger slog.Logger - blockEndpoints bool - derpForceWebSockets bool + mutex sync.Mutex + closed chan struct{} + logger slog.Logger dialer *tsdial.Dialer tunDevice *tstun.Wrapper - peerMap map[tailcfg.NodeID]*tailcfg.Node - netMap *netmap.NetworkMap + configMaps *configMaps + nodeUpdater *nodeUpdater netStack *netstack.Impl magicConn *magicsock.Conn wireguardMonitor *netmon.Monitor @@ -406,17 +307,6 @@ type Conn struct { wireguardEngine wgengine.Engine listeners map[listenKey]*listener - lastMutex sync.Mutex - nodeSending bool - nodeChanged bool - // It's only possible to store these values via status functions, - // so the values must be stored for retrieval later on. - lastStatus time.Time - lastEndpoints []tailcfg.Endpoint - lastDERPForcedWebSockets map[int]string - lastNetInfo *tailcfg.NetInfo - nodeCallback func(node *Node) - trafficStats *connstats.Statistics } @@ -425,57 +315,30 @@ func (c *Conn) MagicsockSetDebugLoggingEnabled(enabled bool) { } func (c *Conn) SetAddresses(ips []netip.Prefix) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.netMap.Addresses = ips - - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - err := c.reconfig() - if err != nil { - return xerrors.Errorf("reconfig: %w", err) - } - + c.configMaps.setAddresses(ips) + c.nodeUpdater.setAddresses(ips) return nil } -func (c *Conn) Addresses() []netip.Prefix { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.netMap.Addresses -} - func (c *Conn) SetNodeCallback(callback func(node *Node)) { - c.lastMutex.Lock() - c.nodeCallback = callback - c.lastMutex.Unlock() - c.sendNode() + c.nodeUpdater.setCallback(callback) } // SetDERPMap updates the DERPMap of a connection. func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap)) - c.wireguardEngine.SetDERPMap(derpMap) - c.netMap.DERPMap = derpMap - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) + c.configMaps.setDERPMap(DERPMapToProto(derpMap)) } func (c *Conn) SetDERPForceWebSockets(v bool) { + c.logger.Info(context.Background(), "setting DERP Force Websockets", slog.F("force_derp_websockets", v)) c.magicConn.SetDERPForceWebsockets(v) } -// SetBlockEndpoints sets whether or not to block P2P endpoints. This setting +// SetBlockEndpoints sets whether to block P2P endpoints. This setting // will only apply to new peers. func (c *Conn) SetBlockEndpoints(blockEndpoints bool) { - c.mutex.Lock() - defer c.mutex.Unlock() - c.blockEndpoints = blockEndpoints + c.configMaps.setBlockEndpoints(blockEndpoints) + c.nodeUpdater.setBlockEndpoints(blockEndpoints) } // SetDERPRegionDialer updates the dialer to use for connecting to DERP regions. @@ -483,186 +346,24 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail c.magicConn.SetDERPRegionDialer(dialer) } -// UpdateNodes connects with a set of peers. This can be constantly updated, -// and peers will continually be reconnected as necessary. If replacePeers is -// true, all peers will be removed before adding the new ones. -// -//nolint:revive // Complains about replacePeers. -func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error { - c.mutex.Lock() - defer c.mutex.Unlock() - +// UpdatePeers connects with a set of peers. This can be constantly updated, +// and peers will continually be reconnected as necessary. +func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { if c.isClosed() { return ErrConnClosed } - - status := c.Status() - if replacePeers { - c.netMap.Peers = []*tailcfg.Node{} - c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{} - } - for _, peer := range c.netMap.Peers { - peerStatus, ok := status.Peer[peer.Key] - if !ok { - continue - } - // If this peer was added in the last 5 minutes, assume it - // could still be active. - if time.Since(peer.Created) < 5*time.Minute { - continue - } - // We double-check that it's safe to remove by ensuring no - // handshake has been sent in the past 5 minutes as well. Connections that - // are actively exchanging IP traffic will handshake every 2 minutes. - if time.Since(peerStatus.LastHandshake) < 5*time.Minute { - continue - } - - c.logger.Debug(context.Background(), "removing peer, last handshake >5m ago", - slog.F("peer", peer.Key), slog.F("last_handshake", peerStatus.LastHandshake), - ) - delete(c.peerMap, peer.ID) - } - - for _, node := range nodes { - // If no preferred DERP is provided, we can't reach the node. - if node.PreferredDERP == 0 { - c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node)) - continue - } - c.logger.Debug(context.Background(), "adding node", slog.F("node", node)) - - peerStatus, ok := status.Peer[node.Key] - peerNode := &tailcfg.Node{ - ID: node.ID, - Created: time.Now(), - Key: node.Key, - DiscoKey: node.DiscoKey, - Addresses: node.Addresses, - AllowedIPs: node.AllowedIPs, - Endpoints: node.Endpoints, - DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP), - Hostinfo: (&tailcfg.Hostinfo{}).View(), - // Starting KeepAlive messages at the initialization of a connection - // causes a race condition. If we handshake before the peer has our - // node, we'll have wait for 5 seconds before trying again. Ideally, - // the first handshake starts when the user first initiates a - // connection to the peer. After a successful connection we enable - // keep alives to persist the connection and keep it from becoming - // idle. SSH connections don't send send packets while idle, so we - // use keep alives to avoid random hangs while we set up the - // connection again after inactivity. - KeepAlive: ok && peerStatus.Active, - } - if c.blockEndpoints { - peerNode.Endpoints = nil - } - c.peerMap[node.ID] = peerNode - } - - c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) - for _, peer := range c.peerMap { - c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) - } - - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - err := c.reconfig() - if err != nil { - return xerrors.Errorf("reconfig: %w", err) - } - - return nil -} - -// PeerSelector is used to select a peer from within a Tailnet. -type PeerSelector struct { - ID tailcfg.NodeID - IP netip.Prefix -} - -func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) { - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.isClosed() { - return false, ErrConnClosed - } - - deleted = false - for _, peer := range c.peerMap { - if peer.ID == selector.ID { - delete(c.peerMap, peer.ID) - deleted = true - break - } - - for _, peerIP := range peer.Addresses { - if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 { - delete(c.peerMap, peer.ID) - deleted = true - break - } - } - } - if !deleted { - return false, nil - } - - c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap)) - for _, peer := range c.peerMap { - c.netMap.Peers = append(c.netMap.Peers, peer.Clone()) - } - - netMapCopy := *c.netMap - c.logger.Debug(context.Background(), "updating network map") - c.wireguardEngine.SetNetworkMap(&netMapCopy) - err = c.reconfig() - if err != nil { - return false, xerrors.Errorf("reconfig: %w", err) - } - - return true, nil -} - -func (c *Conn) reconfig() error { - cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "") - if err != nil { - return xerrors.Errorf("update wireguard config: %w", err) - } - - err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{}) - if err != nil { - if c.isClosed() { - return nil - } - if errors.Is(err, wgengine.ErrNoChanges) { - return nil - } - return xerrors.Errorf("reconfig: %w", err) - } - + c.configMaps.updatePeers(updates) return nil } // NodeAddresses returns the addresses of a node from the NetworkMap. func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { - c.mutex.Lock() - defer c.mutex.Unlock() - for _, node := range c.netMap.Peers { - if node.Key == publicKey { - return node.Addresses, true - } - } - return nil, false + return c.configMaps.nodeAddresses(publicKey) } // Status returns the current ipnstate of a connection. func (c *Conn) Status() *ipnstate.Status { - sb := &ipnstate.StatusBuilder{WantPeers: true} - c.wireguardEngine.UpdateStatus(sb) - return sb.Status() + return c.configMaps.status() } // Ping sends a ping to the Wireguard engine. @@ -689,16 +390,9 @@ func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *i // DERPMap returns the currently set DERP mapping. func (c *Conn) DERPMap() *tailcfg.DERPMap { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.netMap.DERPMap -} - -// BlockEndpoints returns whether or not P2P is blocked. -func (c *Conn) BlockEndpoints() bool { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.blockEndpoints + c.configMaps.L.Lock() + defer c.configMaps.L.Unlock() + return c.configMaps.derpMapLocked() } // AwaitReachable pings the provided IP continually until the @@ -759,6 +453,9 @@ func (c *Conn) Closed() <-chan struct{} { // Close shuts down the Wireguard connection. func (c *Conn) Close() error { + c.logger.Info(context.Background(), "closing tailnet Conn") + c.configMaps.close() + c.nodeUpdater.close() c.mutex.Lock() select { case <-c.closed: @@ -808,91 +505,11 @@ func (c *Conn) isClosed() bool { } } -func (c *Conn) sendNode() { - c.lastMutex.Lock() - defer c.lastMutex.Unlock() - if c.nodeSending { - c.nodeChanged = true - return - } - node := c.selfNode() - // Conn.UpdateNodes will skip any nodes that don't have the PreferredDERP - // set to non-zero, since we cannot reach nodes without DERP for discovery. - // Therefore, there is no point in sending the node without this, and we can - // save ourselves from churn in the tailscale/wireguard layer. - if node.PreferredDERP == 0 { - c.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node)) - return - } - nodeCallback := c.nodeCallback - if nodeCallback == nil { - return - } - c.nodeSending = true - go func() { - c.logger.Debug(context.Background(), "sending node", slog.F("node", node)) - nodeCallback(node) - c.lastMutex.Lock() - c.nodeSending = false - if c.nodeChanged { - c.nodeChanged = false - c.lastMutex.Unlock() - c.sendNode() - return - } - c.lastMutex.Unlock() - }() -} - // Node returns the last node that was sent to the node callback. func (c *Conn) Node() *Node { - c.lastMutex.Lock() - defer c.lastMutex.Unlock() - return c.selfNode() -} - -func (c *Conn) selfNode() *Node { - endpoints := make([]string, 0, len(c.lastEndpoints)) - for _, addr := range c.lastEndpoints { - endpoints = append(endpoints, addr.Addr.String()) - } - var preferredDERP int - var derpLatency map[string]float64 - derpForcedWebsocket := make(map[int]string, 0) - if c.lastNetInfo != nil { - preferredDERP = c.lastNetInfo.PreferredDERP - derpLatency = c.lastNetInfo.DERPLatency - - if c.derpForceWebSockets { - // We only need to store this for a single region, since this is - // mostly used for debugging purposes and doesn't actually have a - // code purpose. - derpForcedWebsocket[preferredDERP] = "DERP is configured to always fallback to WebSockets" - } else { - for k, v := range c.lastDERPForcedWebSockets { - derpForcedWebsocket[k] = v - } - } - } - - node := &Node{ - ID: c.netMap.SelfNode.ID, - AsOf: dbtime.Now(), - Key: c.netMap.SelfNode.Key, - Addresses: c.netMap.SelfNode.Addresses, - AllowedIPs: c.netMap.SelfNode.AllowedIPs, - DiscoKey: c.magicConn.DiscoPublicKey(), - Endpoints: endpoints, - PreferredDERP: preferredDERP, - DERPLatency: derpLatency, - DERPForcedWebsocket: derpForcedWebsocket, - } - c.mutex.Lock() - if c.blockEndpoints { - node.Endpoints = nil - } - c.mutex.Unlock() - return node + c.nodeUpdater.L.Lock() + defer c.nodeUpdater.L.Unlock() + return c.nodeUpdater.nodeLocked() } // This and below is taken _mostly_ verbatim from Tailscale: @@ -1056,15 +673,3 @@ func Logger(logger slog.Logger) tslogger.Logf { logger.Debug(context.Background(), fmt.Sprintf(format, args...)) }) } - -func endpointsEqual(x, y []tailcfg.Endpoint) bool { - if len(x) != len(y) { - return false - } - for i := range x { - if x[i] != y[i] { - return false - } - } - return true -} diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index 7554c94cec..f3bc96e242 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -12,6 +13,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) @@ -22,10 +24,10 @@ func TestMain(m *testing.M) { func TestTailnet(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) derpMap, _ := tailnettest.RunDERPAndSTUN(t) t.Run("InstantClose", func(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Logger: logger.Named("w1"), @@ -37,6 +39,8 @@ func TestTailnet(t *testing.T) { }) t.Run("Connect", func(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitLong) w1IP := tailnet.IP() w1, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, @@ -55,14 +59,8 @@ func TestTailnet(t *testing.T) { _ = w1.Close() _ = w2.Close() }) - w1.SetNodeCallback(func(node *tailnet.Node) { - err := w2.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - w2.SetNodeCallback(func(node *tailnet.Node) { - err := w1.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) + stitch(t, w2, w1) + stitch(t, w1, w2) require.True(t, w2.AwaitReachable(context.Background(), w1IP)) conn := make(chan struct{}, 1) go func() { @@ -89,7 +87,7 @@ func TestTailnet(t *testing.T) { default: } }) - node := <-nodes + node := testutil.RequireRecvCtx(ctx, t, nodes) // Ensure this connected over DERP! require.Len(t, node.DERPForcedWebsocket, 0) @@ -99,6 +97,7 @@ func TestTailnet(t *testing.T) { t.Run("ForcesWebSockets", func(t *testing.T) { t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) ctx := testutil.Context(t, testutil.WaitMedium) w1IP := tailnet.IP() @@ -122,14 +121,8 @@ func TestTailnet(t *testing.T) { _ = w1.Close() _ = w2.Close() }) - w1.SetNodeCallback(func(node *tailnet.Node) { - err := w2.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - w2.SetNodeCallback(func(node *tailnet.Node) { - err := w1.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) + stitch(t, w2, w1) + stitch(t, w1, w2) require.True(t, w2.AwaitReachable(ctx, w1IP)) conn := make(chan struct{}, 1) go func() { @@ -243,11 +236,16 @@ func TestConn_UpdateDERP(t *testing.T) { err := client1.Close() assert.NoError(t, err) }() - client1.SetNodeCallback(func(node *tailnet.Node) { - err := conn.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false) + stitch(t, conn, client1) + pn, err := tailnet.NodeToProto(conn.Node()) + require.NoError(t, err) + connID := uuid.New() + err = client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: connID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) + require.NoError(t, err) awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort) defer awaitReachableCancel1() @@ -288,7 +286,13 @@ parentLoop: // ... unless the client updates it's derp map and nodes. client1.SetDERPMap(derpMap2) - client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false) + pn, err = tailnet.NodeToProto(conn.Node()) + require.NoError(t, err) + client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: connID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort) defer awaitReachableCancel3() require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip)) @@ -306,13 +310,34 @@ parentLoop: err := client2.Close() assert.NoError(t, err) }() - client2.SetNodeCallback(func(node *tailnet.Node) { - err := conn.UpdateNodes([]*tailnet.Node{node}, false) - assert.NoError(t, err) - }) - client2.UpdateNodes([]*tailnet.Node{conn.Node()}, false) + stitch(t, conn, client2) + pn, err = tailnet.NodeToProto(conn.Node()) + require.NoError(t, err) + client2.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: connID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort) defer awaitReachableCancel4() require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip)) } + +// stitch sends node updates from src Conn as peer updates to dst Conn. Sort of +// like the Coordinator would, but without actually needing a Coordinator. +func stitch(t *testing.T, dst, src *tailnet.Conn) { + srcID := uuid.New() + src.SetNodeCallback(func(node *tailnet.Node) { + pn, err := tailnet.NodeToProto(node) + if !assert.NoError(t, err) { + return + } + err = dst.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{ + Id: srcID[:], + Node: pn, + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + }}) + assert.NoError(t, err) + }) +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 04c5fc0ee3..0fa62fc922 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -3,6 +3,7 @@ package tailnet import ( "context" "encoding/json" + "fmt" "html/template" "io" "net" @@ -92,6 +93,237 @@ type Node struct { Endpoints []string `json:"endpoints"` } +// Coordinatee is something that can be coordinated over the Coordinate protocol. Usually this is a +// Conn. +type Coordinatee interface { + UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error + SetNodeCallback(func(*Node)) +} + +type Coordination interface { + io.Closer + Error() <-chan error +} + +type remoteCoordination struct { + sync.Mutex + closed bool + errChan chan error + coordinatee Coordinatee + logger slog.Logger + protocol proto.DRPCTailnet_CoordinateClient +} + +func (c *remoteCoordination) Close() error { + c.Lock() + defer c.Unlock() + if c.closed { + return nil + } + c.closed = true + err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + if err != nil { + return xerrors.Errorf("send disconnect: %w", err) + } + return nil +} + +func (c *remoteCoordination) Error() <-chan error { + return c.errChan +} + +func (c *remoteCoordination) sendErr(err error) { + select { + case c.errChan <- err: + default: + } +} + +func (c *remoteCoordination) respLoop() { + for { + resp, err := c.protocol.Recv() + if err != nil { + c.sendErr(xerrors.Errorf("read: %w", err)) + return + } + err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.sendErr(xerrors.Errorf("update peers: %w", err)) + return + } + } +} + +// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a +// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as +// a client---agents should NOT set this!). +func NewRemoteCoordination(logger slog.Logger, + protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee, + tunnelTarget uuid.UUID, +) Coordination { + c := &remoteCoordination{ + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + protocol: protocol, + } + if tunnelTarget != uuid.Nil { + c.Lock() + err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}}) + c.Unlock() + if err != nil { + c.sendErr(err) + } + } + + coordinatee.SetNodeCallback(func(node *Node) { + pn, err := NodeToProto(node) + if err != nil { + c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + c.sendErr(err) + return + } + c.Lock() + defer c.Unlock() + if c.closed { + c.logger.Debug(context.Background(), "ignored node update because coordination is closed") + return + } + err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) + if err != nil { + c.sendErr(xerrors.Errorf("write: %w", err)) + } + }) + go c.respLoop() + return c +} + +type inMemoryCoordination struct { + sync.Mutex + ctx context.Context + errChan chan error + closed bool + closedCh chan struct{} + coordinatee Coordinatee + logger slog.Logger + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest +} + +func (c *inMemoryCoordination) sendErr(err error) { + select { + case c.errChan <- err: + default: + } +} + +func (c *inMemoryCoordination) Error() <-chan error { + return c.errChan +} + +// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing +// or local clients. Set ClientID to uuid.Nil for an agent. +func NewInMemoryCoordination( + ctx context.Context, logger slog.Logger, + clientID, agentID uuid.UUID, + coordinator Coordinator, coordinatee Coordinatee, +) Coordination { + thisID := agentID + logger = logger.With(slog.F("agent_id", agentID)) + var auth TunnelAuth = AgentTunnelAuth{} + if clientID != uuid.Nil { + // this is a client connection + auth = ClientTunnelAuth{AgentID: agentID} + logger = logger.With(slog.F("client_id", clientID)) + thisID = clientID + } + c := &inMemoryCoordination{ + ctx: ctx, + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + closedCh: make(chan struct{}), + } + + // use the background context since we will depend exclusively on closing the req channel to + // tell the coordinator we are done. + c.reqs, c.resps = coordinator.Coordinate(context.Background(), + thisID, fmt.Sprintf("inmemory%s", thisID), + auth, + ) + go c.respLoop() + if agentID != uuid.Nil { + select { + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err())) + return c + case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}: + // OK! + } + } + coordinatee.SetNodeCallback(func(n *Node) { + pn, err := NodeToProto(n) + if err != nil { + c.logger.Critical(ctx, "failed to convert node", slog.Error(err)) + c.sendErr(err) + return + } + c.Lock() + defer c.Unlock() + if c.closed { + return + } + select { + case <-ctx.Done(): + c.logger.Info(ctx, "context expired before sending node update") + return + case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}: + c.logger.Debug(ctx, "sent node in-memory to coordinator") + } + }) + return c +} + +func (c *inMemoryCoordination) respLoop() { + for { + select { + case <-c.closedCh: + c.logger.Debug(context.Background(), "in-memory coordination closed") + return + case resp, ok := <-c.resps: + if !ok { + c.logger.Debug(context.Background(), "in-memory response channel closed") + return + } + c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) + err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) + return + } + } + } +} + +func (c *inMemoryCoordination) Close() error { + c.Lock() + defer c.Unlock() + c.logger.Debug(context.Background(), "closing in-memory coordination") + if c.closed { + return nil + } + defer close(c.reqs) + c.closed = true + close(c.closedCh) + select { + case <-c.ctx.Done(): + return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) + case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: + c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") + return nil + } +} + // ServeCoordinator matches the RW structure of a coordinator to exchange node messages. func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) { errChan := make(chan error, 1) @@ -237,21 +469,17 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge } return false }, - OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) { + OnSubscribe: func(enq Queue, agent uuid.UUID) error { err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}}) - return c.Node(agent), err + return err }, OnUnsubscribe: func(enq Queue, agent uuid.UUID) error { err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}}) return err }, - OnNodeUpdate: func(id uuid.UUID, node *Node) error { - pn, err := NodeToProto(node) - if err != nil { - return err - } + OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error { return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ - Node: pn, + Node: node, }}) }, OnRemove: func(_ Queue) { @@ -285,7 +513,7 @@ const ( type Queue interface { UniqueID() uuid.UUID Kind() QueueKind - Enqueue(n []*Node) error + Enqueue(resp *proto.CoordinateResponse) error Name() string Stats() (start, lastWrite int64) Overwrites() int64 @@ -793,18 +1021,7 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg return } logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp)) - nodes, err := OnlyNodeUpdates(resp) - if err != nil { - logger.Critical(ctx, "v1RespLoop failed to decode resp", slog.F("resp", resp), slog.Error(err)) - _ = q.CoordinatorClose() - return - } - // don't send empty updates - if len(nodes) == 0 { - logger.Debug(ctx, "v1RespLoop skipping enqueueing 0-length v1 update") - continue - } - err = q.Enqueue(nodes) + err = q.Enqueue(resp) if err != nil && !xerrors.Is(err, context.Canceled) { logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err)) } diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 5c3412a595..621f6bc6b1 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -8,13 +8,15 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/tailnet/proto" ) type MultiAgentConn interface { - UpdateSelf(node *Node) error + UpdateSelf(node *proto.Node) error SubscribeAgent(agentID uuid.UUID) error UnsubscribeAgent(agentID uuid.UUID) error - NextUpdate(ctx context.Context) ([]*Node, bool) + NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) AgentIsLegacy(agentID uuid.UUID) bool Close() error IsClosed() bool @@ -26,16 +28,16 @@ type MultiAgent struct { ID uuid.UUID AgentIsLegacyFunc func(agentID uuid.UUID) bool - OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) + OnSubscribe func(enq Queue, agent uuid.UUID) error OnUnsubscribe func(enq Queue, agent uuid.UUID) error - OnNodeUpdate func(id uuid.UUID, node *Node) error + OnNodeUpdate func(id uuid.UUID, node *proto.Node) error OnRemove func(enq Queue) ctx context.Context ctxCancel func() closed bool - updates chan []*Node + updates chan *proto.CoordinateResponse closeOnce sync.Once start int64 lastWrite int64 @@ -45,7 +47,7 @@ type MultiAgent struct { } func (m *MultiAgent) Init() *MultiAgent { - m.updates = make(chan []*Node, 128) + m.updates = make(chan *proto.CoordinateResponse, 128) m.start = time.Now().Unix() m.ctx, m.ctxCancel = context.WithCancel(context.Background()) return m @@ -65,7 +67,7 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool { var ErrMultiAgentClosed = xerrors.New("multiagent is closed") -func (m *MultiAgent) UpdateSelf(node *Node) error { +func (m *MultiAgent) UpdateSelf(node *proto.Node) error { m.mu.RLock() defer m.mu.RUnlock() if m.closed { @@ -82,15 +84,11 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error { return ErrMultiAgentClosed } - node, err := m.OnSubscribe(m, agentID) + err := m.OnSubscribe(m, agentID) if err != nil { return err } - if node != nil { - return m.enqueueLocked([]*Node{node}) - } - return nil } @@ -104,17 +102,17 @@ func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error { return m.OnUnsubscribe(m, agentID) } -func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { +func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) { select { case <-ctx.Done(): return nil, false - case nodes, ok := <-m.updates: - return nodes, ok + case resp, ok := <-m.updates: + return resp, ok } } -func (m *MultiAgent) Enqueue(nodes []*Node) error { +func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error { m.mu.RLock() defer m.mu.RUnlock() @@ -122,14 +120,14 @@ func (m *MultiAgent) Enqueue(nodes []*Node) error { return nil } - return m.enqueueLocked(nodes) + return m.enqueueLocked(resp) } -func (m *MultiAgent) enqueueLocked(nodes []*Node) error { +func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error { atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) select { - case m.updates <- nodes: + case m.updates <- resp: return nil default: return ErrWouldBlock diff --git a/tailnet/service.go b/tailnet/service.go index 191319d16c..7347afbb32 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -75,7 +75,9 @@ func NewClientService( } server := drpcserver.NewWithOptions(mux, drpcserver.Options{ Log: func(err error) { - if xerrors.Is(err, io.EOF) { + if xerrors.Is(err, io.EOF) || + xerrors.Is(err, context.Canceled) || + xerrors.Is(err, context.DeadlineExceeded) { return } logger.Debug(context.Background(), "drpc server error", slog.Error(err)) diff --git a/tailnet/tailnettest/coordinatormock.go b/tailnet/tailnettest/coordinatormock.go new file mode 100644 index 0000000000..b0ae36d3f8 --- /dev/null +++ b/tailnet/tailnettest/coordinatormock.go @@ -0,0 +1,142 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/tailnet (interfaces: Coordinator) +// +// Generated by this command: +// +// mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator +// + +// Package tailnettest is a generated GoMock package. +package tailnettest + +import ( + context "context" + net "net" + http "net/http" + reflect "reflect" + + tailnet "github.com/coder/coder/v2/tailnet" + proto "github.com/coder/coder/v2/tailnet/proto" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" +) + +// MockCoordinator is a mock of Coordinator interface. +type MockCoordinator struct { + ctrl *gomock.Controller + recorder *MockCoordinatorMockRecorder +} + +// MockCoordinatorMockRecorder is the mock recorder for MockCoordinator. +type MockCoordinatorMockRecorder struct { + mock *MockCoordinator +} + +// NewMockCoordinator creates a new mock instance. +func NewMockCoordinator(ctrl *gomock.Controller) *MockCoordinator { + mock := &MockCoordinator{ctrl: ctrl} + mock.recorder = &MockCoordinatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCoordinator) EXPECT() *MockCoordinatorMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockCoordinator) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCoordinator)(nil).Close)) +} + +// Coordinate mocks base method. +func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(chan<- *proto.CoordinateRequest) + ret1, _ := ret[1].(<-chan *proto.CoordinateResponse) + return ret0, ret1 +} + +// Coordinate indicates an expected call of Coordinate. +func (mr *MockCoordinatorMockRecorder) Coordinate(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Coordinate", reflect.TypeOf((*MockCoordinator)(nil).Coordinate), arg0, arg1, arg2, arg3) +} + +// Node mocks base method. +func (m *MockCoordinator) Node(arg0 uuid.UUID) *tailnet.Node { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Node", arg0) + ret0, _ := ret[0].(*tailnet.Node) + return ret0 +} + +// Node indicates an expected call of Node. +func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0) +} + +// ServeAgent mocks base method. +func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ServeAgent indicates an expected call of ServeAgent. +func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2) +} + +// ServeClient mocks base method. +func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ServeClient indicates an expected call of ServeClient. +func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2) +} + +// ServeHTTPDebug mocks base method. +func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ServeHTTPDebug", arg0, arg1) +} + +// ServeHTTPDebug indicates an expected call of ServeHTTPDebug. +func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1) +} + +// ServeMultiAgent mocks base method. +func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServeMultiAgent", arg0) + ret0, _ := ret[0].(tailnet.MultiAgentConn) + return ret0 +} + +// ServeMultiAgent indicates an expected call of ServeMultiAgent. +func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0) +} diff --git a/tailnet/tailnettest/multiagentmock.go b/tailnet/tailnettest/multiagentmock.go deleted file mode 100644 index fd03a0e7f2..0000000000 --- a/tailnet/tailnettest/multiagentmock.go +++ /dev/null @@ -1,141 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn) -// -// Generated by this command: -// -// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn -// - -// Package tailnettest is a generated GoMock package. -package tailnettest - -import ( - context "context" - reflect "reflect" - - tailnet "github.com/coder/coder/v2/tailnet" - uuid "github.com/google/uuid" - gomock "go.uber.org/mock/gomock" -) - -// MockMultiAgentConn is a mock of MultiAgentConn interface. -type MockMultiAgentConn struct { - ctrl *gomock.Controller - recorder *MockMultiAgentConnMockRecorder -} - -// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn. -type MockMultiAgentConnMockRecorder struct { - mock *MockMultiAgentConn -} - -// NewMockMultiAgentConn creates a new mock instance. -func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn { - mock := &MockMultiAgentConn{ctrl: ctrl} - mock.recorder = &MockMultiAgentConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder { - return m.recorder -} - -// AgentIsLegacy mocks base method. -func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AgentIsLegacy", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// AgentIsLegacy indicates an expected call of AgentIsLegacy. -func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0) -} - -// Close mocks base method. -func (m *MockMultiAgentConn) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close)) -} - -// IsClosed mocks base method. -func (m *MockMultiAgentConn) IsClosed() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsClosed") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsClosed indicates an expected call of IsClosed. -func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed)) -} - -// NextUpdate mocks base method. -func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextUpdate", arg0) - ret0, _ := ret[0].([]*tailnet.Node) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// NextUpdate indicates an expected call of NextUpdate. -func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0) -} - -// SubscribeAgent mocks base method. -func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubscribeAgent", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SubscribeAgent indicates an expected call of SubscribeAgent. -func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0) -} - -// UnsubscribeAgent mocks base method. -func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// UnsubscribeAgent indicates an expected call of UnsubscribeAgent. -func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0) -} - -// UpdateSelf mocks base method. -func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateSelf", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateSelf indicates an expected call of UpdateSelf. -func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0) -} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index e9eb45ad96..e7ed6361a1 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -21,7 +21,7 @@ import ( "github.com/coder/coder/v2/tailnet" ) -//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn +//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator // RunDERPAndSTUN creates a DERP mapping for tests. func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) { diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go index 3b3feaa132..a801cdfae0 100644 --- a/tailnet/trackedconn.go +++ b/tailnet/trackedconn.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet/proto" ) const ( @@ -29,7 +30,7 @@ type TrackedConn struct { cancel func() kind QueueKind conn net.Conn - updates chan []*Node + updates chan *proto.CoordinateResponse logger slog.Logger lastData []byte @@ -55,7 +56,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), // coordinator mutex while queuing. Node updates don't // come quickly, so 512 should be plenty for all but // the most pathological cases. - updates := make(chan []*Node, ResponseBufferSize) + updates := make(chan *proto.CoordinateResponse, ResponseBufferSize) now := time.Now().Unix() return &TrackedConn{ ctx: ctx, @@ -72,10 +73,10 @@ func NewTrackedConn(ctx context.Context, cancel func(), } } -func (t *TrackedConn) Enqueue(n []*Node) (err error) { +func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) { atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) select { - case t.updates <- n: + case t.updates <- resp: return nil default: return ErrWouldBlock @@ -124,7 +125,16 @@ func (t *TrackedConn) SendUpdates() { case <-t.ctx.Done(): t.logger.Debug(t.ctx, "done sending updates") return - case nodes := <-t.updates: + case resp := <-t.updates: + nodes, err := OnlyNodeUpdates(resp) + if err != nil { + t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err)) + return + } + if len(nodes) == 0 { + t.logger.Debug(t.ctx, "skipping response with no nodes") + continue + } data, err := json.Marshal(nodes) if err != nil { t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))