diff --git a/.prettierignore b/.prettierignore
index 37cbd3fef3..9be32290ac 100644
--- a/.prettierignore
+++ b/.prettierignore
@@ -83,6 +83,7 @@ helm/**/templates/*.yaml
# Testdata shouldn't be formatted.
scripts/apitypings/testdata/**/*.ts
enterprise/tailnet/testdata/*.golden.html
+tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
site/e2e/provisionerGenerated.ts
diff --git a/.prettierignore.include b/.prettierignore.include
index fd7f94f13d..7efd582e15 100644
--- a/.prettierignore.include
+++ b/.prettierignore.include
@@ -9,6 +9,7 @@ helm/**/templates/*.yaml
# Testdata shouldn't be formatted.
scripts/apitypings/testdata/**/*.ts
enterprise/tailnet/testdata/*.golden.html
+tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
site/e2e/provisionerGenerated.ts
diff --git a/Makefile b/Makefile
index d16263b517..8ee86d1584 100644
--- a/Makefile
+++ b/Makefile
@@ -602,6 +602,7 @@ update-golden-files: \
scripts/ci-report/testdata/.gen-golden \
enterprise/cli/testdata/.gen-golden \
enterprise/tailnet/testdata/.gen-golden \
+ tailnet/testdata/.gen-golden \
coderd/.gen-golden \
provisioner/terraform/testdata/.gen-golden
.PHONY: update-golden-files
@@ -614,6 +615,10 @@ enterprise/cli/testdata/.gen-golden: $(wildcard enterprise/cli/testdata/*.golden
go test ./enterprise/cli -run="TestEnterpriseCommandHelp" -update
touch "$@"
+tailnet/testdata/.gen-golden: $(wildcard tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard tailnet/*_test.go)
+ go test ./tailnet -run="TestDebugTemplate" -update
+ touch "$@"
+
enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard enterprise/tailnet/*_test.go)
go test ./enterprise/tailnet -run="TestDebugTemplate" -update
touch "$@"
diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go
index 83c1d8a2b9..94d080e219 100644
--- a/enterprise/tailnet/connio.go
+++ b/enterprise/tailnet/connio.go
@@ -86,7 +86,7 @@ func (c *connIO) recvLoop() {
if c.disconnected {
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
}
- if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
+ if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
}
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
@@ -97,14 +97,14 @@ func (c *connIO) recvLoop() {
tKey: tKey{src: c.UniqueID()},
active: false,
}
- if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
+ if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
}
}
}()
defer c.Close()
for {
- req, err := recvCtx(c.peerCtx, c.requests)
+ req, err := agpl.RecvCtx(c.peerCtx, c.requests)
if err != nil {
if xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
@@ -132,7 +132,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
node: req.UpdateSelf.Node,
kind: proto.CoordinateResponse_PeerUpdate_NODE,
}
- if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
+ if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
return err
}
@@ -156,7 +156,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
},
active: true,
}
- if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
+ if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
return err
}
@@ -177,7 +177,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
},
active: false,
}
- if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
+ if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
return err
}
diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go
index 5a26cdc92a..068b91160f 100644
--- a/enterprise/tailnet/coordinator.go
+++ b/enterprise/tailnet/coordinator.go
@@ -5,17 +5,22 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
+ "html/template"
"io"
"net"
"net/http"
"sync"
+ "time"
"github.com/google/uuid"
lru "github.com/hashicorp/golang-lru/v2"
+ "golang.org/x/exp/slices"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/pubsub"
+ "github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
agpl "github.com/coder/coder/v2/tailnet"
)
@@ -719,7 +724,209 @@ func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
c.mutex.RLock()
defer c.mutex.RUnlock()
- agpl.CoordinatorHTTPDebug(
- agpl.HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
+ CoordinatorHTTPDebug(
+ HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
)(w, r)
}
+
+func HTTPDebugFromLocal(
+ ha bool,
+ agentSocketsMap map[uuid.UUID]agpl.Queue,
+ agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]agpl.Queue,
+ nodesMap map[uuid.UUID]*agpl.Node,
+ agentNameCache *lru.Cache[uuid.UUID, string],
+) HTMLDebugHA {
+ now := time.Now()
+ data := HTMLDebugHA{HA: ha}
+ for id, conn := range agentSocketsMap {
+ start, lastWrite := conn.Stats()
+ agent := &HTMLAgent{
+ Name: conn.Name(),
+ ID: id,
+ CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
+ LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
+ Overwrites: int(conn.Overwrites()),
+ }
+
+ for id, conn := range agentToConnectionSocketsMap[id] {
+ start, lastWrite := conn.Stats()
+ agent.Connections = append(agent.Connections, &HTMLClient{
+ Name: conn.Name(),
+ ID: id,
+ CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
+ LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
+ })
+ }
+ slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
+ return slice.Ascending(a.Name, b.Name)
+ })
+
+ data.Agents = append(data.Agents, agent)
+ }
+ slices.SortFunc(data.Agents, func(a, b *HTMLAgent) int {
+ return slice.Ascending(a.Name, b.Name)
+ })
+
+ for agentID, conns := range agentToConnectionSocketsMap {
+ if len(conns) == 0 {
+ continue
+ }
+
+ if _, ok := agentSocketsMap[agentID]; ok {
+ continue
+ }
+
+ agentName, ok := agentNameCache.Get(agentID)
+ if !ok {
+ agentName = "unknown"
+ }
+ agent := &HTMLAgent{
+ Name: agentName,
+ ID: agentID,
+ }
+ for id, conn := range conns {
+ start, lastWrite := conn.Stats()
+ agent.Connections = append(agent.Connections, &HTMLClient{
+ Name: conn.Name(),
+ ID: id,
+ CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
+ LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
+ })
+ }
+ slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
+ return slice.Ascending(a.Name, b.Name)
+ })
+
+ data.MissingAgents = append(data.MissingAgents, agent)
+ }
+ slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) int {
+ return slice.Ascending(a.Name, b.Name)
+ })
+
+ for id, node := range nodesMap {
+ name, _ := agentNameCache.Get(id)
+ data.Nodes = append(data.Nodes, &HTMLNode{
+ ID: id,
+ Name: name,
+ Node: node,
+ })
+ }
+ slices.SortFunc(data.Nodes, func(a, b *HTMLNode) int {
+ return slice.Ascending(a.Name+a.ID.String(), b.Name+b.ID.String())
+ })
+
+ return data
+}
+
+func CoordinatorHTTPDebug(data HTMLDebugHA) func(w http.ResponseWriter, _ *http.Request) {
+ return func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+
+ tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{
+ "marshal": func(v any) template.JS {
+ a, err := json.MarshalIndent(v, "", " ")
+ if err != nil {
+ //nolint:gosec
+ return template.JS(fmt.Sprintf(`{"err": %q}`, err))
+ }
+ //nolint:gosec
+ return template.JS(a)
+ },
+ }).Parse(haCoordinatorDebugTmpl)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte(err.Error()))
+ return
+ }
+
+ err = tmpl.Execute(w, data)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte(err.Error()))
+ return
+ }
+ }
+}
+
+type HTMLDebugHA struct {
+ HA bool
+ Agents []*HTMLAgent
+ MissingAgents []*HTMLAgent
+ Nodes []*HTMLNode
+}
+
+type HTMLAgent struct {
+ Name string
+ ID uuid.UUID
+ CreatedAge time.Duration
+ LastWriteAge time.Duration
+ Overwrites int
+ Connections []*HTMLClient
+}
+
+type HTMLClient struct {
+ Name string
+ ID uuid.UUID
+ CreatedAge time.Duration
+ LastWriteAge time.Duration
+}
+
+type HTMLNode struct {
+ ID uuid.UUID
+ Name string
+ Node any
+}
+
+var haCoordinatorDebugTmpl = `
+
+
+
+
+
+
+ {{- if .HA }}
+ high-availability wireguard coordinator debug
+ warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete
+ {{- else }}
+ in-memory wireguard coordinator debug
+ {{- end }}
+
+ # agents: total {{ len .Agents }}
+
+ {{- range .Agents }}
+ -
+ {{ .Name }} (
{{ .ID }}
): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago, overwrites {{ .Overwrites }}
+ connections: total {{ len .Connections}}
+
+ {{- range .Connections }}
+ - {{ .Name }} (
{{ .ID }}
): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago
+ {{- end }}
+
+
+ {{- end }}
+
+
+ # missing agents: total {{ len .MissingAgents }}
+
+ {{- range .MissingAgents}}
+ - {{ .Name }} (
{{ .ID }}
): created ? ago, write ? ago, overwrites ?
+ connections: total {{ len .Connections }}
+
+ {{- range .Connections }}
+ - {{ .Name }} (
{{ .ID }}
): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago
+ {{- end }}
+
+ {{- end }}
+
+
+ # nodes: total {{ len .Nodes }}
+
+ {{- range .Nodes }}
+ - {{ .Name }} (
{{ .ID }}
):
+ {{ marshal .Node }}
+
+ {{- end }}
+
+
+
+`
diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go
index 635d1d4fce..7bc915801e 100644
--- a/enterprise/tailnet/pgcoord.go
+++ b/enterprise/tailnet/pgcoord.go
@@ -3,17 +3,12 @@ package tailnet
import (
"context"
"database/sql"
- "encoding/json"
- "io"
"net"
- "net/netip"
"strings"
"sync"
"sync/atomic"
"time"
- "nhooyr.io/websocket"
-
"github.com/coder/coder/v2/tailnet/proto"
"github.com/cenkalti/backoff/v4"
@@ -30,17 +25,16 @@ import (
)
const (
- EventHeartbeats = "tailnet_coordinator_heartbeat"
- eventPeerUpdate = "tailnet_peer_update"
- eventTunnelUpdate = "tailnet_tunnel_update"
- HeartbeatPeriod = time.Second * 2
- MissedHeartbeats = 3
- numQuerierWorkers = 10
- numBinderWorkers = 10
- numTunnelerWorkers = 10
- dbMaxBackoff = 10 * time.Second
- cleanupPeriod = time.Hour
- requestResponseBuffSize = 32
+ EventHeartbeats = "tailnet_coordinator_heartbeat"
+ eventPeerUpdate = "tailnet_peer_update"
+ eventTunnelUpdate = "tailnet_tunnel_update"
+ HeartbeatPeriod = time.Second * 2
+ MissedHeartbeats = 3
+ numQuerierWorkers = 10
+ numBinderWorkers = 10
+ numTunnelerWorkers = 10
+ dbMaxBackoff = 10 * time.Second
+ cleanupPeriod = time.Hour
)
// pgCoord is a postgres-backed coordinator
@@ -161,55 +155,8 @@ func NewPGCoordV2(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, sto
return newPGCoordInternal(ctx, logger, ps, store)
}
-// This is copied from codersdk because importing it here would cause an import
-// cycle. This is just temporary until wsconncache is phased out.
-var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
-
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
- logger := c.logger.With(slog.F("client_id", id)).Named("multiagent")
- ctx, cancel := context.WithCancel(c.ctx)
- reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.SingleTailnetTunnelAuth{})
- ma := (&agpl.MultiAgent{
- ID: id,
- AgentIsLegacyFunc: func(agentID uuid.UUID) bool {
- if n := c.Node(agentID); n == nil {
- // If we don't have the node at all assume it's legacy for
- // safety.
- return true
- } else if len(n.Addresses) > 0 && n.Addresses[0].Addr() == legacyAgentIP {
- // An agent is determined to be "legacy" if it's first IP is the
- // legacy IP. Agents with only the legacy IP aren't compatible
- // with single_tailnet and must be routed through wsconncache.
- return true
- } else {
- return false
- }
- },
- OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) {
- err := sendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}})
- return c.Node(agent), err
- },
- OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error {
- err := sendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}})
- return err
- },
- OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error {
- pn, err := agpl.NodeToProto(node)
- if err != nil {
- return err
- }
- return sendCtx(c.ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
- Node: pn,
- }})
- },
- OnRemove: func(_ agpl.Queue) {
- _ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
- cancel()
- },
- }).Init()
-
- go v1SendLoop(ctx, cancel, logger, ma, resps)
- return ma
+ return agpl.ServeMultiAgent(c, c.logger, id)
}
func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
@@ -253,116 +200,11 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
}
func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
- logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
- defer func() {
- err := conn.Close()
- if err != nil {
- logger.Debug(c.ctx, "closing client connection", slog.Error(err))
- }
- }()
- ctx, cancel := context.WithCancel(c.ctx)
- defer cancel()
- reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.ClientTunnelAuth{AgentID: agent})
- err := sendCtx(ctx, reqs, &proto.CoordinateRequest{
- AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)},
- })
- if err != nil {
- // can only be a context error, no need to log here.
- return err
- }
- defer func() {
- _ = sendCtx(ctx, reqs, &proto.CoordinateRequest{
- RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)},
- })
- }()
-
- tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient)
- go tc.SendUpdates()
- go v1SendLoop(ctx, cancel, logger, tc, resps)
- go v1RecvLoop(ctx, cancel, logger, conn, reqs)
- <-ctx.Done()
- return nil
+ return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent)
}
func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
- logger := c.logger.With(slog.F("agent_id", id), slog.F("name", name))
- defer func() {
- logger.Debug(c.ctx, "closing agent connection")
- err := conn.Close()
- logger.Debug(c.ctx, "closed agent connection", slog.Error(err))
- }()
- ctx, cancel := context.WithCancel(c.ctx)
- defer cancel()
- reqs, resps := c.Coordinate(ctx, id, name, agpl.AgentTunnelAuth{})
- tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, agpl.QueueKindAgent)
- go tc.SendUpdates()
- go v1SendLoop(ctx, cancel, logger, tc, resps)
- go v1RecvLoop(ctx, cancel, logger, conn, reqs)
- <-ctx.Done()
- return nil
-}
-
-func v1RecvLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger,
- conn net.Conn, reqs chan<- *proto.CoordinateRequest,
-) {
- defer cancel()
- decoder := json.NewDecoder(conn)
- for {
- var node agpl.Node
- err := decoder.Decode(&node)
- if err != nil {
- if xerrors.Is(err, io.EOF) ||
- xerrors.Is(err, io.ErrClosedPipe) ||
- xerrors.Is(err, context.Canceled) ||
- xerrors.Is(err, context.DeadlineExceeded) ||
- websocket.CloseStatus(err) > 0 {
- logger.Debug(ctx, "exiting recvLoop", slog.Error(err))
- } else {
- logger.Error(ctx, "failed to decode Node update", slog.Error(err))
- }
- return
- }
- logger.Debug(ctx, "got node update", slog.F("node", node))
- pn, err := agpl.NodeToProto(&node)
- if err != nil {
- logger.Critical(ctx, "failed to convert v1 node", slog.F("node", node), slog.Error(err))
- return
- }
- req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
- Node: pn,
- }}
- if err := sendCtx(ctx, reqs, req); err != nil {
- logger.Debug(ctx, "recvLoop ctx expired", slog.Error(err))
- return
- }
- }
-}
-
-func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q agpl.Queue, resps <-chan *proto.CoordinateResponse) {
- defer cancel()
- for {
- resp, err := recvCtx(ctx, resps)
- if err != nil {
- logger.Debug(ctx, "done reading responses", slog.Error(err))
- return
- }
- logger.Debug(ctx, "v1: got response", slog.F("resp", resp))
- nodes, err := agpl.OnlyNodeUpdates(resp)
- if err != nil {
- logger.Critical(ctx, "failed to decode resp", slog.F("resp", resp), slog.Error(err))
- _ = q.CoordinatorClose()
- return
- }
- // don't send empty updates
- if len(nodes) == 0 {
- logger.Debug(ctx, "skipping enqueueing 0-length v1 update")
- continue
- }
- err = q.Enqueue(nodes)
- if err != nil {
- logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err))
- }
- }
+ return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name)
}
func (c *pgCoord) Close() error {
@@ -378,43 +220,22 @@ func (c *pgCoord) Coordinate(
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
) {
logger := c.logger.With(slog.F("peer_id", id))
- reqs := make(chan *proto.CoordinateRequest, requestResponseBuffSize)
+ reqs := make(chan *proto.CoordinateRequest, agpl.RequestBufferSize)
resps := make(chan *proto.CoordinateResponse, agpl.ResponseBufferSize)
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
- err := sendCtx(c.ctx, c.newConnections, cIO)
+ err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
if err != nil {
// this can only happen if the context is canceled, no need to log
return reqs, resps
}
go func() {
<-cIO.Done()
- _ = sendCtx(c.ctx, c.closeConnections, cIO)
+ _ = agpl.SendCtx(c.ctx, c.closeConnections, cIO)
}()
return reqs, resps
}
-func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case c <- a:
- return nil
- }
-}
-
-func recvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
- select {
- case <-ctx.Done():
- return a, ctx.Err()
- case a, ok := <-c:
- if ok {
- return a, nil
- }
- return a, io.EOF
- }
-}
-
type tKey struct {
src uuid.UUID
dst uuid.UUID
@@ -873,7 +694,7 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
eq, err := sm.node.Equal(mpng.node)
if err != nil {
- m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind))
+ m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.node))
continue
}
if eq {
@@ -1303,7 +1124,7 @@ func (q *querier) updateAll() {
go func(m *mapper) {
// make sure we send on the _mapper_ context, not our own in case the mapper is
// shutting down or shut down.
- _ = sendCtx(m.ctx, m.update, struct{}{})
+ _ = agpl.SendCtx(m.ctx, m.update, struct{}{})
}(mpr)
}
}
@@ -1603,7 +1424,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) {
h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id))
// send on a separate goroutine to avoid holding lock. Triggering update can be async
go func() {
- _ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
+ _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
}()
}
h.coordinators[id] = time.Now()
@@ -1650,7 +1471,7 @@ func (h *heartbeats) checkExpiry() {
if expired {
// send on a separate goroutine to avoid holding lock. Triggering update can be async
go func() {
- _ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
+ _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
}()
}
// we need to reset the timer for when the next oldest coordinator will expire, if any.
@@ -1685,14 +1506,14 @@ func (h *heartbeats) sendBeat() {
h.failedHeartbeats++
if h.failedHeartbeats == 3 {
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
- _ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
+ _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
}
return
}
h.logger.Debug(h.ctx, "sent heartbeat")
if h.failedHeartbeats >= 3 {
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
- _ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
+ _ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
}
h.failedHeartbeats = 0
}
diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go
index 920bb8a2c8..8a581f48b6 100644
--- a/enterprise/tailnet/pgcoord_test.go
+++ b/enterprise/tailnet/pgcoord_test.go
@@ -9,6 +9,8 @@ import (
"testing"
"time"
+ agpltest "github.com/coder/coder/v2/tailnet/test"
+
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -612,18 +614,7 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
-
- p1 := newTestPeer(ctx, t, coordinator, "p1")
- defer p1.close(ctx)
- p2 := newTestPeer(ctx, t, coordinator, "p2")
- defer p2.close(ctx)
- p1.addTunnel(p2.id)
- p2.addTunnel(p1.id)
- p1.updateDERP(1)
- p2.updateDERP(2)
-
- p1.assertEventuallyHasDERP(p2.id, 2)
- p2.assertEventuallyHasDERP(p1.id, 1)
+ agpltest.BidirectionalTunnels(ctx, t, coordinator)
}
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
@@ -638,21 +629,7 @@ func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
-
- p1 := newTestPeer(ctx, t, coordinator, "p1")
- defer p1.close(ctx)
- p2 := newTestPeer(ctx, t, coordinator, "p2")
- defer p2.close(ctx)
- p1.addTunnel(p2.id)
- p1.updateDERP(1)
- p2.updateDERP(2)
-
- p1.assertEventuallyHasDERP(p2.id, 2)
- p2.assertEventuallyHasDERP(p1.id, 1)
-
- p2.disconnect()
- p1.assertEventuallyDisconnected(p2.id)
- p2.assertEventuallyResponsesClosed()
+ agpltest.GracefulDisconnectTest(ctx, t, coordinator)
}
func TestPGCoordinator_Lost(t *testing.T) {
@@ -667,20 +644,7 @@ func TestPGCoordinator_Lost(t *testing.T) {
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
-
- p1 := newTestPeer(ctx, t, coordinator, "p1")
- defer p1.close(ctx)
- p2 := newTestPeer(ctx, t, coordinator, "p2")
- defer p2.close(ctx)
- p1.addTunnel(p2.id)
- p1.updateDERP(1)
- p2.updateDERP(2)
-
- p1.assertEventuallyHasDERP(p2.id, 2)
- p2.assertEventuallyHasDERP(p1.id, 1)
-
- p2.close(ctx)
- p1.assertEventuallyLost(p2.id)
+ agpltest.LostTest(ctx, t, coordinator)
}
type testConn struct {
@@ -918,177 +882,6 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
}, testutil.WaitShort, testutil.IntervalFast)
}
-type peerStatus struct {
- preferredDERP int32
- status proto.CoordinateResponse_PeerUpdate_Kind
-}
-
-type testPeer struct {
- ctx context.Context
- cancel context.CancelFunc
- t testing.TB
- id uuid.UUID
- name string
- resps <-chan *proto.CoordinateResponse
- reqs chan<- *proto.CoordinateRequest
- peers map[uuid.UUID]peerStatus
-}
-
-func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
- p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)}
- p.ctx, p.cancel = context.WithCancel(ctx)
- if len(id) > 1 {
- t.Fatal("too many")
- }
- if len(id) == 1 {
- p.id = id[0]
- } else {
- p.id = uuid.New()
- }
- // SingleTailnetTunnelAuth allows connections to arbitrary peers
- p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{})
- return p
-}
-
-func (p *testPeer) addTunnel(other uuid.UUID) {
- p.t.Helper()
- req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}}
- select {
- case <-p.ctx.Done():
- p.t.Errorf("timeout adding tunnel for %s", p.name)
- return
- case p.reqs <- req:
- return
- }
-}
-
-func (p *testPeer) updateDERP(derp int32) {
- p.t.Helper()
- req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
- select {
- case <-p.ctx.Done():
- p.t.Errorf("timeout updating node for %s", p.name)
- return
- case p.reqs <- req:
- return
- }
-}
-
-func (p *testPeer) disconnect() {
- p.t.Helper()
- req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
- select {
- case <-p.ctx.Done():
- p.t.Errorf("timeout updating node for %s", p.name)
- return
- case p.reqs <- req:
- return
- }
-}
-
-func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
- p.t.Helper()
- for {
- o, ok := p.peers[other]
- if ok && o.preferredDERP == derp {
- return
- }
- if err := p.handleOneResp(); err != nil {
- assert.NoError(p.t, err)
- return
- }
- }
-}
-
-func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) {
- p.t.Helper()
- for {
- _, ok := p.peers[other]
- if !ok {
- return
- }
- if err := p.handleOneResp(); err != nil {
- assert.NoError(p.t, err)
- return
- }
- }
-}
-
-func (p *testPeer) assertEventuallyLost(other uuid.UUID) {
- p.t.Helper()
- for {
- o := p.peers[other]
- if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
- return
- }
- if err := p.handleOneResp(); err != nil {
- assert.NoError(p.t, err)
- return
- }
- }
-}
-
-func (p *testPeer) assertEventuallyResponsesClosed() {
- p.t.Helper()
- for {
- err := p.handleOneResp()
- if xerrors.Is(err, responsesClosed) {
- return
- }
- if !assert.NoError(p.t, err) {
- return
- }
- }
-}
-
-var responsesClosed = xerrors.New("responses closed")
-
-func (p *testPeer) handleOneResp() error {
- select {
- case <-p.ctx.Done():
- return p.ctx.Err()
- case resp, ok := <-p.resps:
- if !ok {
- return responsesClosed
- }
- for _, update := range resp.PeerUpdates {
- id, err := uuid.FromBytes(update.Uuid)
- if err != nil {
- return err
- }
- switch update.Kind {
- case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
- p.peers[id] = peerStatus{
- preferredDERP: update.GetNode().GetPreferredDerp(),
- status: update.Kind,
- }
- case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
- delete(p.peers, id)
- default:
- return xerrors.Errorf("unhandled update kind %s", update.Kind)
- }
- }
- }
- return nil
-}
-
-func (p *testPeer) close(ctx context.Context) {
- p.t.Helper()
- p.cancel()
- for {
- select {
- case <-ctx.Done():
- p.t.Errorf("timeout waiting for responses to close for %s", p.name)
- return
- case _, ok := <-p.resps:
- if ok {
- continue
- }
- return
- }
- }
-}
-
type fakeCoordinator struct {
ctx context.Context
t *testing.T
diff --git a/site/.eslintignore b/site/.eslintignore
index 033d259091..fc74ff6dd9 100644
--- a/site/.eslintignore
+++ b/site/.eslintignore
@@ -83,6 +83,7 @@ result
# Testdata shouldn't be formatted.
../scripts/apitypings/testdata/**/*.ts
../enterprise/tailnet/testdata/*.golden.html
+../tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
e2e/provisionerGenerated.ts
diff --git a/site/.prettierignore b/site/.prettierignore
index 033d259091..fc74ff6dd9 100644
--- a/site/.prettierignore
+++ b/site/.prettierignore
@@ -83,6 +83,7 @@ result
# Testdata shouldn't be formatted.
../scripts/apitypings/testdata/**/*.ts
../enterprise/tailnet/testdata/*.golden.html
+../tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
e2e/provisionerGenerated.ts
diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go
index 2da96bc444..2c67e5ff72 100644
--- a/tailnet/coordinator.go
+++ b/tailnet/coordinator.go
@@ -3,8 +3,6 @@ package tailnet
import (
"context"
"encoding/json"
- "errors"
- "fmt"
"html/template"
"io"
"net"
@@ -14,14 +12,12 @@ import (
"time"
"github.com/google/uuid"
- lru "github.com/hashicorp/golang-lru/v2"
- "golang.org/x/exp/slices"
"golang.org/x/xerrors"
+ "nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
- "github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/tailnet/proto"
)
@@ -131,12 +127,29 @@ func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func
const LoggerName = "coord"
+var (
+ ErrClosed = xerrors.New("coordinator is closed")
+ ErrWouldBlock = xerrors.New("would block")
+ ErrAlreadyRemoved = xerrors.New("already removed")
+)
+
// NewCoordinator constructs a new in-memory connection coordinator. This
// coordinator is incompatible with multiple Coder replicas as all node data is
// in-memory.
func NewCoordinator(logger slog.Logger) Coordinator {
return &coordinator{
- core: newCore(logger),
+ core: newCore(logger.Named(LoggerName)),
+ closedChan: make(chan struct{}),
+ }
+}
+
+// NewCoordinatorV2 constructs a new in-memory connection coordinator. This
+// coordinator is incompatible with multiple Coder replicas as all node data is
+// in-memory.
+func NewCoordinatorV2(logger slog.Logger) CoordinatorV2 {
+ return &coordinator{
+ core: newCore(logger.Named(LoggerName)),
+ closedChan: make(chan struct{}),
}
}
@@ -149,26 +162,112 @@ func NewCoordinator(logger slog.Logger) Coordinator {
// data is in-memory.
type coordinator struct {
core *core
+
+ mu sync.Mutex
+ closed bool
+ wg sync.WaitGroup
+ closedChan chan struct{}
+}
+
+func (c *coordinator) Coordinate(
+ ctx context.Context, id uuid.UUID, name string, a TunnelAuth,
+) (
+ chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
+) {
+ logger := c.core.logger.With(
+ slog.F("peer_id", id),
+ slog.F("peer_name", name),
+ )
+ reqs := make(chan *proto.CoordinateRequest, RequestBufferSize)
+ resps := make(chan *proto.CoordinateResponse, ResponseBufferSize)
+
+ p := &peer{
+ logger: logger,
+ id: id,
+ name: name,
+ resps: resps,
+ reqs: reqs,
+ auth: a,
+ sent: make(map[uuid.UUID]*proto.Node),
+ }
+ err := c.core.initPeer(p)
+ if err != nil {
+ if xerrors.Is(err, ErrClosed) {
+ logger.Debug(ctx, "coordinate failed: Coordinator is closed")
+ } else {
+ logger.Critical(ctx, "coordinate failed: %s", err.Error())
+ }
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ // don't start the readLoop if we are closed.
+ return reqs, resps
+ }
+ c.wg.Add(1)
+ go func() {
+ defer c.wg.Done()
+ p.reqLoop(ctx, logger, c.core.handleRequest)
+ err := c.core.lostPeer(p)
+ if xerrors.Is(err, ErrClosed) || xerrors.Is(err, ErrAlreadyRemoved) {
+ return
+ }
+ if err != nil {
+ logger.Error(context.Background(), "failed to process lost peer", slog.Error(err))
+ }
+ }()
+ return reqs, resps
}
func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
- m := (&MultiAgent{
- ID: id,
- AgentIsLegacyFunc: c.core.agentIsLegacy,
- OnSubscribe: c.core.clientSubscribeToAgent,
- OnUnsubscribe: c.core.clientUnsubscribeFromAgent,
- OnNodeUpdate: c.core.clientNodeUpdate,
- OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) },
- }).Init()
- c.core.addClient(id, m)
- return m
+ return ServeMultiAgent(c, c.core.logger, id)
}
-func (c *core) addClient(id uuid.UUID, ma Queue) {
- c.mutex.Lock()
- c.clients[id] = ma
- c.clientsToAgents[id] = map[uuid.UUID]Queue{}
- c.mutex.Unlock()
+func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn {
+ logger = logger.With(slog.F("client_id", id)).Named("multiagent")
+ ctx, cancel := context.WithCancel(context.Background())
+ reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetTunnelAuth{})
+ m := (&MultiAgent{
+ ID: id,
+ AgentIsLegacyFunc: func(agentID uuid.UUID) bool {
+ if n := c.Node(agentID); n == nil {
+ // If we don't have the node at all assume it's legacy for
+ // safety.
+ return true
+ } else if len(n.Addresses) > 0 && n.Addresses[0].Addr() == legacyAgentIP {
+ // An agent is determined to be "legacy" if it's first IP is the
+ // legacy IP. Agents with only the legacy IP aren't compatible
+ // with single_tailnet and must be routed through wsconncache.
+ return true
+ } else {
+ return false
+ }
+ },
+ OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) {
+ err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: UUIDToByteSlice(agent)}})
+ return c.Node(agent), err
+ },
+ OnUnsubscribe: func(enq Queue, agent uuid.UUID) error {
+ err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: UUIDToByteSlice(agent)}})
+ return err
+ },
+ OnNodeUpdate: func(id uuid.UUID, node *Node) error {
+ pn, err := NodeToProto(node)
+ if err != nil {
+ return err
+ }
+ return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
+ Node: pn,
+ }})
+ },
+ OnRemove: func(_ Queue) {
+ _ = SendCtx(ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
+ cancel()
+ },
+ }).Init()
+
+ go v1RespLoop(ctx, cancel, logger, m, resps)
+ return m
}
// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines;
@@ -178,29 +277,8 @@ type core struct {
mutex sync.RWMutex
closed bool
- // nodes maps agent and connection IDs their respective node.
- nodes map[uuid.UUID]*Node
- // agentSockets maps agent IDs to their open websocket.
- agentSockets map[uuid.UUID]Queue
- // agentToConnectionSockets maps agent IDs to connection IDs of conns that
- // are subscribed to updates for that agent.
- agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue
-
- // clients holds a map of all clients connected to the coordinator. This is
- // necessary because a client may not be subscribed into any agents.
- clients map[uuid.UUID]Queue
- // clientsToAgents is an index of clients to all of their subscribed agents.
- clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue
-
- // agentNameCache holds a cache of agent names. If one of them disappears,
- // it's helpful to have a name cached for debugging.
- agentNameCache *lru.Cache[uuid.UUID, string]
-
- // legacyAgents holda a mapping of all agents detected as legacy, meaning
- // they only listen on codersdk.WorkspaceAgentIP. They aren't compatible
- // with the new ServerTailnet, so they must be connected through
- // wsconncache.
- legacyAgents map[uuid.UUID]struct{}
+ peers map[uuid.UUID]*peer
+ tunnels *tunnelStore
}
type QueueKind int
@@ -225,26 +303,14 @@ type Queue interface {
}
func newCore(logger slog.Logger) *core {
- nameCache, err := lru.New[uuid.UUID, string](512)
- if err != nil {
- panic("make lru cache: " + err.Error())
- }
-
return &core{
- logger: logger,
- closed: false,
- nodes: map[uuid.UUID]*Node{},
- agentSockets: map[uuid.UUID]Queue{},
- agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{},
- agentNameCache: nameCache,
- legacyAgents: map[uuid.UUID]struct{}{},
- clients: map[uuid.UUID]Queue{},
- clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{},
+ logger: logger,
+ closed: false,
+ peers: make(map[uuid.UUID]*peer),
+ tunnels: newTunnelStore(),
}
}
-var ErrWouldBlock = xerrors.New("would block")
-
// Node returns an in-memory node by ID.
// If the node does not exist, nil is returned.
func (c *coordinator) Node(id uuid.UUID) *Node {
@@ -254,27 +320,17 @@ func (c *coordinator) Node(id uuid.UUID) *Node {
func (c *core) node(id uuid.UUID) *Node {
c.mutex.Lock()
defer c.mutex.Unlock()
- return c.nodes[id]
-}
-
-func (c *coordinator) NodeCount() int {
- return c.core.nodeCount()
-}
-
-func (c *core) nodeCount() int {
- c.mutex.Lock()
- defer c.mutex.Unlock()
- return len(c.nodes)
-}
-
-func (c *coordinator) AgentCount() int {
- return c.core.agentCount()
-}
-
-func (c *core) agentCount() int {
- c.mutex.Lock()
- defer c.mutex.Unlock()
- return len(c.agentSockets)
+ p := c.peers[id]
+ if p == nil || p.node == nil {
+ return nil
+ }
+ v1Node, err := ProtoToNode(p.node)
+ if err != nil {
+ c.logger.Critical(context.Background(),
+ "failed to convert node", slog.Error(err), slog.F("node", p.node))
+ return nil
+ }
+ return v1Node
}
// ServeClient accepts a WebSocket connection that wants to connect to an agent
@@ -282,180 +338,229 @@ func (c *core) agentCount() int {
func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- logger := c.core.clientLogger(id, agentID)
- logger.Debug(ctx, "coordinating client")
+ return ServeClientV1(ctx, c.core.logger, c, conn, id, agentID)
+}
+
+// ServeClientV1 adapts a v1 Client to a v2 Coordinator
+func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
+ logger = logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
+ defer func() {
+ err := conn.Close()
+ if err != nil {
+ logger.Debug(ctx, "closing client connection", slog.Error(err))
+ }
+ }()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ reqs, resps := c.Coordinate(ctx, id, id.String(), ClientTunnelAuth{AgentID: agent})
+ err := SendCtx(ctx, reqs, &proto.CoordinateRequest{
+ AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: UUIDToByteSlice(agent)},
+ })
+ if err != nil {
+ // can only be a context error, no need to log here.
+ return err
+ }
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient)
- defer tc.Close()
-
- c.core.addClient(id, tc)
- defer c.core.clientDisconnected(id)
-
- agentNode, err := c.core.clientSubscribeToAgent(tc, agentID)
- if err != nil {
- return xerrors.Errorf("subscribe agent: %w", err)
- }
-
- if agentNode != nil {
- err := tc.Enqueue([]*Node{agentNode})
- if err != nil {
- logger.Debug(ctx, "enqueue initial node", slog.Error(err))
- }
- }
-
- // On this goroutine, we read updates from the client and publish them. We start a second goroutine
- // to write updates back to the client.
go tc.SendUpdates()
-
- decoder := json.NewDecoder(conn)
- for {
- err := c.handleNextClientMessage(id, decoder)
- if err != nil {
- logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err))
- if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
- return nil
- }
- return xerrors.Errorf("handle next client message: %w", err)
- }
- }
-}
-
-func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger {
- return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
-}
-
-func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) {
- connectionSockets, ok := c.agentToConnectionSockets[agentID]
- if !ok {
- connectionSockets = map[uuid.UUID]Queue{}
- c.agentToConnectionSockets[agentID] = connectionSockets
- }
- connectionSockets[enq.UniqueID()] = enq
-
- c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID]
-}
-
-func (c *core) clientDisconnected(id uuid.UUID) {
- logger := c.clientLogger(id, uuid.Nil)
- c.mutex.Lock()
- defer c.mutex.Unlock()
- // Clean all traces of this connection from the map.
- delete(c.nodes, id)
- logger.Debug(context.Background(), "deleted client node")
-
- for agentID := range c.clientsToAgents[id] {
- connectionSockets, ok := c.agentToConnectionSockets[agentID]
- if !ok {
- continue
- }
- delete(connectionSockets, id)
- logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID))
-
- if len(connectionSockets) == 0 {
- delete(c.agentToConnectionSockets, agentID)
- logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID))
- }
- }
-
- delete(c.clients, id)
- delete(c.clientsToAgents, id)
- logger.Debug(context.Background(), "deleted client agents")
-}
-
-func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
- logger := c.core.clientLogger(id, uuid.Nil)
-
- var node Node
- err := decoder.Decode(&node)
- if err != nil {
- return xerrors.Errorf("read json: %w", err)
- }
-
- logger.Debug(context.Background(), "got client node update", slog.F("node", node))
- return c.core.clientNodeUpdate(id, &node)
-}
-
-func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error {
- c.mutex.Lock()
- defer c.mutex.Unlock()
-
- // Update the node of this client in our in-memory map. If an agent entirely
- // shuts down and reconnects, it needs to be aware of all clients attempting
- // to establish connections.
- c.nodes[id] = node
-
- return c.clientNodeUpdateLocked(id, node)
-}
-
-func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error {
- logger := c.clientLogger(id, uuid.Nil)
-
- agents := []uuid.UUID{}
- for agentID, agentSocket := range c.clientsToAgents[id] {
- if agentSocket == nil {
- logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID))
- continue
- }
-
- err := agentSocket.Enqueue([]*Node{node})
- if err != nil {
- logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID))
- continue
- }
- agents = append(agents, agentID)
- }
-
- logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents))
+ go v1RespLoop(ctx, cancel, logger, tc, resps)
+ go v1ReqLoop(ctx, cancel, logger, conn, reqs)
+ <-ctx.Done()
return nil
}
-func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) {
+func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
c.mutex.Lock()
defer c.mutex.Unlock()
+ if c.closed {
+ return ErrClosed
+ }
+ pr, ok := c.peers[p.id]
+ if !ok || pr != p {
+ return ErrAlreadyRemoved
+ }
+ if req.UpdateSelf != nil {
+ err := c.nodeUpdateLocked(p, req.UpdateSelf.Node)
+ if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
+ return nil
+ }
+ if err != nil {
+ return xerrors.Errorf("node update failed: %w", err)
+ }
+ }
+ if req.AddTunnel != nil {
+ dstID, err := uuid.FromBytes(req.AddTunnel.Uuid)
+ if err != nil {
+ // this shouldn't happen unless there is a client error. Close the connection so the client
+ // doesn't just happily continue thinking everything is fine.
+ return xerrors.Errorf("unable to convert bytes to UUID: %w", err)
+ }
+ err = c.addTunnelLocked(p, dstID)
+ if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
+ return nil
+ }
+ if err != nil {
+ return xerrors.Errorf("add tunnel failed: %w", err)
+ }
+ }
+ if req.RemoveTunnel != nil {
+ dstID, err := uuid.FromBytes(req.RemoveTunnel.Uuid)
+ if err != nil {
+ // this shouldn't happen unless there is a client error. Close the connection so the client
+ // doesn't just happily continue thinking everything is fine.
+ return xerrors.Errorf("unable to convert bytes to UUID: %w", err)
+ }
+ err = c.removeTunnelLocked(p, dstID)
+ if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
+ return nil
+ }
+ if err != nil {
+ return xerrors.Errorf("remove tunnel failed: %w", err)
+ }
+ }
+ if req.Disconnect != nil {
+ c.removePeerLocked(p.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "graceful disconnect")
+ }
+ return nil
+}
- logger := c.clientLogger(enq.UniqueID(), agentID)
+func (c *core) nodeUpdateLocked(p *peer, node *proto.Node) error {
+ c.logger.Debug(context.Background(), "processing node update",
+ slog.F("peer_id", p.id),
+ slog.F("node", node.String()))
- c.initOrSetAgentConnectionSocketLocked(agentID, enq)
+ p.node = node
+ c.updateTunnelPeersLocked(p.id, node, proto.CoordinateResponse_PeerUpdate_NODE, "node update")
+ return nil
+}
- node, ok := c.nodes[enq.UniqueID()]
+func (c *core) updateTunnelPeersLocked(id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) {
+ tp := c.tunnels.findTunnelPeers(id)
+ c.logger.Debug(context.Background(), "got tunnel peers", slog.F("peer_id", id), slog.F("tunnel_peers", tp))
+ for _, otherID := range tp {
+ other, ok := c.peers[otherID]
+ if !ok {
+ continue
+ }
+ err := other.updateMappingLocked(id, n, k, reason)
+ if err != nil {
+ other.logger.Error(context.Background(), "failed to update mapping", slog.Error(err))
+ c.removePeerLocked(other.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update")
+ }
+ }
+}
+
+func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error {
+ if !src.auth.Authorize(dstID) {
+ return xerrors.Errorf("src %s is not allowed to tunnel to %s", src.id, dstID)
+ }
+ c.tunnels.add(src.id, dstID)
+ c.logger.Debug(context.Background(), "adding tunnel",
+ slog.F("src_id", src.id),
+ slog.F("dst_id", dstID))
+ dst, ok := c.peers[dstID]
if ok {
- // If we have the client node, send it to the agent. If not, it will be
- // sent async.
- agentSocket, ok := c.agentSockets[agentID]
- if !ok {
- logger.Debug(context.Background(), "subscribe to agent; socket is nil")
- } else {
- err := agentSocket.Enqueue([]*Node{node})
+ if dst.node != nil {
+ err := src.updateMappingLocked(dstID, dst.node, proto.CoordinateResponse_PeerUpdate_NODE, "add tunnel")
if err != nil {
- return nil, xerrors.Errorf("enqueue client to agent: %w", err)
+ src.logger.Error(context.Background(), "failed update of tunnel src", slog.Error(err))
+ c.removePeerLocked(src.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update")
+ // if the source fails, then the tunnel is also removed and there is no reason to continue
+ // processing.
+ return err
+ }
+ }
+ if src.node != nil {
+ err := dst.updateMappingLocked(src.id, src.node, proto.CoordinateResponse_PeerUpdate_NODE, "add tunnel")
+ if err != nil {
+ dst.logger.Error(context.Background(), "failed update of tunnel dst", slog.Error(err))
+ c.removePeerLocked(dst.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update")
}
}
- } else {
- logger.Debug(context.Background(), "multiagent node doesn't exist")
}
-
- agentNode, ok := c.nodes[agentID]
- if !ok {
- // This is ok, once the agent connects the node will be sent over.
- logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID))
- }
-
- // Send the subscribed agent back to the multi agent.
- return agentNode, nil
-}
-
-func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error {
- c.mutex.Lock()
- defer c.mutex.Unlock()
-
- delete(c.clientsToAgents[enq.UniqueID()], agentID)
- delete(c.agentToConnectionSockets[agentID], enq.UniqueID())
-
return nil
}
-func (c *core) agentLogger(id uuid.UUID) slog.Logger {
- return c.logger.With(slog.F("agent_id", id))
+func (c *core) removeTunnelLocked(src *peer, dstID uuid.UUID) error {
+ err := src.updateMappingLocked(dstID, nil, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "remove tunnel")
+ if err != nil {
+ src.logger.Error(context.Background(), "failed to update", slog.Error(err))
+ c.removePeerLocked(src.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update")
+ // removing the peer also removes all other tunnels and notifies destinations, so it's safe to
+ // return here.
+ return err
+ }
+ dst, ok := c.peers[dstID]
+ if ok {
+ err = dst.updateMappingLocked(src.id, nil, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "remove tunnel")
+ if err != nil {
+ dst.logger.Error(context.Background(), "failed to update", slog.Error(err))
+ c.removePeerLocked(dst.id, proto.CoordinateResponse_PeerUpdate_DISCONNECTED, "failed update")
+ // don't return here because we still want to remove the tunnel, and an error at the
+ // destination doesn't count as an error removing the tunnel at the source.
+ }
+ }
+ c.tunnels.remove(src.id, dstID)
+ return nil
+}
+
+func (c *core) initPeer(p *peer) error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ p.logger.Debug(context.Background(), "initPeer")
+ if c.closed {
+ return ErrClosed
+ }
+ if p.node != nil {
+ return xerrors.Errorf("peer (%s) must be initialized with nil node", p.id)
+ }
+ if old, ok := c.peers[p.id]; ok {
+ // rare and interesting enough to log at Info, but it isn't an error per se
+ old.logger.Info(context.Background(), "overwritten by new connection")
+ close(old.resps)
+ p.overwrites = old.overwrites + 1
+ }
+ now := time.Now()
+ p.start = now
+ p.lastWrite = now
+ c.peers[p.id] = p
+
+ tp := c.tunnels.findTunnelPeers(p.id)
+ p.logger.Debug(context.Background(), "initial tunnel peers", slog.F("tunnel_peers", tp))
+ var others []*peer
+ for _, otherID := range tp {
+ // ok to append nil here because the batch call below filters them out
+ others = append(others, c.peers[otherID])
+ }
+ return p.batchUpdateMappingLocked(others, proto.CoordinateResponse_PeerUpdate_NODE, "init")
+}
+
+// removePeer removes and cleans up a lost peer. It updates all peers it shares a tunnel with, deletes
+// all tunnels from which the removed peer is the source.
+func (c *core) lostPeer(p *peer) error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.logger.Debug(context.Background(), "lostPeer", slog.F("peer_id", p.id))
+ if c.closed {
+ return ErrClosed
+ }
+ if existing, ok := c.peers[p.id]; !ok || existing != p {
+ return ErrAlreadyRemoved
+ }
+ c.removePeerLocked(p.id, proto.CoordinateResponse_PeerUpdate_LOST, "lost")
+ return nil
+}
+
+func (c *core) removePeerLocked(id uuid.UUID, kind proto.CoordinateResponse_PeerUpdate_Kind, reason string) {
+ p, ok := c.peers[id]
+ if !ok {
+ c.logger.Critical(context.Background(), "removed non-existent peer %s", id)
+ return
+ }
+ c.updateTunnelPeersLocked(id, nil, kind, reason)
+ c.tunnels.removeAll(id)
+ close(p.resps)
+ delete(c.peers, id)
}
// ServeAgent accepts a WebSocket connection to an agent that
@@ -463,204 +568,63 @@ func (c *core) agentLogger(id uuid.UUID) slog.Logger {
func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- logger := c.core.agentLogger(id)
- logger.Debug(context.Background(), "coordinating agent")
- // This uniquely identifies a connection that belongs to this goroutine.
- unique := uuid.New()
- tc, err := c.core.initAndTrackAgent(ctx, cancel, conn, id, unique, name)
- if err != nil {
- return err
- }
+ return ServeAgentV1(ctx, c.core.logger, c, conn, id, name)
+}
- // On this goroutine, we read updates from the agent and publish them. We start a second goroutine
- // to write updates back to the agent.
+func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, name string) error {
+ logger = logger.With(slog.F("agent_id", id), slog.F("name", name))
+ defer func() {
+ logger.Debug(ctx, "closing agent connection")
+ err := conn.Close()
+ logger.Debug(ctx, "closed agent connection", slog.Error(err))
+ }()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ logger.Debug(ctx, "starting new agent connection")
+ reqs, resps := c.Coordinate(ctx, id, name, AgentTunnelAuth{})
+ tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent)
go tc.SendUpdates()
-
- defer c.core.agentDisconnected(id, unique)
-
- decoder := json.NewDecoder(conn)
- for {
- err := c.handleNextAgentMessage(id, decoder)
- if err != nil {
- logger.Debug(ctx, "unable to read agent update, connection may be closed", slog.Error(err))
- if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
- return nil
- }
- return xerrors.Errorf("handle next agent message: %w", err)
- }
- }
-}
-
-func (c *core) agentDisconnected(id, unique uuid.UUID) {
- logger := c.agentLogger(id)
- c.mutex.Lock()
- defer c.mutex.Unlock()
-
- // Only delete the connection if it's ours. It could have been
- // overwritten.
- if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique {
- delete(c.agentSockets, id)
- delete(c.nodes, id)
- logger.Debug(context.Background(), "deleted agent socket and node")
- }
- for clientID := range c.agentToConnectionSockets[id] {
- c.clientsToAgents[clientID][id] = nil
- }
-}
-
-// initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is
-// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some
-// updates.
-func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Conn, id, unique uuid.UUID, name string) (*TrackedConn, error) {
- logger := c.logger.With(slog.F("agent_id", id))
- c.mutex.Lock()
- defer c.mutex.Unlock()
- if c.closed {
- return nil, xerrors.New("coordinator is closed")
- }
-
- overwrites := int64(0)
- // If an old agent socket is connected, we Close it to avoid any leaks. This
- // shouldn't ever occur because we expect one agent to be running, but it's
- // possible for a race condition to happen when an agent is disconnected and
- // attempts to reconnect before the server realizes the old connection is
- // dead.
- oldAgentSocket, ok := c.agentSockets[id]
- if ok {
- overwrites = oldAgentSocket.Overwrites() + 1
- _ = oldAgentSocket.Close()
- }
- tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, QueueKindAgent)
- c.agentNameCache.Add(id, name)
-
- sockets, ok := c.agentToConnectionSockets[id]
- if ok {
- // Publish all nodes that want to connect to the
- // desired agent ID.
- nodes := make([]*Node, 0, len(sockets))
- for targetID := range sockets {
- node, ok := c.nodes[targetID]
- if !ok {
- continue
- }
- nodes = append(nodes, node)
- }
- err := tc.Enqueue(nodes)
- // this should never error since we're still the only goroutine that
- // knows about the TrackedConn. If we hit an error something really
- // wrong is happening
- if err != nil {
- logger.Critical(ctx, "unable to queue initial nodes", slog.Error(err))
- return nil, err
- }
- logger.Debug(ctx, "wrote initial client(s) to agent", slog.F("nodes", nodes))
- }
-
- c.agentSockets[id] = tc
- for clientID := range c.agentToConnectionSockets[id] {
- c.clientsToAgents[clientID][id] = tc
- }
-
- logger.Debug(ctx, "added agent socket")
- return tc, nil
-}
-
-func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder) error {
- logger := c.core.agentLogger(id)
- var node Node
- err := decoder.Decode(&node)
- if err != nil {
- return xerrors.Errorf("read json: %w", err)
- }
- logger.Debug(context.Background(), "decoded agent node", slog.F("node", node))
- return c.core.agentNodeUpdate(id, &node)
+ go v1RespLoop(ctx, cancel, logger, tc, resps)
+ go v1ReqLoop(ctx, cancel, logger, conn, reqs)
+ <-ctx.Done()
+ logger.Debug(ctx, "ending agent connection")
+ return nil
}
// This is copied from codersdk because importing it here would cause an import
// cycle. This is just temporary until wsconncache is phased out.
var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
-// This is temporary until we no longer need to detect for agent backwards
-// compatibility.
-// See: https://github.com/coder/coder/issues/8218
-func (c *core) agentIsLegacy(agentID uuid.UUID) bool {
- c.mutex.RLock()
- _, ok := c.legacyAgents[agentID]
- c.mutex.RUnlock()
- return ok
-}
-
-func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error {
- logger := c.agentLogger(id)
- c.mutex.Lock()
- defer c.mutex.Unlock()
- c.nodes[id] = node
-
- // Keep a cache of all legacy agents.
- if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP {
- c.legacyAgents[id] = struct{}{}
- }
-
- connectionSockets, ok := c.agentToConnectionSockets[id]
- if !ok {
- logger.Debug(context.Background(), "no client sockets; unable to send node")
- return nil
- }
-
- // Publish the new node to every listening socket.
- for clientID, connectionSocket := range connectionSockets {
- err := connectionSocket.Enqueue([]*Node{node})
- if err == nil {
- logger.Debug(context.Background(), "enqueued agent node to client",
- slog.F("client_id", clientID))
- } else {
- // queue is backed up for some reason. This is bad, but we don't want to drop
- // updates to other clients over it. Log and move on.
- logger.Error(context.Background(), "failed to Enqueue",
- slog.F("client_id", clientID), slog.Error(err))
- }
- }
-
- return nil
-}
-
// Close closes all of the open connections in the coordinator and stops the
// coordinator from accepting new connections.
func (c *coordinator) Close() error {
- return c.core.close()
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return nil
+ }
+ c.closed = true
+ close(c.closedChan)
+ c.mu.Unlock()
+
+ err := c.core.close()
+ // wait for all request loops to complete
+ c.wg.Wait()
+ return err
}
func (c *core) close() error {
c.mutex.Lock()
+ defer c.mutex.Unlock()
if c.closed {
- c.mutex.Unlock()
return nil
}
c.closed = true
-
- wg := sync.WaitGroup{}
-
- wg.Add(len(c.agentSockets))
- for _, socket := range c.agentSockets {
- socket := socket
- go func() {
- _ = socket.CoordinatorClose()
- wg.Done()
- }()
+ for id := range c.peers {
+ // when closing, mark them as LOST so that we don't disrupt in-progress
+ // connections.
+ c.removePeerLocked(id, proto.CoordinateResponse_PeerUpdate_LOST, "coordinator close")
}
-
- wg.Add(len(c.clients))
- for _, client := range c.clients {
- client := client
- go func() {
- _ = client.CoordinatorClose()
- wg.Done()
- }()
- }
-
- c.mutex.Unlock()
-
- wg.Wait()
return nil
}
@@ -668,161 +632,43 @@ func (c *coordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
c.core.serveHTTPDebug(w, r)
}
-func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) {
+func (c *core) serveHTTPDebug(w http.ResponseWriter, _ *http.Request) {
+ debug := c.getHTMLDebug()
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ err := debugTempl.Execute(w, debug)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte(err.Error()))
+ return
+ }
+}
+
+func (c *core) getHTMLDebug() HTMLDebug {
c.mutex.RLock()
defer c.mutex.RUnlock()
-
- CoordinatorHTTPDebug(
- HTTPDebugFromLocal(false, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
- )(w, r)
-}
-
-func HTTPDebugFromLocal(
- ha bool,
- agentSocketsMap map[uuid.UUID]Queue,
- agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue,
- nodesMap map[uuid.UUID]*Node,
- agentNameCache *lru.Cache[uuid.UUID, string],
-) HTMLDebug {
- now := time.Now()
- data := HTMLDebug{HA: ha}
- for id, conn := range agentSocketsMap {
- start, lastWrite := conn.Stats()
- agent := &HTMLAgent{
- Name: conn.Name(),
- ID: id,
- CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
- LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
- Overwrites: int(conn.Overwrites()),
- }
-
- for id, conn := range agentToConnectionSocketsMap[id] {
- start, lastWrite := conn.Stats()
- agent.Connections = append(agent.Connections, &HTMLClient{
- Name: conn.Name(),
- ID: id,
- CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
- LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
- })
- }
- slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
- return slice.Ascending(a.Name, b.Name)
- })
-
- data.Agents = append(data.Agents, agent)
- }
- slices.SortFunc(data.Agents, func(a, b *HTMLAgent) int {
- return slice.Ascending(a.Name, b.Name)
- })
-
- for agentID, conns := range agentToConnectionSocketsMap {
- if len(conns) == 0 {
- continue
- }
-
- if _, ok := agentSocketsMap[agentID]; ok {
- continue
- }
-
- agentName, ok := agentNameCache.Get(agentID)
- if !ok {
- agentName = "unknown"
- }
- agent := &HTMLAgent{
- Name: agentName,
- ID: agentID,
- }
- for id, conn := range conns {
- start, lastWrite := conn.Stats()
- agent.Connections = append(agent.Connections, &HTMLClient{
- Name: conn.Name(),
- ID: id,
- CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
- LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
- })
- }
- slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
- return slice.Ascending(a.Name, b.Name)
- })
-
- data.MissingAgents = append(data.MissingAgents, agent)
- }
- slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) int {
- return slice.Ascending(a.Name, b.Name)
- })
-
- for id, node := range nodesMap {
- name, _ := agentNameCache.Get(id)
- data.Nodes = append(data.Nodes, &HTMLNode{
- ID: id,
- Name: name,
- Node: node,
- })
- }
- slices.SortFunc(data.Nodes, func(a, b *HTMLNode) int {
- return slice.Ascending(a.Name+a.ID.String(), b.Name+b.ID.String())
- })
-
- return data
-}
-
-func CoordinatorHTTPDebug(data HTMLDebug) func(w http.ResponseWriter, _ *http.Request) {
- return func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
-
- tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{
- "marshal": func(v any) template.JS {
- a, err := json.MarshalIndent(v, "", " ")
- if err != nil {
- //nolint:gosec
- return template.JS(fmt.Sprintf(`{"err": %q}`, err))
- }
- //nolint:gosec
- return template.JS(a)
- },
- }).Parse(coordinatorDebugTmpl)
- if err != nil {
- w.WriteHeader(http.StatusInternalServerError)
- _, _ = w.Write([]byte(err.Error()))
- return
- }
-
- err = tmpl.Execute(w, data)
- if err != nil {
- w.WriteHeader(http.StatusInternalServerError)
- _, _ = w.Write([]byte(err.Error()))
- return
- }
+ debug := HTMLDebug{Tunnels: c.tunnels.htmlDebug()}
+ for _, p := range c.peers {
+ debug.Peers = append(debug.Peers, p.htmlDebug())
}
+ return debug
}
type HTMLDebug struct {
- HA bool
- Agents []*HTMLAgent
- MissingAgents []*HTMLAgent
- Nodes []*HTMLNode
+ Peers []HTMLPeer
+ Tunnels []HTMLTunnel
}
-type HTMLAgent struct {
- Name string
+type HTMLPeer struct {
ID uuid.UUID
+ Name string
CreatedAge time.Duration
LastWriteAge time.Duration
Overwrites int
- Connections []*HTMLClient
+ Node string
}
-type HTMLClient struct {
- Name string
- ID uuid.UUID
- CreatedAge time.Duration
- LastWriteAge time.Duration
-}
-
-type HTMLNode struct {
- ID uuid.UUID
- Name string
- Node any
+type HTMLTunnel struct {
+ Src, Dst uuid.UUID
}
var coordinatorDebugTmpl = `
@@ -830,51 +676,143 @@ var coordinatorDebugTmpl = `
+
- {{- if .HA }}
- high-availability wireguard coordinator debug
- warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete
- {{- else }}
in-memory wireguard coordinator debug
- {{- end }}
- # agents: total {{ len .Agents }}
-
- {{- range .Agents }}
- -
- {{ .Name }} (
{{ .ID }}
): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago, overwrites {{ .Overwrites }}
- connections: total {{ len .Connections}}
-
- {{- range .Connections }}
- - {{ .Name }} (
{{ .ID }}
): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago
- {{- end }}
-
-
+ # peers: total {{ len .Peers }}
+
+
+ Name |
+ ID |
+ Created Age |
+ Last Write Age |
+ Overwrites |
+ Node |
+
+ {{- range .Peers }}
+
+ {{ .Name }} |
+ {{ .ID }} |
+ {{ .CreatedAge }} |
+ {{ .LastWriteAge }} ago |
+ {{ .Overwrites }} |
+ {{ .Node }} |
+
{{- end }}
-
+
- # missing agents: total {{ len .MissingAgents }}
-
- {{- range .MissingAgents}}
- - {{ .Name }} (
{{ .ID }}
): created ? ago, write ? ago, overwrites ?
- connections: total {{ len .Connections }}
-
- {{- range .Connections }}
- - {{ .Name }} (
{{ .ID }}
): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago
- {{- end }}
-
+ # tunnels: total {{ len .Tunnels }}
+
+
+ SrcID |
+ DstID |
+
+ {{- range .Tunnels }}
+
+ {{ .Src }} |
+ {{ .Dst }} |
+
{{- end }}
-
-
- # nodes: total {{ len .Nodes }}
-
- {{- range .Nodes }}
- - {{ .Name }} (
{{ .ID }}
):
- {{ marshal .Node }}
-
- {{- end }}
-
+
`
+var debugTempl = template.Must(template.New("coordinator_debug").Parse(coordinatorDebugTmpl))
+
+func SendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case c <- a:
+ return nil
+ }
+}
+
+func RecvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
+ select {
+ case <-ctx.Done():
+ return a, ctx.Err()
+ case a, ok := <-c:
+ if ok {
+ return a, nil
+ }
+ return a, io.EOF
+ }
+}
+
+func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger,
+ conn net.Conn, reqs chan<- *proto.CoordinateRequest,
+) {
+ defer close(reqs)
+ defer cancel()
+ decoder := json.NewDecoder(conn)
+ for {
+ var node Node
+ err := decoder.Decode(&node)
+ if err != nil {
+ if xerrors.Is(err, io.EOF) ||
+ xerrors.Is(err, io.ErrClosedPipe) ||
+ xerrors.Is(err, context.Canceled) ||
+ xerrors.Is(err, context.DeadlineExceeded) ||
+ websocket.CloseStatus(err) > 0 {
+ logger.Debug(ctx, "v1ReqLoop exiting", slog.Error(err))
+ } else {
+ logger.Info(ctx, "v1ReqLoop failed to decode Node update", slog.Error(err))
+ }
+ return
+ }
+ logger.Debug(ctx, "v1ReqLoop got node update", slog.F("node", node))
+ pn, err := NodeToProto(&node)
+ if err != nil {
+ logger.Critical(ctx, "v1ReqLoop failed to convert v1 node", slog.F("node", node), slog.Error(err))
+ return
+ }
+ req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
+ Node: pn,
+ }}
+ if err := SendCtx(ctx, reqs, req); err != nil {
+ logger.Debug(ctx, "v1ReqLoop ctx expired", slog.Error(err))
+ return
+ }
+ }
+}
+
+func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
+ defer cancel()
+ for {
+ resp, err := RecvCtx(ctx, resps)
+ if err != nil {
+ logger.Debug(ctx, "v1RespLoop done reading responses", slog.Error(err))
+ return
+ }
+ logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp))
+ nodes, err := OnlyNodeUpdates(resp)
+ if err != nil {
+ logger.Critical(ctx, "v1RespLoop failed to decode resp", slog.F("resp", resp), slog.Error(err))
+ _ = q.CoordinatorClose()
+ return
+ }
+ // don't send empty updates
+ if len(nodes) == 0 {
+ logger.Debug(ctx, "v1RespLoop skipping enqueueing 0-length v1 update")
+ continue
+ }
+ err = q.Enqueue(nodes)
+ if err != nil {
+ logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err))
+ }
+ }
+}
diff --git a/tailnet/coordinator_internal_test.go b/tailnet/coordinator_internal_test.go
new file mode 100644
index 0000000000..2344bf2723
--- /dev/null
+++ b/tailnet/coordinator_internal_test.go
@@ -0,0 +1,74 @@
+package tailnet
+
+import (
+ "bytes"
+ "flag"
+ "os"
+ "path/filepath"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+// UpdateGoldenFiles indicates golden files should be updated.
+// To update the golden files:
+// make update-golden-files
+var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
+
+func TestDebugTemplate(t *testing.T) {
+ t.Parallel()
+ if runtime.GOOS == "windows" {
+ t.Skip("newlines screw up golden files on windows")
+ }
+ p1 := uuid.MustParse("01000000-2222-2222-2222-222222222222")
+ p2 := uuid.MustParse("02000000-2222-2222-2222-222222222222")
+ in := HTMLDebug{
+ Peers: []HTMLPeer{
+ {
+ Name: "Peer 1",
+ ID: p1,
+ LastWriteAge: 5 * time.Second,
+ Node: `id:1 preferred_derp:999 endpoints:"192.168.0.49:4449"`,
+ CreatedAge: 87 * time.Second,
+ Overwrites: 0,
+ },
+ {
+ Name: "Peer 2",
+ ID: p2,
+ LastWriteAge: 7 * time.Second,
+ Node: `id:2 preferred_derp:999 endpoints:"192.168.0.33:4449"`,
+ CreatedAge: time.Hour,
+ Overwrites: 2,
+ },
+ },
+ Tunnels: []HTMLTunnel{
+ {
+ Src: p1,
+ Dst: p2,
+ },
+ },
+ }
+ buf := new(bytes.Buffer)
+ err := debugTempl.Execute(buf, in)
+ require.NoError(t, err)
+ actual := buf.Bytes()
+
+ goldenPath := filepath.Join("testdata", "debug.golden.html")
+ if *UpdateGoldenFiles {
+ t.Logf("update golden file %s", goldenPath)
+ err := os.WriteFile(goldenPath, actual, 0o600)
+ require.NoError(t, err, "update golden file")
+ }
+
+ expected, err := os.ReadFile(goldenPath)
+ require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes")
+
+ require.Equal(
+ t, string(expected), string(actual),
+ "golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes",
+ goldenPath,
+ )
+}
diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go
index 1435af3688..278f4b32ca 100644
--- a/tailnet/coordinator_test.go
+++ b/tailnet/coordinator_test.go
@@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/tailnet"
+ "github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
@@ -27,7 +28,12 @@ func TestCoordinator(t *testing.T) {
t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
+ ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
+ defer func() {
+ err := coordinator.Close()
+ require.NoError(t, err)
+ }()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
@@ -45,14 +51,19 @@ func TestCoordinator(t *testing.T) {
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, client.Close())
require.NoError(t, server.Close())
- <-errChan
- <-closeChan
+ _ = testutil.RequireRecvCtx(ctx, t, errChan)
+ _ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
+ ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
+ defer func() {
+ err := coordinator.Close()
+ require.NoError(t, err)
+ }()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
@@ -70,14 +81,18 @@ func TestCoordinator(t *testing.T) {
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
- <-errChan
- <-closeChan
+ _ = testutil.RequireRecvCtx(ctx, t, errChan)
+ _ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
+ defer func() {
+ err := coordinator.Close()
+ require.NoError(t, err)
+ }()
// in this test we use real websockets to test use of deadlines
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
@@ -116,14 +131,11 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeClientChan)
}()
- select {
- case agentNodes := <-clientNodeChan:
- require.Len(t, agentNodes, 1)
- case <-ctx.Done():
- t.Fatal("timed out")
- }
+ agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
+ require.Len(t, agentNodes, 1)
+
sendClientNode(&tailnet.Node{PreferredDERP: 2})
- clientNodes := <-agentNodeChan
+ clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
// wait longer than the internal wait timeout.
@@ -132,18 +144,14 @@ func TestCoordinator(t *testing.T) {
// Ensure an update to the agent node reaches the client!
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
- select {
- case agentNodes := <-clientNodeChan:
- require.Len(t, agentNodes, 1)
- case <-ctx.Done():
- t.Fatal("timed out")
- }
+ agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
+ require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect.
err := agentWS.Close()
require.NoError(t, err)
- <-agentErrChan
- <-closeAgentChan
+ _ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
+ _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
// Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe()
@@ -159,30 +167,32 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeAgentChan)
}()
- // Ensure the existing listening client sends it's node immediately!
- clientNodes = <-agentNodeChan
+ // Ensure the existing listening client sends its node immediately!
+ clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
- <-agentErrChan
- <-closeAgentChan
+ _ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
+ _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
err = clientWS.Close()
require.NoError(t, err)
- <-clientErrChan
- <-closeClientChan
+ _ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
+ _ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})
t.Run("AgentDoubleConnect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
+ ctx := testutil.Context(t, testutil.WaitLong)
agentWS1, agentServerWS1 := net.Pipe()
defer agentWS1.Close()
agentNodeChan1 := make(chan []*tailnet.Node)
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
+ t.Logf("agent1 got node update: %v", nodes)
agentNodeChan1 <- nodes
return nil
})
@@ -203,6 +213,7 @@ func TestCoordinator(t *testing.T) {
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
+ t.Logf("client got node update: %v", nodes)
clientNodeChan <- nodes
return nil
})
@@ -213,15 +224,15 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeClientChan)
}()
- agentNodes := <-clientNodeChan
+ agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
- clientNodes := <-agentNodeChan1
+ clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
- agentNodes = <-clientNodeChan
+ agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Create a new agent connection without disconnecting the old one.
@@ -229,6 +240,7 @@ func TestCoordinator(t *testing.T) {
defer agentWS2.Close()
agentNodeChan2 := make(chan []*tailnet.Node)
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
+ t.Logf("agent2 got node update: %v", nodes)
agentNodeChan2 <- nodes
return nil
})
@@ -240,33 +252,22 @@ func TestCoordinator(t *testing.T) {
}()
// Ensure the existing listening client sends it's node immediately!
- clientNodes = <-agentNodeChan2
+ clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2)
require.Len(t, clientNodes, 1)
- counts, ok := coordinator.(interface {
- NodeCount() int
- AgentCount() int
- })
- if !ok {
- t.Fatal("coordinator should have NodeCount() and AgentCount()")
- }
-
- assert.Equal(t, 2, counts.NodeCount())
- assert.Equal(t, 1, counts.AgentCount())
+ // This original agent websocket should've been closed forcefully.
+ _ = testutil.RequireRecvCtx(ctx, t, agentErrChan1)
+ _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
err := agentWS2.Close()
require.NoError(t, err)
- <-agentErrChan2
- <-closeAgentChan2
+ _ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
+ _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
err = clientWS.Close()
require.NoError(t, err)
- <-clientErrChan
- <-closeClientChan
-
- // This original agent websocket should've been closed forcefully.
- <-agentErrChan1
- <-closeAgentChan1
+ _ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
+ _ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})
}
@@ -334,9 +335,8 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
require.NoError(t, err)
n, err = clientWS.Read(buf[1:])
require.NoError(t, err)
- require.Equal(t, len(buf)-1, n)
var cNodes []*tailnet.Node
- err = json.Unmarshal(buf, &cNodes)
+ err = json.Unmarshal(buf[:n+1], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 0, cNodes[0].PreferredDERP)
@@ -348,13 +348,36 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
require.NoError(t, err)
n, err = clientWS.Read(buf)
require.NoError(t, err)
- require.Equal(t, len(buf), n)
- err = json.Unmarshal(buf, &cNodes)
+ err = json.Unmarshal(buf[:n], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 1, cNodes[0].PreferredDERP)
}
+func TestCoordinator_BidirectionalTunnels(t *testing.T) {
+ t.Parallel()
+ logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
+ coordinator := tailnet.NewCoordinatorV2(logger)
+ ctx := testutil.Context(t, testutil.WaitShort)
+ test.BidirectionalTunnels(ctx, t, coordinator)
+}
+
+func TestCoordinator_GracefulDisconnect(t *testing.T) {
+ t.Parallel()
+ logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
+ coordinator := tailnet.NewCoordinatorV2(logger)
+ ctx := testutil.Context(t, testutil.WaitShort)
+ test.GracefulDisconnectTest(ctx, t, coordinator)
+}
+
+func TestCoordinator_Lost(t *testing.T) {
+ t.Parallel()
+ logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
+ coordinator := tailnet.NewCoordinatorV2(logger)
+ ctx := testutil.Context(t, testutil.WaitShort)
+ test.LostTest(ctx, t, coordinator)
+}
+
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
t.Helper()
sc := make(chan net.Conn, 1)
diff --git a/tailnet/peer.go b/tailnet/peer.go
new file mode 100644
index 0000000000..e51aadaeac
--- /dev/null
+++ b/tailnet/peer.go
@@ -0,0 +1,160 @@
+package tailnet
+
+import (
+ "context"
+ "time"
+
+ "golang.org/x/xerrors"
+
+ "github.com/google/uuid"
+
+ "cdr.dev/slog"
+
+ "github.com/coder/coder/v2/tailnet/proto"
+)
+
+type peer struct {
+ logger slog.Logger
+ id uuid.UUID
+ node *proto.Node
+ resps chan<- *proto.CoordinateResponse
+ reqs <-chan *proto.CoordinateRequest
+ auth TunnelAuth
+ sent map[uuid.UUID]*proto.Node
+
+ name string
+ start time.Time
+ lastWrite time.Time
+ overwrites int
+}
+
+// updateMappingLocked updates the mapping for another peer linked to this one by a tunnel. This method
+// is NOT threadsafe and must be called while holding the core lock.
+func (p *peer) updateMappingLocked(id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) error {
+ logger := p.logger.With(slog.F("from_id", id), slog.F("kind", k), slog.F("reason", reason))
+ update, err := p.storeMappingLocked(id, n, k, reason)
+ if xerrors.Is(err, noResp) {
+ logger.Debug(context.Background(), "skipping update")
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+
+ req := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{update}}
+ select {
+ case p.resps <- req:
+ p.lastWrite = time.Now()
+ logger.Debug(context.Background(), "wrote peer update")
+ return nil
+ default:
+ return ErrWouldBlock
+ }
+}
+
+// batchUpdateMapping updates the mappings for a list of peers linked to this one by a tunnel. This
+// method is NOT threadsafe and must be called while holding the core lock.
+func (p *peer) batchUpdateMappingLocked(others []*peer, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) error {
+ req := &proto.CoordinateResponse{}
+ for _, other := range others {
+ if other == nil || other.node == nil {
+ continue
+ }
+ update, err := p.storeMappingLocked(other.id, other.node, k, reason)
+ if xerrors.Is(err, noResp) {
+ continue
+ }
+ if err != nil {
+ return err
+ }
+ req.PeerUpdates = append(req.PeerUpdates, update)
+ }
+ if len(req.PeerUpdates) == 0 {
+ return nil
+ }
+ select {
+ case p.resps <- req:
+ p.lastWrite = time.Now()
+ p.logger.Debug(context.Background(), "wrote batched update", slog.F("num_peer_updates", len(req.PeerUpdates)))
+ return nil
+ default:
+ return ErrWouldBlock
+ }
+}
+
+var noResp = xerrors.New("no response needed")
+
+func (p *peer) storeMappingLocked(
+ id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string,
+) (
+ *proto.CoordinateResponse_PeerUpdate, error,
+) {
+ p.logger.Debug(context.Background(), "got updated mapping",
+ slog.F("from_id", id), slog.F("kind", k), slog.F("reason", reason))
+ sn, ok := p.sent[id]
+ switch {
+ case !ok && (k == proto.CoordinateResponse_PeerUpdate_LOST || k == proto.CoordinateResponse_PeerUpdate_DISCONNECTED):
+ // we don't need to send a lost/disconnect update if we've never sent an update about this peer
+ return nil, noResp
+ case !ok && k == proto.CoordinateResponse_PeerUpdate_NODE:
+ p.sent[id] = n
+ case ok && k == proto.CoordinateResponse_PeerUpdate_LOST:
+ delete(p.sent, id)
+ case ok && k == proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
+ delete(p.sent, id)
+ case ok && k == proto.CoordinateResponse_PeerUpdate_NODE:
+ eq, err := sn.Equal(n)
+ if err != nil {
+ p.logger.Critical(context.Background(), "failed to compare nodes", slog.F("old", sn), slog.F("new", n))
+ return nil, xerrors.Errorf("failed to compare nodes: %s", sn.String())
+ }
+ if eq {
+ return nil, noResp
+ }
+ p.sent[id] = n
+ }
+ return &proto.CoordinateResponse_PeerUpdate{
+ Uuid: id[:],
+ Kind: k,
+ Node: n,
+ Reason: reason,
+ }, nil
+}
+
+func (p *peer) reqLoop(ctx context.Context, logger slog.Logger, handler func(*peer, *proto.CoordinateRequest) error) {
+ for {
+ select {
+ case <-ctx.Done():
+ logger.Debug(ctx, "peerReadLoop context done")
+ return
+ case req, ok := <-p.reqs:
+ if !ok {
+ logger.Debug(ctx, "peerReadLoop channel closed")
+ return
+ }
+ logger.Debug(ctx, "peerReadLoop got request")
+ if err := handler(p, req); err != nil {
+ if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
+ return
+ }
+ logger.Error(ctx, "peerReadLoop error handling request", slog.Error(err), slog.F("request", req))
+ return
+ }
+ }
+ }
+}
+
+func (p *peer) htmlDebug() HTMLPeer {
+ node := ""
+ if p.node != nil {
+ node = p.node.String()
+ }
+ return HTMLPeer{
+ ID: p.id,
+ Name: p.name,
+ CreatedAge: time.Since(p.start),
+ LastWriteAge: time.Since(p.lastWrite),
+ Overwrites: p.overwrites,
+ Node: node,
+ }
+}
diff --git a/tailnet/test/cases.go b/tailnet/test/cases.go
new file mode 100644
index 0000000000..a54c1e9320
--- /dev/null
+++ b/tailnet/test/cases.go
@@ -0,0 +1,55 @@
+package test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/coder/coder/v2/tailnet"
+)
+
+func GracefulDisconnectTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
+ p1 := NewPeer(ctx, t, coordinator, "p1")
+ defer p1.Close(ctx)
+ p2 := NewPeer(ctx, t, coordinator, "p2")
+ defer p2.Close(ctx)
+ p1.AddTunnel(p2.ID)
+ p1.UpdateDERP(1)
+ p2.UpdateDERP(2)
+
+ p1.AssertEventuallyHasDERP(p2.ID, 2)
+ p2.AssertEventuallyHasDERP(p1.ID, 1)
+
+ p2.Disconnect()
+ p1.AssertEventuallyDisconnected(p2.ID)
+ p2.AssertEventuallyResponsesClosed()
+}
+
+func LostTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
+ p1 := NewPeer(ctx, t, coordinator, "p1")
+ defer p1.Close(ctx)
+ p2 := NewPeer(ctx, t, coordinator, "p2")
+ defer p2.Close(ctx)
+ p1.AddTunnel(p2.ID)
+ p1.UpdateDERP(1)
+ p2.UpdateDERP(2)
+
+ p1.AssertEventuallyHasDERP(p2.ID, 2)
+ p2.AssertEventuallyHasDERP(p1.ID, 1)
+
+ p2.Close(ctx)
+ p1.AssertEventuallyLost(p2.ID)
+}
+
+func BidirectionalTunnels(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
+ p1 := NewPeer(ctx, t, coordinator, "p1")
+ defer p1.Close(ctx)
+ p2 := NewPeer(ctx, t, coordinator, "p2")
+ defer p2.Close(ctx)
+ p1.AddTunnel(p2.ID)
+ p2.AddTunnel(p1.ID)
+ p1.UpdateDERP(1)
+ p2.UpdateDERP(2)
+
+ p1.AssertEventuallyHasDERP(p2.ID, 2)
+ p2.AssertEventuallyHasDERP(p1.ID, 1)
+}
diff --git a/tailnet/test/peer.go b/tailnet/test/peer.go
new file mode 100644
index 0000000000..dde979f813
--- /dev/null
+++ b/tailnet/test/peer.go
@@ -0,0 +1,184 @@
+package test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/xerrors"
+
+ "github.com/coder/coder/v2/tailnet"
+ "github.com/coder/coder/v2/tailnet/proto"
+)
+
+type PeerStatus struct {
+ preferredDERP int32
+ status proto.CoordinateResponse_PeerUpdate_Kind
+}
+
+type Peer struct {
+ ctx context.Context
+ cancel context.CancelFunc
+ t testing.TB
+ ID uuid.UUID
+ name string
+ resps <-chan *proto.CoordinateResponse
+ reqs chan<- *proto.CoordinateRequest
+ peers map[uuid.UUID]PeerStatus
+}
+
+func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer {
+ p := &Peer{t: t, name: name, peers: make(map[uuid.UUID]PeerStatus)}
+ p.ctx, p.cancel = context.WithCancel(ctx)
+ if len(id) > 1 {
+ t.Fatal("too many")
+ }
+ if len(id) == 1 {
+ p.ID = id[0]
+ } else {
+ p.ID = uuid.New()
+ }
+ // SingleTailnetTunnelAuth allows connections to arbitrary peers
+ p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{})
+ return p
+}
+
+func (p *Peer) AddTunnel(other uuid.UUID) {
+ p.t.Helper()
+ req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: tailnet.UUIDToByteSlice(other)}}
+ select {
+ case <-p.ctx.Done():
+ p.t.Errorf("timeout adding tunnel for %s", p.name)
+ return
+ case p.reqs <- req:
+ return
+ }
+}
+
+func (p *Peer) UpdateDERP(derp int32) {
+ p.t.Helper()
+ req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
+ select {
+ case <-p.ctx.Done():
+ p.t.Errorf("timeout updating node for %s", p.name)
+ return
+ case p.reqs <- req:
+ return
+ }
+}
+
+func (p *Peer) Disconnect() {
+ p.t.Helper()
+ req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
+ select {
+ case <-p.ctx.Done():
+ p.t.Errorf("timeout updating node for %s", p.name)
+ return
+ case p.reqs <- req:
+ return
+ }
+}
+
+func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
+ p.t.Helper()
+ for {
+ o, ok := p.peers[other]
+ if ok && o.preferredDERP == derp {
+ return
+ }
+ if err := p.handleOneResp(); err != nil {
+ assert.NoError(p.t, err)
+ return
+ }
+ }
+}
+
+func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
+ p.t.Helper()
+ for {
+ _, ok := p.peers[other]
+ if !ok {
+ return
+ }
+ if err := p.handleOneResp(); err != nil {
+ assert.NoError(p.t, err)
+ return
+ }
+ }
+}
+
+func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
+ p.t.Helper()
+ for {
+ o := p.peers[other]
+ if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
+ return
+ }
+ if err := p.handleOneResp(); err != nil {
+ assert.NoError(p.t, err)
+ return
+ }
+ }
+}
+
+func (p *Peer) AssertEventuallyResponsesClosed() {
+ p.t.Helper()
+ for {
+ err := p.handleOneResp()
+ if xerrors.Is(err, responsesClosed) {
+ return
+ }
+ if !assert.NoError(p.t, err) {
+ return
+ }
+ }
+}
+
+var responsesClosed = xerrors.New("responses closed")
+
+func (p *Peer) handleOneResp() error {
+ select {
+ case <-p.ctx.Done():
+ return p.ctx.Err()
+ case resp, ok := <-p.resps:
+ if !ok {
+ return responsesClosed
+ }
+ for _, update := range resp.PeerUpdates {
+ id, err := uuid.FromBytes(update.Uuid)
+ if err != nil {
+ return err
+ }
+ switch update.Kind {
+ case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
+ p.peers[id] = PeerStatus{
+ preferredDERP: update.GetNode().GetPreferredDerp(),
+ status: update.Kind,
+ }
+ case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
+ delete(p.peers, id)
+ default:
+ return xerrors.Errorf("unhandled update kind %s", update.Kind)
+ }
+ }
+ }
+ return nil
+}
+
+func (p *Peer) Close(ctx context.Context) {
+ p.t.Helper()
+ p.cancel()
+ for {
+ select {
+ case <-ctx.Done():
+ p.t.Errorf("timeout waiting for responses to close for %s", p.name)
+ return
+ case _, ok := <-p.resps:
+ if ok {
+ continue
+ }
+ return
+ }
+ }
+}
diff --git a/tailnet/testdata/debug.golden.html b/tailnet/testdata/debug.golden.html
new file mode 100644
index 0000000000..bbf4656ec4
--- /dev/null
+++ b/tailnet/testdata/debug.golden.html
@@ -0,0 +1,62 @@
+
+
+
+
+
+
+
+
+ in-memory wireguard coordinator debug
+
+ # peers: total 2
+
+
+ Name |
+ ID |
+ Created Age |
+ Last Write Age |
+ Overwrites |
+ Node |
+
+
+ Peer 1 |
+ 01000000-2222-2222-2222-222222222222 |
+ 1m27s |
+ 5s ago |
+ 0 |
+ id:1 preferred_derp:999 endpoints:"192.168.0.49:4449" |
+
+
+ Peer 2 |
+ 02000000-2222-2222-2222-222222222222 |
+ 1h0m0s |
+ 7s ago |
+ 2 |
+ id:2 preferred_derp:999 endpoints:"192.168.0.33:4449" |
+
+
+
+ # tunnels: total 1
+
+
+ SrcID |
+ DstID |
+
+
+ 01000000-2222-2222-2222-222222222222 |
+ 02000000-2222-2222-2222-222222222222 |
+
+
+
+
diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go
index d083c838b2..3b3feaa132 100644
--- a/tailnet/trackedconn.go
+++ b/tailnet/trackedconn.go
@@ -20,6 +20,8 @@ const (
// ResponseBufferSize is the max number of responses to buffer per connection before we start
// dropping updates
ResponseBufferSize = 512
+ // RequestBufferSize is the max number of requests to buffer per connection
+ RequestBufferSize = 32
)
type TrackedConn struct {
diff --git a/tailnet/tunnel.go b/tailnet/tunnel.go
index 19f4a485dc..6fe36ee419 100644
--- a/tailnet/tunnel.go
+++ b/tailnet/tunnel.go
@@ -28,3 +28,74 @@ type AgentTunnelAuth struct{}
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
return false
}
+
+// tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all
+// methods must be serialized by holding, e.g. the core mutex.
+type tunnelStore struct {
+ bySrc map[uuid.UUID]map[uuid.UUID]struct{}
+ byDst map[uuid.UUID]map[uuid.UUID]struct{}
+}
+
+func newTunnelStore() *tunnelStore {
+ return &tunnelStore{
+ bySrc: make(map[uuid.UUID]map[uuid.UUID]struct{}),
+ byDst: make(map[uuid.UUID]map[uuid.UUID]struct{}),
+ }
+}
+
+func (s *tunnelStore) add(src, dst uuid.UUID) {
+ srcM, ok := s.bySrc[src]
+ if !ok {
+ srcM = make(map[uuid.UUID]struct{})
+ s.bySrc[src] = srcM
+ }
+ srcM[dst] = struct{}{}
+ dstM, ok := s.byDst[dst]
+ if !ok {
+ dstM = make(map[uuid.UUID]struct{})
+ s.byDst[dst] = dstM
+ }
+ dstM[src] = struct{}{}
+}
+
+func (s *tunnelStore) remove(src, dst uuid.UUID) {
+ delete(s.bySrc[src], dst)
+ if len(s.bySrc[src]) == 0 {
+ delete(s.bySrc, src)
+ }
+ delete(s.byDst[dst], src)
+ if len(s.byDst[dst]) == 0 {
+ delete(s.byDst, dst)
+ }
+}
+
+func (s *tunnelStore) removeAll(src uuid.UUID) {
+ for dst := range s.bySrc[src] {
+ s.remove(src, dst)
+ }
+}
+
+func (s *tunnelStore) findTunnelPeers(id uuid.UUID) []uuid.UUID {
+ set := make(map[uuid.UUID]struct{})
+ for dst := range s.bySrc[id] {
+ set[dst] = struct{}{}
+ }
+ for src := range s.byDst[id] {
+ set[src] = struct{}{}
+ }
+ out := make([]uuid.UUID, 0, len(set))
+ for id := range set {
+ out = append(out, id)
+ }
+ return out
+}
+
+func (s *tunnelStore) htmlDebug() []HTMLTunnel {
+ out := make([]HTMLTunnel, 0)
+ for src, dsts := range s.bySrc {
+ for dst := range dsts {
+ out = append(out, HTMLTunnel{Src: src, Dst: dst})
+ }
+ }
+ return out
+}
diff --git a/tailnet/tunnel_internal_test.go b/tailnet/tunnel_internal_test.go
new file mode 100644
index 0000000000..3ba7cc4165
--- /dev/null
+++ b/tailnet/tunnel_internal_test.go
@@ -0,0 +1,45 @@
+package tailnet
+
+import (
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTunnelStore_Bidir(t *testing.T) {
+ t.Parallel()
+ p1 := uuid.MustParse("00000001-1111-1111-1111-111111111111")
+ p2 := uuid.MustParse("00000002-1111-1111-1111-111111111111")
+ uut := newTunnelStore()
+ uut.add(p1, p2)
+ require.Equal(t, []uuid.UUID{p1}, uut.findTunnelPeers(p2))
+ require.Equal(t, []uuid.UUID{p2}, uut.findTunnelPeers(p1))
+ uut.remove(p1, p2)
+ require.Empty(t, uut.findTunnelPeers(p1))
+ require.Empty(t, uut.findTunnelPeers(p2))
+ require.Len(t, uut.byDst, 0)
+ require.Len(t, uut.bySrc, 0)
+}
+
+func TestTunnelStore_RemoveAll(t *testing.T) {
+ t.Parallel()
+ p1 := uuid.MustParse("00000001-1111-1111-1111-111111111111")
+ p2 := uuid.MustParse("00000002-1111-1111-1111-111111111111")
+ p3 := uuid.MustParse("00000003-1111-1111-1111-111111111111")
+ uut := newTunnelStore()
+ uut.add(p1, p2)
+ uut.add(p1, p3)
+ uut.add(p3, p1)
+ require.Len(t, uut.findTunnelPeers(p1), 2)
+ require.Len(t, uut.findTunnelPeers(p2), 1)
+ require.Len(t, uut.findTunnelPeers(p3), 1)
+ uut.removeAll(p1)
+ require.Len(t, uut.findTunnelPeers(p1), 1)
+ require.Len(t, uut.findTunnelPeers(p2), 0)
+ require.Len(t, uut.findTunnelPeers(p3), 1)
+ uut.removeAll(p3)
+ require.Len(t, uut.findTunnelPeers(p1), 0)
+ require.Len(t, uut.findTunnelPeers(p2), 0)
+ require.Len(t, uut.findTunnelPeers(p3), 0)
+}