diff --git a/.prettierignore b/.prettierignore index 37cbd3fef3..9be32290ac 100644 --- a/.prettierignore +++ b/.prettierignore @@ -83,6 +83,7 @@ helm/**/templates/*.yaml # Testdata shouldn't be formatted. scripts/apitypings/testdata/**/*.ts enterprise/tailnet/testdata/*.golden.html +tailnet/testdata/*.golden.html # Generated files shouldn't be formatted. site/e2e/provisionerGenerated.ts diff --git a/.prettierignore.include b/.prettierignore.include index fd7f94f13d..7efd582e15 100644 --- a/.prettierignore.include +++ b/.prettierignore.include @@ -9,6 +9,7 @@ helm/**/templates/*.yaml # Testdata shouldn't be formatted. scripts/apitypings/testdata/**/*.ts enterprise/tailnet/testdata/*.golden.html +tailnet/testdata/*.golden.html # Generated files shouldn't be formatted. site/e2e/provisionerGenerated.ts diff --git a/Makefile b/Makefile index d16263b517..8ee86d1584 100644 --- a/Makefile +++ b/Makefile @@ -602,6 +602,7 @@ update-golden-files: \ scripts/ci-report/testdata/.gen-golden \ enterprise/cli/testdata/.gen-golden \ enterprise/tailnet/testdata/.gen-golden \ + tailnet/testdata/.gen-golden \ coderd/.gen-golden \ provisioner/terraform/testdata/.gen-golden .PHONY: update-golden-files @@ -614,6 +615,10 @@ enterprise/cli/testdata/.gen-golden: $(wildcard enterprise/cli/testdata/*.golden go test ./enterprise/cli -run="TestEnterpriseCommandHelp" -update touch "$@" +tailnet/testdata/.gen-golden: $(wildcard tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard tailnet/*_test.go) + go test ./tailnet -run="TestDebugTemplate" -update + touch "$@" + enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard enterprise/tailnet/*_test.go) go test ./enterprise/tailnet -run="TestDebugTemplate" -update touch "$@" diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 83c1d8a2b9..94d080e219 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -86,7 +86,7 @@ func (c *connIO) recvLoop() { if c.disconnected { b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED } - if err := sendCtx(c.coordCtx, c.bindings, b); err != nil { + if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil { c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err)) } // only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then @@ -97,14 +97,14 @@ func (c *connIO) recvLoop() { tKey: tKey{src: c.UniqueID()}, active: false, } - if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil { c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err)) } } }() defer c.Close() for { - req, err := recvCtx(c.peerCtx, c.requests) + req, err := agpl.RecvCtx(c.peerCtx, c.requests) if err != nil { if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) || @@ -132,7 +132,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { node: req.UpdateSelf.Node, kind: proto.CoordinateResponse_PeerUpdate_NODE, } - if err := sendCtx(c.coordCtx, c.bindings, b); err != nil { + if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil { c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err)) return err } @@ -156,7 +156,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { }, active: true, } - if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil { c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err)) return err } @@ -177,7 +177,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { }, active: false, } - if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil { + if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil { c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err)) return err } diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 5a26cdc92a..068b91160f 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -5,17 +5,22 @@ import ( "context" "encoding/json" "errors" + "fmt" + "html/template" "io" "net" "net/http" "sync" + "time" "github.com/google/uuid" lru "github.com/hashicorp/golang-lru/v2" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" agpl "github.com/coder/coder/v2/tailnet" ) @@ -719,7 +724,209 @@ func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { c.mutex.RLock() defer c.mutex.RUnlock() - agpl.CoordinatorHTTPDebug( - agpl.HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache), + CoordinatorHTTPDebug( + HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache), )(w, r) } + +func HTTPDebugFromLocal( + ha bool, + agentSocketsMap map[uuid.UUID]agpl.Queue, + agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]agpl.Queue, + nodesMap map[uuid.UUID]*agpl.Node, + agentNameCache *lru.Cache[uuid.UUID, string], +) HTMLDebugHA { + now := time.Now() + data := HTMLDebugHA{HA: ha} + for id, conn := range agentSocketsMap { + start, lastWrite := conn.Stats() + agent := &HTMLAgent{ + Name: conn.Name(), + ID: id, + CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), + LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + Overwrites: int(conn.Overwrites()), + } + + for id, conn := range agentToConnectionSocketsMap[id] { + start, lastWrite := conn.Stats() + agent.Connections = append(agent.Connections, &HTMLClient{ + Name: conn.Name(), + ID: id, + CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), + LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + }) + } + slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int { + return slice.Ascending(a.Name, b.Name) + }) + + data.Agents = append(data.Agents, agent) + } + slices.SortFunc(data.Agents, func(a, b *HTMLAgent) int { + return slice.Ascending(a.Name, b.Name) + }) + + for agentID, conns := range agentToConnectionSocketsMap { + if len(conns) == 0 { + continue + } + + if _, ok := agentSocketsMap[agentID]; ok { + continue + } + + agentName, ok := agentNameCache.Get(agentID) + if !ok { + agentName = "unknown" + } + agent := &HTMLAgent{ + Name: agentName, + ID: agentID, + } + for id, conn := range conns { + start, lastWrite := conn.Stats() + agent.Connections = append(agent.Connections, &HTMLClient{ + Name: conn.Name(), + ID: id, + CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), + LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + }) + } + slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int { + return slice.Ascending(a.Name, b.Name) + }) + + data.MissingAgents = append(data.MissingAgents, agent) + } + slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) int { + return slice.Ascending(a.Name, b.Name) + }) + + for id, node := range nodesMap { + name, _ := agentNameCache.Get(id) + data.Nodes = append(data.Nodes, &HTMLNode{ + ID: id, + Name: name, + Node: node, + }) + } + slices.SortFunc(data.Nodes, func(a, b *HTMLNode) int { + return slice.Ascending(a.Name+a.ID.String(), b.Name+b.ID.String()) + }) + + return data +} + +func CoordinatorHTTPDebug(data HTMLDebugHA) func(w http.ResponseWriter, _ *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{ + "marshal": func(v any) template.JS { + a, err := json.MarshalIndent(v, "", " ") + if err != nil { + //nolint:gosec + return template.JS(fmt.Sprintf(`{"err": %q}`, err)) + } + //nolint:gosec + return template.JS(a) + }, + }).Parse(haCoordinatorDebugTmpl) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + return + } + + err = tmpl.Execute(w, data) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + return + } + } +} + +type HTMLDebugHA struct { + HA bool + Agents []*HTMLAgent + MissingAgents []*HTMLAgent + Nodes []*HTMLNode +} + +type HTMLAgent struct { + Name string + ID uuid.UUID + CreatedAge time.Duration + LastWriteAge time.Duration + Overwrites int + Connections []*HTMLClient +} + +type HTMLClient struct { + Name string + ID uuid.UUID + CreatedAge time.Duration + LastWriteAge time.Duration +} + +type HTMLNode struct { + ID uuid.UUID + Name string + Node any +} + +var haCoordinatorDebugTmpl = ` + + + + + + + {{- if .HA }} +

high-availability wireguard coordinator debug

+

warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete

+ {{- else }} +

in-memory wireguard coordinator debug

+ {{- end }} + +

# agents: total {{ len .Agents }}

+ + +

# missing agents: total {{ len .MissingAgents }}

+ + +

# nodes: total {{ len .Nodes }}

+ + + +` diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 635d1d4fce..7bc915801e 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -3,17 +3,12 @@ package tailnet import ( "context" "database/sql" - "encoding/json" - "io" "net" - "net/netip" "strings" "sync" "sync/atomic" "time" - "nhooyr.io/websocket" - "github.com/coder/coder/v2/tailnet/proto" "github.com/cenkalti/backoff/v4" @@ -30,17 +25,16 @@ import ( ) const ( - EventHeartbeats = "tailnet_coordinator_heartbeat" - eventPeerUpdate = "tailnet_peer_update" - eventTunnelUpdate = "tailnet_tunnel_update" - HeartbeatPeriod = time.Second * 2 - MissedHeartbeats = 3 - numQuerierWorkers = 10 - numBinderWorkers = 10 - numTunnelerWorkers = 10 - dbMaxBackoff = 10 * time.Second - cleanupPeriod = time.Hour - requestResponseBuffSize = 32 + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventPeerUpdate = "tailnet_peer_update" + eventTunnelUpdate = "tailnet_tunnel_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + numTunnelerWorkers = 10 + dbMaxBackoff = 10 * time.Second + cleanupPeriod = time.Hour ) // pgCoord is a postgres-backed coordinator @@ -161,55 +155,8 @@ func NewPGCoordV2(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, sto return newPGCoordInternal(ctx, logger, ps, store) } -// This is copied from codersdk because importing it here would cause an import -// cycle. This is just temporary until wsconncache is phased out. -var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") - func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { - logger := c.logger.With(slog.F("client_id", id)).Named("multiagent") - ctx, cancel := context.WithCancel(c.ctx) - reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.SingleTailnetTunnelAuth{}) - ma := (&agpl.MultiAgent{ - ID: id, - AgentIsLegacyFunc: func(agentID uuid.UUID) bool { - if n := c.Node(agentID); n == nil { - // If we don't have the node at all assume it's legacy for - // safety. - return true - } else if len(n.Addresses) > 0 && n.Addresses[0].Addr() == legacyAgentIP { - // An agent is determined to be "legacy" if it's first IP is the - // legacy IP. Agents with only the legacy IP aren't compatible - // with single_tailnet and must be routed through wsconncache. - return true - } else { - return false - } - }, - OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) { - err := sendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}}) - return c.Node(agent), err - }, - OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error { - err := sendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}}) - return err - }, - OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error { - pn, err := agpl.NodeToProto(node) - if err != nil { - return err - } - return sendCtx(c.ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ - Node: pn, - }}) - }, - OnRemove: func(_ agpl.Queue) { - _ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) - cancel() - }, - }).Init() - - go v1SendLoop(ctx, cancel, logger, ma, resps) - return ma + return agpl.ServeMultiAgent(c, c.logger, id) } func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { @@ -253,116 +200,11 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { } func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { - logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) - defer func() { - err := conn.Close() - if err != nil { - logger.Debug(c.ctx, "closing client connection", slog.Error(err)) - } - }() - ctx, cancel := context.WithCancel(c.ctx) - defer cancel() - reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.ClientTunnelAuth{AgentID: agent}) - err := sendCtx(ctx, reqs, &proto.CoordinateRequest{ - AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}, - }) - if err != nil { - // can only be a context error, no need to log here. - return err - } - defer func() { - _ = sendCtx(ctx, reqs, &proto.CoordinateRequest{ - RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}, - }) - }() - - tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient) - go tc.SendUpdates() - go v1SendLoop(ctx, cancel, logger, tc, resps) - go v1RecvLoop(ctx, cancel, logger, conn, reqs) - <-ctx.Done() - return nil + return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent) } func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { - logger := c.logger.With(slog.F("agent_id", id), slog.F("name", name)) - defer func() { - logger.Debug(c.ctx, "closing agent connection") - err := conn.Close() - logger.Debug(c.ctx, "closed agent connection", slog.Error(err)) - }() - ctx, cancel := context.WithCancel(c.ctx) - defer cancel() - reqs, resps := c.Coordinate(ctx, id, name, agpl.AgentTunnelAuth{}) - tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, agpl.QueueKindAgent) - go tc.SendUpdates() - go v1SendLoop(ctx, cancel, logger, tc, resps) - go v1RecvLoop(ctx, cancel, logger, conn, reqs) - <-ctx.Done() - return nil -} - -func v1RecvLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, - conn net.Conn, reqs chan<- *proto.CoordinateRequest, -) { - defer cancel() - decoder := json.NewDecoder(conn) - for { - var node agpl.Node - err := decoder.Decode(&node) - if err != nil { - if xerrors.Is(err, io.EOF) || - xerrors.Is(err, io.ErrClosedPipe) || - xerrors.Is(err, context.Canceled) || - xerrors.Is(err, context.DeadlineExceeded) || - websocket.CloseStatus(err) > 0 { - logger.Debug(ctx, "exiting recvLoop", slog.Error(err)) - } else { - logger.Error(ctx, "failed to decode Node update", slog.Error(err)) - } - return - } - logger.Debug(ctx, "got node update", slog.F("node", node)) - pn, err := agpl.NodeToProto(&node) - if err != nil { - logger.Critical(ctx, "failed to convert v1 node", slog.F("node", node), slog.Error(err)) - return - } - req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ - Node: pn, - }} - if err := sendCtx(ctx, reqs, req); err != nil { - logger.Debug(ctx, "recvLoop ctx expired", slog.Error(err)) - return - } - } -} - -func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q agpl.Queue, resps <-chan *proto.CoordinateResponse) { - defer cancel() - for { - resp, err := recvCtx(ctx, resps) - if err != nil { - logger.Debug(ctx, "done reading responses", slog.Error(err)) - return - } - logger.Debug(ctx, "v1: got response", slog.F("resp", resp)) - nodes, err := agpl.OnlyNodeUpdates(resp) - if err != nil { - logger.Critical(ctx, "failed to decode resp", slog.F("resp", resp), slog.Error(err)) - _ = q.CoordinatorClose() - return - } - // don't send empty updates - if len(nodes) == 0 { - logger.Debug(ctx, "skipping enqueueing 0-length v1 update") - continue - } - err = q.Enqueue(nodes) - if err != nil { - logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err)) - } - } + return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name) } func (c *pgCoord) Close() error { @@ -378,43 +220,22 @@ func (c *pgCoord) Coordinate( chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, ) { logger := c.logger.With(slog.F("peer_id", id)) - reqs := make(chan *proto.CoordinateRequest, requestResponseBuffSize) + reqs := make(chan *proto.CoordinateRequest, agpl.RequestBufferSize) resps := make(chan *proto.CoordinateResponse, agpl.ResponseBufferSize) cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a) - err := sendCtx(c.ctx, c.newConnections, cIO) + err := agpl.SendCtx(c.ctx, c.newConnections, cIO) if err != nil { // this can only happen if the context is canceled, no need to log return reqs, resps } go func() { <-cIO.Done() - _ = sendCtx(c.ctx, c.closeConnections, cIO) + _ = agpl.SendCtx(c.ctx, c.closeConnections, cIO) }() return reqs, resps } -func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { - select { - case <-ctx.Done(): - return ctx.Err() - case c <- a: - return nil - } -} - -func recvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) { - select { - case <-ctx.Done(): - return a, ctx.Err() - case a, ok := <-c: - if ok { - return a, nil - } - return a, io.EOF - } -} - type tKey struct { src uuid.UUID dst uuid.UUID @@ -873,7 +694,7 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE: eq, err := sm.node.Equal(mpng.node) if err != nil { - m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind)) + m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.node)) continue } if eq { @@ -1303,7 +1124,7 @@ func (q *querier) updateAll() { go func(m *mapper) { // make sure we send on the _mapper_ context, not our own in case the mapper is // shutting down or shut down. - _ = sendCtx(m.ctx, m.update, struct{}{}) + _ = agpl.SendCtx(m.ctx, m.update, struct{}{}) }(mpr) } } @@ -1603,7 +1424,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) { h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id)) // send on a separate goroutine to avoid holding lock. Triggering update can be async go func() { - _ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) + _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) }() } h.coordinators[id] = time.Now() @@ -1650,7 +1471,7 @@ func (h *heartbeats) checkExpiry() { if expired { // send on a separate goroutine to avoid holding lock. Triggering update can be async go func() { - _ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) + _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated}) }() } // we need to reset the timer for when the next oldest coordinator will expire, if any. @@ -1685,14 +1506,14 @@ func (h *heartbeats) sendBeat() { h.failedHeartbeats++ if h.failedHeartbeats == 3 { h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy") - _ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy}) + _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy}) } return } h.logger.Debug(h.ctx, "sent heartbeat") if h.failedHeartbeats >= 3 { h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy") - _ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy}) + _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy}) } h.failedHeartbeats = 0 } diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 920bb8a2c8..8a581f48b6 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + agpltest "github.com/coder/coder/v2/tailnet/test" + "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -612,18 +614,7 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) { coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() - - p1 := newTestPeer(ctx, t, coordinator, "p1") - defer p1.close(ctx) - p2 := newTestPeer(ctx, t, coordinator, "p2") - defer p2.close(ctx) - p1.addTunnel(p2.id) - p2.addTunnel(p1.id) - p1.updateDERP(1) - p2.updateDERP(2) - - p1.assertEventuallyHasDERP(p2.id, 2) - p2.assertEventuallyHasDERP(p1.id, 1) + agpltest.BidirectionalTunnels(ctx, t, coordinator) } func TestPGCoordinator_GracefulDisconnect(t *testing.T) { @@ -638,21 +629,7 @@ func TestPGCoordinator_GracefulDisconnect(t *testing.T) { coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() - - p1 := newTestPeer(ctx, t, coordinator, "p1") - defer p1.close(ctx) - p2 := newTestPeer(ctx, t, coordinator, "p2") - defer p2.close(ctx) - p1.addTunnel(p2.id) - p1.updateDERP(1) - p2.updateDERP(2) - - p1.assertEventuallyHasDERP(p2.id, 2) - p2.assertEventuallyHasDERP(p1.id, 1) - - p2.disconnect() - p1.assertEventuallyDisconnected(p2.id) - p2.assertEventuallyResponsesClosed() + agpltest.GracefulDisconnectTest(ctx, t, coordinator) } func TestPGCoordinator_Lost(t *testing.T) { @@ -667,20 +644,7 @@ func TestPGCoordinator_Lost(t *testing.T) { coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store) require.NoError(t, err) defer coordinator.Close() - - p1 := newTestPeer(ctx, t, coordinator, "p1") - defer p1.close(ctx) - p2 := newTestPeer(ctx, t, coordinator, "p2") - defer p2.close(ctx) - p1.addTunnel(p2.id) - p1.updateDERP(1) - p2.updateDERP(2) - - p1.assertEventuallyHasDERP(p2.id, 2) - p2.assertEventuallyHasDERP(p1.id, 1) - - p2.close(ctx) - p1.assertEventuallyLost(p2.id) + agpltest.LostTest(ctx, t, coordinator) } type testConn struct { @@ -918,177 +882,6 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store }, testutil.WaitShort, testutil.IntervalFast) } -type peerStatus struct { - preferredDERP int32 - status proto.CoordinateResponse_PeerUpdate_Kind -} - -type testPeer struct { - ctx context.Context - cancel context.CancelFunc - t testing.TB - id uuid.UUID - name string - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest - peers map[uuid.UUID]peerStatus -} - -func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer { - p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)} - p.ctx, p.cancel = context.WithCancel(ctx) - if len(id) > 1 { - t.Fatal("too many") - } - if len(id) == 1 { - p.id = id[0] - } else { - p.id = uuid.New() - } - // SingleTailnetTunnelAuth allows connections to arbitrary peers - p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{}) - return p -} - -func (p *testPeer) addTunnel(other uuid.UUID) { - p.t.Helper() - req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}} - select { - case <-p.ctx.Done(): - p.t.Errorf("timeout adding tunnel for %s", p.name) - return - case p.reqs <- req: - return - } -} - -func (p *testPeer) updateDERP(derp int32) { - p.t.Helper() - req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}} - select { - case <-p.ctx.Done(): - p.t.Errorf("timeout updating node for %s", p.name) - return - case p.reqs <- req: - return - } -} - -func (p *testPeer) disconnect() { - p.t.Helper() - req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}} - select { - case <-p.ctx.Done(): - p.t.Errorf("timeout updating node for %s", p.name) - return - case p.reqs <- req: - return - } -} - -func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) { - p.t.Helper() - for { - o, ok := p.peers[other] - if ok && o.preferredDERP == derp { - return - } - if err := p.handleOneResp(); err != nil { - assert.NoError(p.t, err) - return - } - } -} - -func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) { - p.t.Helper() - for { - _, ok := p.peers[other] - if !ok { - return - } - if err := p.handleOneResp(); err != nil { - assert.NoError(p.t, err) - return - } - } -} - -func (p *testPeer) assertEventuallyLost(other uuid.UUID) { - p.t.Helper() - for { - o := p.peers[other] - if o.status == proto.CoordinateResponse_PeerUpdate_LOST { - return - } - if err := p.handleOneResp(); err != nil { - assert.NoError(p.t, err) - return - } - } -} - -func (p *testPeer) assertEventuallyResponsesClosed() { - p.t.Helper() - for { - err := p.handleOneResp() - if xerrors.Is(err, responsesClosed) { - return - } - if !assert.NoError(p.t, err) { - return - } - } -} - -var responsesClosed = xerrors.New("responses closed") - -func (p *testPeer) handleOneResp() error { - select { - case <-p.ctx.Done(): - return p.ctx.Err() - case resp, ok := <-p.resps: - if !ok { - return responsesClosed - } - for _, update := range resp.PeerUpdates { - id, err := uuid.FromBytes(update.Uuid) - if err != nil { - return err - } - switch update.Kind { - case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: - p.peers[id] = peerStatus{ - preferredDERP: update.GetNode().GetPreferredDerp(), - status: update.Kind, - } - case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: - delete(p.peers, id) - default: - return xerrors.Errorf("unhandled update kind %s", update.Kind) - } - } - } - return nil -} - -func (p *testPeer) close(ctx context.Context) { - p.t.Helper() - p.cancel() - for { - select { - case <-ctx.Done(): - p.t.Errorf("timeout waiting for responses to close for %s", p.name) - return - case _, ok := <-p.resps: - if ok { - continue - } - return - } - } -} - type fakeCoordinator struct { ctx context.Context t *testing.T diff --git a/site/.eslintignore b/site/.eslintignore index 033d259091..fc74ff6dd9 100644 --- a/site/.eslintignore +++ b/site/.eslintignore @@ -83,6 +83,7 @@ result # Testdata shouldn't be formatted. ../scripts/apitypings/testdata/**/*.ts ../enterprise/tailnet/testdata/*.golden.html +../tailnet/testdata/*.golden.html # Generated files shouldn't be formatted. e2e/provisionerGenerated.ts diff --git a/site/.prettierignore b/site/.prettierignore index 033d259091..fc74ff6dd9 100644 --- a/site/.prettierignore +++ b/site/.prettierignore @@ -83,6 +83,7 @@ result # Testdata shouldn't be formatted. ../scripts/apitypings/testdata/**/*.ts ../enterprise/tailnet/testdata/*.golden.html +../tailnet/testdata/*.golden.html # Generated files shouldn't be formatted. e2e/provisionerGenerated.ts diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 2da96bc444..2c67e5ff72 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -3,8 +3,6 @@ package tailnet import ( "context" "encoding/json" - "errors" - "fmt" "html/template" "io" "net" @@ -14,14 +12,12 @@ import ( "time" "github.com/google/uuid" - lru "github.com/hashicorp/golang-lru/v2" - "golang.org/x/exp/slices" "golang.org/x/xerrors" + "nhooyr.io/websocket" "tailscale.com/tailcfg" "tailscale.com/types/key" "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/tailnet/proto" ) @@ -131,12 +127,29 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func const LoggerName = "coord" +var ( + ErrClosed = xerrors.New("coordinator is closed") + ErrWouldBlock = xerrors.New("would block") + ErrAlreadyRemoved = xerrors.New("already removed") +) + // NewCoordinator constructs a new in-memory connection coordinator. This // coordinator is incompatible with multiple Coder replicas as all node data is // in-memory. func NewCoordinator(logger slog.Logger) Coordinator { return &coordinator{ - core: newCore(logger), + core: newCore(logger.Named(LoggerName)), + closedChan: make(chan struct{}), + } +} + +// NewCoordinatorV2 constructs a new in-memory connection coordinator. This +// coordinator is incompatible with multiple Coder replicas as all node data is +// in-memory. +func NewCoordinatorV2(logger slog.Logger) CoordinatorV2 { + return &coordinator{ + core: newCore(logger.Named(LoggerName)), + closedChan: make(chan struct{}), } } @@ -149,26 +162,112 @@ func NewCoordinator(logger slog.Logger) Coordinator { // data is in-memory. type coordinator struct { core *core + + mu sync.Mutex + closed bool + wg sync.WaitGroup + closedChan chan struct{} +} + +func (c *coordinator) Coordinate( + ctx context.Context, id uuid.UUID, name string, a TunnelAuth, +) ( + chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, +) { + logger := c.core.logger.With( + slog.F("peer_id", id), + slog.F("peer_name", name), + ) + reqs := make(chan *proto.CoordinateRequest, RequestBufferSize) + resps := make(chan *proto.CoordinateResponse, ResponseBufferSize) + + p := &peer{ + logger: logger, + id: id, + name: name, + resps: resps, + reqs: reqs, + auth: a, + sent: make(map[uuid.UUID]*proto.Node), + } + err := c.core.initPeer(p) + if err != nil { + if xerrors.Is(err, ErrClosed) { + logger.Debug(ctx, "coordinate failed: Coordinator is closed") + } else { + logger.Critical(ctx, "coordinate failed: %s", err.Error()) + } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + // don't start the readLoop if we are closed. + return reqs, resps + } + c.wg.Add(1) + go func() { + defer c.wg.Done() + p.reqLoop(ctx, logger, c.core.handleRequest) + err := c.core.lostPeer(p) + if xerrors.Is(err, ErrClosed) || xerrors.Is(err, ErrAlreadyRemoved) { + return + } + if err != nil { + logger.Error(context.Background(), "failed to process lost peer", slog.Error(err)) + } + }() + return reqs, resps } func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { - m := (&MultiAgent{ - ID: id, - AgentIsLegacyFunc: c.core.agentIsLegacy, - OnSubscribe: c.core.clientSubscribeToAgent, - OnUnsubscribe: c.core.clientUnsubscribeFromAgent, - OnNodeUpdate: c.core.clientNodeUpdate, - OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) }, - }).Init() - c.core.addClient(id, m) - return m + return ServeMultiAgent(c, c.core.logger, id) } -func (c *core) addClient(id uuid.UUID, ma Queue) { - c.mutex.Lock() - c.clients[id] = ma - c.clientsToAgents[id] = map[uuid.UUID]Queue{} - c.mutex.Unlock() +func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn { + logger = logger.With(slog.F("client_id", id)).Named("multiagent") + ctx, cancel := context.WithCancel(context.Background()) + reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetTunnelAuth{}) + m := (&MultiAgent{ + ID: id, + AgentIsLegacyFunc: func(agentID uuid.UUID) bool { + if n := c.Node(agentID); n == nil { + // If we don't have the node at all assume it's legacy for + // safety. + return true + } else if len(n.Addresses) > 0 && n.Addresses[0].Addr() == legacyAgentIP { + // An agent is determined to be "legacy" if it's first IP is the + // legacy IP. Agents with only the legacy IP aren't compatible + // with single_tailnet and must be routed through wsconncache. + return true + } else { + return false + } + }, + OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) { + err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: UUIDToByteSlice(agent)}}) + return c.Node(agent), err + }, + OnUnsubscribe: func(enq Queue, agent uuid.UUID) error { + err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: UUIDToByteSlice(agent)}}) + return err + }, + OnNodeUpdate: func(id uuid.UUID, node *Node) error { + pn, err := NodeToProto(node) + if err != nil { + return err + } + return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ + Node: pn, + }}) + }, + OnRemove: func(_ Queue) { + _ = SendCtx(ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + cancel() + }, + }).Init() + + go v1RespLoop(ctx, cancel, logger, m, resps) + return m } // core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; @@ -178,29 +277,8 @@ type core struct { mutex sync.RWMutex closed bool - // nodes maps agent and connection IDs their respective node. - nodes map[uuid.UUID]*Node - // agentSockets maps agent IDs to their open websocket. - agentSockets map[uuid.UUID]Queue - // agentToConnectionSockets maps agent IDs to connection IDs of conns that - // are subscribed to updates for that agent. - agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue - - // clients holds a map of all clients connected to the coordinator. This is - // necessary because a client may not be subscribed into any agents. - clients map[uuid.UUID]Queue - // clientsToAgents is an index of clients to all of their subscribed agents. - clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue - - // agentNameCache holds a cache of agent names. If one of them disappears, - // it's helpful to have a name cached for debugging. - agentNameCache *lru.Cache[uuid.UUID, string] - - // legacyAgents holda a mapping of all agents detected as legacy, meaning - // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible - // with the new ServerTailnet, so they must be connected through - // wsconncache. - legacyAgents map[uuid.UUID]struct{} + peers map[uuid.UUID]*peer + tunnels *tunnelStore } type QueueKind int @@ -225,26 +303,14 @@ type Queue interface { } func newCore(logger slog.Logger) *core { - nameCache, err := lru.New[uuid.UUID, string](512) - if err != nil { - panic("make lru cache: " + err.Error()) - } - return &core{ - logger: logger, - closed: false, - nodes: map[uuid.UUID]*Node{}, - agentSockets: map[uuid.UUID]Queue{}, - agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{}, - agentNameCache: nameCache, - legacyAgents: map[uuid.UUID]struct{}{}, - clients: map[uuid.UUID]Queue{}, - clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{}, + logger: logger, + closed: false, + peers: make(map[uuid.UUID]*peer), + tunnels: newTunnelStore(), } } -var ErrWouldBlock = xerrors.New("would block") - // Node returns an in-memory node by ID. // If the node does not exist, nil is returned. func (c *coordinator) Node(id uuid.UUID) *Node { @@ -254,27 +320,17 @@ func (c *coordinator) Node(id uuid.UUID) *Node { func (c *core) node(id uuid.UUID) *Node { c.mutex.Lock() defer c.mutex.Unlock() - return c.nodes[id] -} - -func (c *coordinator) NodeCount() int { - return c.core.nodeCount() -} - -func (c *core) nodeCount() int { - c.mutex.Lock() - defer c.mutex.Unlock() - return len(c.nodes) -} - -func (c *coordinator) AgentCount() int { - return c.core.agentCount() -} - -func (c *core) agentCount() int { - c.mutex.Lock() - defer c.mutex.Unlock() - return len(c.agentSockets) + p := c.peers[id] + if p == nil || p.node == nil { + return nil + } + v1Node, err := ProtoToNode(p.node) + if err != nil { + c.logger.Critical(context.Background(), + "failed to convert node", slog.Error(err), slog.F("node", p.node)) + return nil + } + return v1Node } // ServeClient accepts a WebSocket connection that wants to connect to an agent @@ -282,180 +338,229 @@ func (c *core) agentCount() int { func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.core.clientLogger(id, agentID) - logger.Debug(ctx, "coordinating client") + return ServeClientV1(ctx, c.core.logger, c, conn, id, agentID) +} + +// ServeClientV1 adapts a v1 Client to a v2 Coordinator +func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + logger = logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) + defer func() { + err := conn.Close() + if err != nil { + logger.Debug(ctx, "closing client connection", slog.Error(err)) + } + }() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + reqs, resps := c.Coordinate(ctx, id, id.String(), ClientTunnelAuth{AgentID: agent}) + err := SendCtx(ctx, reqs, &proto.CoordinateRequest{ + AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: UUIDToByteSlice(agent)}, + }) + if err != nil { + // can only be a context error, no need to log here. + return err + } tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient) - defer tc.Close() - - c.core.addClient(id, tc) - defer c.core.clientDisconnected(id) - - agentNode, err := c.core.clientSubscribeToAgent(tc, agentID) - if err != nil { - return xerrors.Errorf("subscribe agent: %w", err) - } - - if agentNode != nil { - err := tc.Enqueue([]*Node{agentNode}) - if err != nil { - logger.Debug(ctx, "enqueue initial node", slog.Error(err)) - } - } - - // On this goroutine, we read updates from the client and publish them. We start a second goroutine - // to write updates back to the client. go tc.SendUpdates() - - decoder := json.NewDecoder(conn) - for { - err := c.handleNextClientMessage(id, decoder) - if err != nil { - logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err)) - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { - return nil - } - return xerrors.Errorf("handle next client message: %w", err) - } - } -} - -func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger { - return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) -} - -func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) { - connectionSockets, ok := c.agentToConnectionSockets[agentID] - if !ok { - connectionSockets = map[uuid.UUID]Queue{} - c.agentToConnectionSockets[agentID] = connectionSockets - } - connectionSockets[enq.UniqueID()] = enq - - c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID] -} - -func (c *core) clientDisconnected(id uuid.UUID) { - logger := c.clientLogger(id, uuid.Nil) - c.mutex.Lock() - defer c.mutex.Unlock() - // Clean all traces of this connection from the map. - delete(c.nodes, id) - logger.Debug(context.Background(), "deleted client node") - - for agentID := range c.clientsToAgents[id] { - connectionSockets, ok := c.agentToConnectionSockets[agentID] - if !ok { - continue - } - delete(connectionSockets, id) - logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID)) - - if len(connectionSockets) == 0 { - delete(c.agentToConnectionSockets, agentID) - logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID)) - } - } - - delete(c.clients, id) - delete(c.clientsToAgents, id) - logger.Debug(context.Background(), "deleted client agents") -} - -func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error { - logger := c.core.clientLogger(id, uuid.Nil) - - var node Node - err := decoder.Decode(&node) - if err != nil { - return xerrors.Errorf("read json: %w", err) - } - - logger.Debug(context.Background(), "got client node update", slog.F("node", node)) - return c.core.clientNodeUpdate(id, &node) -} - -func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - // Update the node of this client in our in-memory map. If an agent entirely - // shuts down and reconnects, it needs to be aware of all clients attempting - // to establish connections. - c.nodes[id] = node - - return c.clientNodeUpdateLocked(id, node) -} - -func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error { - logger := c.clientLogger(id, uuid.Nil) - - agents := []uuid.UUID{} - for agentID, agentSocket := range c.clientsToAgents[id] { - if agentSocket == nil { - logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID)) - continue - } - - err := agentSocket.Enqueue([]*Node{node}) - if err != nil { - logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID)) - continue - } - agents = append(agents, agentID) - } - - logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents)) + go v1RespLoop(ctx, cancel, logger, tc, resps) + go v1ReqLoop(ctx, cancel, logger, conn, reqs) + <-ctx.Done() return nil } -func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) { +func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error { c.mutex.Lock() defer c.mutex.Unlock() + if c.closed { + return ErrClosed + } + pr, ok := c.peers[p.id] + if !ok || pr != p { + return ErrAlreadyRemoved + } + if req.UpdateSelf != nil { + err := c.nodeUpdateLocked(p, req.UpdateSelf.Node) + if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) { + return nil + } + if err != nil { + return xerrors.Errorf("node update failed: %w", err) + } + } + if req.AddTunnel != nil { + dstID, err := uuid.FromBytes(req.AddTunnel.Uuid) + if err != nil { + // this shouldn't happen unless there is a client error. Close the connection so the client + // doesn't just happily continue thinking everything is fine. + return xerrors.Errorf("unable to convert bytes to UUID: %w", err) + } + err = c.addTunnelLocked(p, dstID) + if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) { + return nil + } + if err != nil { + return xerrors.Errorf("add tunnel failed: %w", err) + } + } + if req.RemoveTunnel != nil { + dstID, err := uuid.FromBytes(req.RemoveTunnel.Uuid) + if err != nil { + // this shouldn't happen unless there is a client error. Close the connection so the client + // doesn't just happily continue thinking everything is fine. + return xerrors.Errorf("unable to convert bytes to UUID: %w", err) + } + err = c.removeTunnelLocked(p, dstID) + if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) { + return nil + } + if err != nil { + return xerrors.Errorf("remove tunnel failed: %w", err) + } + } + if req.Disconnect != nil { + c.removePeerLocked(p.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "graceful disconnect") + } + return nil +} - logger := c.clientLogger(enq.UniqueID(), agentID) +func (c *core) nodeUpdateLocked(p *peer, node *proto.Node) error { + c.logger.Debug(context.Background(), "processing node update", + slog.F("peer_id", p.id), + slog.F("node", node.String())) - c.initOrSetAgentConnectionSocketLocked(agentID, enq) + p.node = node + c.updateTunnelPeersLocked(p.id, node, proto.CoordinateResponse_PeerUpdate_NODE, "node update") + return nil +} - node, ok := c.nodes[enq.UniqueID()] +func (c *core) updateTunnelPeersLocked(id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) { + tp := c.tunnels.findTunnelPeers(id) + c.logger.Debug(context.Background(), "got tunnel peers", slog.F("peer_id", id), slog.F("tunnel_peers", tp)) + for _, otherID := range tp { + other, ok := c.peers[otherID] + if !ok { + continue + } + err := other.updateMappingLocked(id, n, k, reason) + if err != nil { + other.logger.Error(context.Background(), "failed to update mapping", slog.Error(err)) + c.removePeerLocked(other.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update") + } + } +} + +func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error { + if !src.auth.Authorize(dstID) { + return xerrors.Errorf("src %s is not allowed to tunnel to %s", src.id, dstID) + } + c.tunnels.add(src.id, dstID) + c.logger.Debug(context.Background(), "adding tunnel", + slog.F("src_id", src.id), + slog.F("dst_id", dstID)) + dst, ok := c.peers[dstID] if ok { - // If we have the client node, send it to the agent. If not, it will be - // sent async. - agentSocket, ok := c.agentSockets[agentID] - if !ok { - logger.Debug(context.Background(), "subscribe to agent; socket is nil") - } else { - err := agentSocket.Enqueue([]*Node{node}) + if dst.node != nil { + err := src.updateMappingLocked(dstID, dst.node, proto.CoordinateResponse_PeerUpdate_NODE, "add tunnel") if err != nil { - return nil, xerrors.Errorf("enqueue client to agent: %w", err) + src.logger.Error(context.Background(), "failed update of tunnel src", slog.Error(err)) + c.removePeerLocked(src.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update") + // if the source fails, then the tunnel is also removed and there is no reason to continue + // processing. + return err + } + } + if src.node != nil { + err := dst.updateMappingLocked(src.id, src.node, proto.CoordinateResponse_PeerUpdate_NODE, "add tunnel") + if err != nil { + dst.logger.Error(context.Background(), "failed update of tunnel dst", slog.Error(err)) + c.removePeerLocked(dst.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update") } } - } else { - logger.Debug(context.Background(), "multiagent node doesn't exist") } - - agentNode, ok := c.nodes[agentID] - if !ok { - // This is ok, once the agent connects the node will be sent over. - logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID)) - } - - // Send the subscribed agent back to the multi agent. - return agentNode, nil -} - -func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - delete(c.clientsToAgents[enq.UniqueID()], agentID) - delete(c.agentToConnectionSockets[agentID], enq.UniqueID()) - return nil } -func (c *core) agentLogger(id uuid.UUID) slog.Logger { - return c.logger.With(slog.F("agent_id", id)) +func (c *core) removeTunnelLocked(src *peer, dstID uuid.UUID) error { + err := src.updateMappingLocked(dstID, nil, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "remove tunnel") + if err != nil { + src.logger.Error(context.Background(), "failed to update", slog.Error(err)) + c.removePeerLocked(src.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update") + // removing the peer also removes all other tunnels and notifies destinations, so it's safe to + // return here. + return err + } + dst, ok := c.peers[dstID] + if ok { + err = dst.updateMappingLocked(src.id, nil, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "remove tunnel") + if err != nil { + dst.logger.Error(context.Background(), "failed to update", slog.Error(err)) + c.removePeerLocked(dst.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update") + // don't return here because we still want to remove the tunnel, and an error at the + // destination doesn't count as an error removing the tunnel at the source. + } + } + c.tunnels.remove(src.id, dstID) + return nil +} + +func (c *core) initPeer(p *peer) error { + c.mutex.Lock() + defer c.mutex.Unlock() + p.logger.Debug(context.Background(), "initPeer") + if c.closed { + return ErrClosed + } + if p.node != nil { + return xerrors.Errorf("peer (%s) must be initialized with nil node", p.id) + } + if old, ok := c.peers[p.id]; ok { + // rare and interesting enough to log at Info, but it isn't an error per se + old.logger.Info(context.Background(), "overwritten by new connection") + close(old.resps) + p.overwrites = old.overwrites + 1 + } + now := time.Now() + p.start = now + p.lastWrite = now + c.peers[p.id] = p + + tp := c.tunnels.findTunnelPeers(p.id) + p.logger.Debug(context.Background(), "initial tunnel peers", slog.F("tunnel_peers", tp)) + var others []*peer + for _, otherID := range tp { + // ok to append nil here because the batch call below filters them out + others = append(others, c.peers[otherID]) + } + return p.batchUpdateMappingLocked(others, proto.CoordinateResponse_PeerUpdate_NODE, "init") +} + +// removePeer removes and cleans up a lost peer. It updates all peers it shares a tunnel with, deletes +// all tunnels from which the removed peer is the source. +func (c *core) lostPeer(p *peer) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.logger.Debug(context.Background(), "lostPeer", slog.F("peer_id", p.id)) + if c.closed { + return ErrClosed + } + if existing, ok := c.peers[p.id]; !ok || existing != p { + return ErrAlreadyRemoved + } + c.removePeerLocked(p.id, proto.CoordinateResponse_PeerUpdate_LOST, "lost") + return nil +} + +func (c *core) removePeerLocked(id uuid.UUID, kind proto.CoordinateResponse_PeerUpdate_Kind, reason string) { + p, ok := c.peers[id] + if !ok { + c.logger.Critical(context.Background(), "removed non-existent peer %s", id) + return + } + c.updateTunnelPeersLocked(id, nil, kind, reason) + c.tunnels.removeAll(id) + close(p.resps) + delete(c.peers, id) } // ServeAgent accepts a WebSocket connection to an agent that @@ -463,204 +568,63 @@ func (c *core) agentLogger(id uuid.UUID) slog.Logger { func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := c.core.agentLogger(id) - logger.Debug(context.Background(), "coordinating agent") - // This uniquely identifies a connection that belongs to this goroutine. - unique := uuid.New() - tc, err := c.core.initAndTrackAgent(ctx, cancel, conn, id, unique, name) - if err != nil { - return err - } + return ServeAgentV1(ctx, c.core.logger, c, conn, id, name) +} - // On this goroutine, we read updates from the agent and publish them. We start a second goroutine - // to write updates back to the agent. +func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, name string) error { + logger = logger.With(slog.F("agent_id", id), slog.F("name", name)) + defer func() { + logger.Debug(ctx, "closing agent connection") + err := conn.Close() + logger.Debug(ctx, "closed agent connection", slog.Error(err)) + }() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + logger.Debug(ctx, "starting new agent connection") + reqs, resps := c.Coordinate(ctx, id, name, AgentTunnelAuth{}) + tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent) go tc.SendUpdates() - - defer c.core.agentDisconnected(id, unique) - - decoder := json.NewDecoder(conn) - for { - err := c.handleNextAgentMessage(id, decoder) - if err != nil { - logger.Debug(ctx, "unable to read agent update, connection may be closed", slog.Error(err)) - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) { - return nil - } - return xerrors.Errorf("handle next agent message: %w", err) - } - } -} - -func (c *core) agentDisconnected(id, unique uuid.UUID) { - logger := c.agentLogger(id) - c.mutex.Lock() - defer c.mutex.Unlock() - - // Only delete the connection if it's ours. It could have been - // overwritten. - if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique { - delete(c.agentSockets, id) - delete(c.nodes, id) - logger.Debug(context.Background(), "deleted agent socket and node") - } - for clientID := range c.agentToConnectionSockets[id] { - c.clientsToAgents[clientID][id] = nil - } -} - -// initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is -// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some -// updates. -func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Conn, id, unique uuid.UUID, name string) (*TrackedConn, error) { - logger := c.logger.With(slog.F("agent_id", id)) - c.mutex.Lock() - defer c.mutex.Unlock() - if c.closed { - return nil, xerrors.New("coordinator is closed") - } - - overwrites := int64(0) - // If an old agent socket is connected, we Close it to avoid any leaks. This - // shouldn't ever occur because we expect one agent to be running, but it's - // possible for a race condition to happen when an agent is disconnected and - // attempts to reconnect before the server realizes the old connection is - // dead. - oldAgentSocket, ok := c.agentSockets[id] - if ok { - overwrites = oldAgentSocket.Overwrites() + 1 - _ = oldAgentSocket.Close() - } - tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, QueueKindAgent) - c.agentNameCache.Add(id, name) - - sockets, ok := c.agentToConnectionSockets[id] - if ok { - // Publish all nodes that want to connect to the - // desired agent ID. - nodes := make([]*Node, 0, len(sockets)) - for targetID := range sockets { - node, ok := c.nodes[targetID] - if !ok { - continue - } - nodes = append(nodes, node) - } - err := tc.Enqueue(nodes) - // this should never error since we're still the only goroutine that - // knows about the TrackedConn. If we hit an error something really - // wrong is happening - if err != nil { - logger.Critical(ctx, "unable to queue initial nodes", slog.Error(err)) - return nil, err - } - logger.Debug(ctx, "wrote initial client(s) to agent", slog.F("nodes", nodes)) - } - - c.agentSockets[id] = tc - for clientID := range c.agentToConnectionSockets[id] { - c.clientsToAgents[clientID][id] = tc - } - - logger.Debug(ctx, "added agent socket") - return tc, nil -} - -func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error { - logger := c.core.agentLogger(id) - var node Node - err := decoder.Decode(&node) - if err != nil { - return xerrors.Errorf("read json: %w", err) - } - logger.Debug(context.Background(), "decoded agent node", slog.F("node", node)) - return c.core.agentNodeUpdate(id, &node) + go v1RespLoop(ctx, cancel, logger, tc, resps) + go v1ReqLoop(ctx, cancel, logger, conn, reqs) + <-ctx.Done() + logger.Debug(ctx, "ending agent connection") + return nil } // This is copied from codersdk because importing it here would cause an import // cycle. This is just temporary until wsconncache is phased out. var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") -// This is temporary until we no longer need to detect for agent backwards -// compatibility. -// See: https://github.com/coder/coder/issues/8218 -func (c *core) agentIsLegacy(agentID uuid.UUID) bool { - c.mutex.RLock() - _, ok := c.legacyAgents[agentID] - c.mutex.RUnlock() - return ok -} - -func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error { - logger := c.agentLogger(id) - c.mutex.Lock() - defer c.mutex.Unlock() - c.nodes[id] = node - - // Keep a cache of all legacy agents. - if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP { - c.legacyAgents[id] = struct{}{} - } - - connectionSockets, ok := c.agentToConnectionSockets[id] - if !ok { - logger.Debug(context.Background(), "no client sockets; unable to send node") - return nil - } - - // Publish the new node to every listening socket. - for clientID, connectionSocket := range connectionSockets { - err := connectionSocket.Enqueue([]*Node{node}) - if err == nil { - logger.Debug(context.Background(), "enqueued agent node to client", - slog.F("client_id", clientID)) - } else { - // queue is backed up for some reason. This is bad, but we don't want to drop - // updates to other clients over it. Log and move on. - logger.Error(context.Background(), "failed to Enqueue", - slog.F("client_id", clientID), slog.Error(err)) - } - } - - return nil -} - // Close closes all of the open connections in the coordinator and stops the // coordinator from accepting new connections. func (c *coordinator) Close() error { - return c.core.close() + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + close(c.closedChan) + c.mu.Unlock() + + err := c.core.close() + // wait for all request loops to complete + c.wg.Wait() + return err } func (c *core) close() error { c.mutex.Lock() + defer c.mutex.Unlock() if c.closed { - c.mutex.Unlock() return nil } c.closed = true - - wg := sync.WaitGroup{} - - wg.Add(len(c.agentSockets)) - for _, socket := range c.agentSockets { - socket := socket - go func() { - _ = socket.CoordinatorClose() - wg.Done() - }() + for id := range c.peers { + // when closing, mark them as LOST so that we don't disrupt in-progress + // connections. + c.removePeerLocked(id, proto.CoordinateResponse_PeerUpdate_LOST, "coordinator close") } - - wg.Add(len(c.clients)) - for _, client := range c.clients { - client := client - go func() { - _ = client.CoordinatorClose() - wg.Done() - }() - } - - c.mutex.Unlock() - - wg.Wait() return nil } @@ -668,161 +632,43 @@ func (c *coordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { c.core.serveHTTPDebug(w, r) } -func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { +func (c *core) serveHTTPDebug(w http.ResponseWriter, _ *http.Request) { + debug := c.getHTMLDebug() + w.Header().Set("Content-Type", "text/html; charset=utf-8") + err := debugTempl.Execute(w, debug) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + return + } +} + +func (c *core) getHTMLDebug() HTMLDebug { c.mutex.RLock() defer c.mutex.RUnlock() - - CoordinatorHTTPDebug( - HTTPDebugFromLocal(false, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache), - )(w, r) -} - -func HTTPDebugFromLocal( - ha bool, - agentSocketsMap map[uuid.UUID]Queue, - agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue, - nodesMap map[uuid.UUID]*Node, - agentNameCache *lru.Cache[uuid.UUID, string], -) HTMLDebug { - now := time.Now() - data := HTMLDebug{HA: ha} - for id, conn := range agentSocketsMap { - start, lastWrite := conn.Stats() - agent := &HTMLAgent{ - Name: conn.Name(), - ID: id, - CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), - LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), - Overwrites: int(conn.Overwrites()), - } - - for id, conn := range agentToConnectionSocketsMap[id] { - start, lastWrite := conn.Stats() - agent.Connections = append(agent.Connections, &HTMLClient{ - Name: conn.Name(), - ID: id, - CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), - LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), - }) - } - slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int { - return slice.Ascending(a.Name, b.Name) - }) - - data.Agents = append(data.Agents, agent) - } - slices.SortFunc(data.Agents, func(a, b *HTMLAgent) int { - return slice.Ascending(a.Name, b.Name) - }) - - for agentID, conns := range agentToConnectionSocketsMap { - if len(conns) == 0 { - continue - } - - if _, ok := agentSocketsMap[agentID]; ok { - continue - } - - agentName, ok := agentNameCache.Get(agentID) - if !ok { - agentName = "unknown" - } - agent := &HTMLAgent{ - Name: agentName, - ID: agentID, - } - for id, conn := range conns { - start, lastWrite := conn.Stats() - agent.Connections = append(agent.Connections, &HTMLClient{ - Name: conn.Name(), - ID: id, - CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), - LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), - }) - } - slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int { - return slice.Ascending(a.Name, b.Name) - }) - - data.MissingAgents = append(data.MissingAgents, agent) - } - slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) int { - return slice.Ascending(a.Name, b.Name) - }) - - for id, node := range nodesMap { - name, _ := agentNameCache.Get(id) - data.Nodes = append(data.Nodes, &HTMLNode{ - ID: id, - Name: name, - Node: node, - }) - } - slices.SortFunc(data.Nodes, func(a, b *HTMLNode) int { - return slice.Ascending(a.Name+a.ID.String(), b.Name+b.ID.String()) - }) - - return data -} - -func CoordinatorHTTPDebug(data HTMLDebug) func(w http.ResponseWriter, _ *http.Request) { - return func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - - tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{ - "marshal": func(v any) template.JS { - a, err := json.MarshalIndent(v, "", " ") - if err != nil { - //nolint:gosec - return template.JS(fmt.Sprintf(`{"err": %q}`, err)) - } - //nolint:gosec - return template.JS(a) - }, - }).Parse(coordinatorDebugTmpl) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(err.Error())) - return - } - - err = tmpl.Execute(w, data) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(err.Error())) - return - } + debug := HTMLDebug{Tunnels: c.tunnels.htmlDebug()} + for _, p := range c.peers { + debug.Peers = append(debug.Peers, p.htmlDebug()) } + return debug } type HTMLDebug struct { - HA bool - Agents []*HTMLAgent - MissingAgents []*HTMLAgent - Nodes []*HTMLNode + Peers []HTMLPeer + Tunnels []HTMLTunnel } -type HTMLAgent struct { - Name string +type HTMLPeer struct { ID uuid.UUID + Name string CreatedAge time.Duration LastWriteAge time.Duration Overwrites int - Connections []*HTMLClient + Node string } -type HTMLClient struct { - Name string - ID uuid.UUID - CreatedAge time.Duration - LastWriteAge time.Duration -} - -type HTMLNode struct { - ID uuid.UUID - Name string - Node any +type HTMLTunnel struct { + Src, Dst uuid.UUID } var coordinatorDebugTmpl = ` @@ -830,51 +676,143 @@ var coordinatorDebugTmpl = ` + - {{- if .HA }} -

high-availability wireguard coordinator debug

-

warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete

- {{- else }}

in-memory wireguard coordinator debug

- {{- end }} -

# agents: total {{ len .Agents }}

-