feat: support v2 Tailnet API in AGPL coordinator (#11010)

Fixes #10529
This commit is contained in:
Spike Curtis 2023-12-06 15:04:28 +04:00 committed by GitHub
parent 38ed816207
commit 2c86d0bed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1509 additions and 1065 deletions

View File

@ -83,6 +83,7 @@ helm/**/templates/*.yaml
# Testdata shouldn't be formatted.
scripts/apitypings/testdata/**/*.ts
enterprise/tailnet/testdata/*.golden.html
tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
site/e2e/provisionerGenerated.ts

View File

@ -9,6 +9,7 @@ helm/**/templates/*.yaml
# Testdata shouldn't be formatted.
scripts/apitypings/testdata/**/*.ts
enterprise/tailnet/testdata/*.golden.html
tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
site/e2e/provisionerGenerated.ts

View File

@ -602,6 +602,7 @@ update-golden-files: \
scripts/ci-report/testdata/.gen-golden \
enterprise/cli/testdata/.gen-golden \
enterprise/tailnet/testdata/.gen-golden \
tailnet/testdata/.gen-golden \
coderd/.gen-golden \
provisioner/terraform/testdata/.gen-golden
.PHONY: update-golden-files
@ -614,6 +615,10 @@ enterprise/cli/testdata/.gen-golden: $(wildcard enterprise/cli/testdata/*.golden
go test ./enterprise/cli -run="TestEnterpriseCommandHelp" -update
touch "$@"
tailnet/testdata/.gen-golden: $(wildcard tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard tailnet/*_test.go)
go test ./tailnet -run="TestDebugTemplate" -update
touch "$@"
enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard enterprise/tailnet/*_test.go)
go test ./enterprise/tailnet -run="TestDebugTemplate" -update
touch "$@"

View File

@ -86,7 +86,7 @@ func (c *connIO) recvLoop() {
if c.disconnected {
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
}
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
@ -97,14 +97,14 @@ func (c *connIO) recvLoop() {
tKey: tKey{src: c.UniqueID()},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
}
}
}()
defer c.Close()
for {
req, err := recvCtx(c.peerCtx, c.requests)
req, err := agpl.RecvCtx(c.peerCtx, c.requests)
if err != nil {
if xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
@ -132,7 +132,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
node: req.UpdateSelf.Node,
kind: proto.CoordinateResponse_PeerUpdate_NODE,
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
return err
}
@ -156,7 +156,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
},
active: true,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
return err
}
@ -177,7 +177,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
return err
}

View File

@ -5,17 +5,22 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"html/template"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/google/uuid"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
agpl "github.com/coder/coder/v2/tailnet"
)
@ -719,7 +724,209 @@ func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
c.mutex.RLock()
defer c.mutex.RUnlock()
agpl.CoordinatorHTTPDebug(
agpl.HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
CoordinatorHTTPDebug(
HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
)(w, r)
}
func HTTPDebugFromLocal(
ha bool,
agentSocketsMap map[uuid.UUID]agpl.Queue,
agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]agpl.Queue,
nodesMap map[uuid.UUID]*agpl.Node,
agentNameCache *lru.Cache[uuid.UUID, string],
) HTMLDebugHA {
now := time.Now()
data := HTMLDebugHA{HA: ha}
for id, conn := range agentSocketsMap {
start, lastWrite := conn.Stats()
agent := &HTMLAgent{
Name: conn.Name(),
ID: id,
CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
Overwrites: int(conn.Overwrites()),
}
for id, conn := range agentToConnectionSocketsMap[id] {
start, lastWrite := conn.Stats()
agent.Connections = append(agent.Connections, &HTMLClient{
Name: conn.Name(),
ID: id,
CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
})
}
slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
return slice.Ascending(a.Name, b.Name)
})
data.Agents = append(data.Agents, agent)
}
slices.SortFunc(data.Agents, func(a, b *HTMLAgent) int {
return slice.Ascending(a.Name, b.Name)
})
for agentID, conns := range agentToConnectionSocketsMap {
if len(conns) == 0 {
continue
}
if _, ok := agentSocketsMap[agentID]; ok {
continue
}
agentName, ok := agentNameCache.Get(agentID)
if !ok {
agentName = "unknown"
}
agent := &HTMLAgent{
Name: agentName,
ID: agentID,
}
for id, conn := range conns {
start, lastWrite := conn.Stats()
agent.Connections = append(agent.Connections, &HTMLClient{
Name: conn.Name(),
ID: id,
CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
})
}
slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
return slice.Ascending(a.Name, b.Name)
})
data.MissingAgents = append(data.MissingAgents, agent)
}
slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) int {
return slice.Ascending(a.Name, b.Name)
})
for id, node := range nodesMap {
name, _ := agentNameCache.Get(id)
data.Nodes = append(data.Nodes, &HTMLNode{
ID: id,
Name: name,
Node: node,
})
}
slices.SortFunc(data.Nodes, func(a, b *HTMLNode) int {
return slice.Ascending(a.Name+a.ID.String(), b.Name+b.ID.String())
})
return data
}
func CoordinatorHTTPDebug(data HTMLDebugHA) func(w http.ResponseWriter, _ *http.Request) {
return func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{
"marshal": func(v any) template.JS {
a, err := json.MarshalIndent(v, "", " ")
if err != nil {
//nolint:gosec
return template.JS(fmt.Sprintf(`{"err": %q}`, err))
}
//nolint:gosec
return template.JS(a)
},
}).Parse(haCoordinatorDebugTmpl)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(err.Error()))
return
}
err = tmpl.Execute(w, data)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(err.Error()))
return
}
}
}
type HTMLDebugHA struct {
HA bool
Agents []*HTMLAgent
MissingAgents []*HTMLAgent
Nodes []*HTMLNode
}
type HTMLAgent struct {
Name string
ID uuid.UUID
CreatedAge time.Duration
LastWriteAge time.Duration
Overwrites int
Connections []*HTMLClient
}
type HTMLClient struct {
Name string
ID uuid.UUID
CreatedAge time.Duration
LastWriteAge time.Duration
}
type HTMLNode struct {
ID uuid.UUID
Name string
Node any
}
var haCoordinatorDebugTmpl = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
</head>
<body>
{{- if .HA }}
<h1>high-availability wireguard coordinator debug</h1>
<h4 style="margin-top:-25px">warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete</h4>
{{- else }}
<h1>in-memory wireguard coordinator debug</h1>
{{- end }}
<h2 id=agents> <a href=#agents>#</a> agents: total {{ len .Agents }} </h2>
<ul>
{{- range .Agents }}
<li style="margin-top:4px">
<b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago, overwrites {{ .Overwrites }}
<h3 style="margin:0px;font-size:16px;font-weight:400"> connections: total {{ len .Connections}} </h3>
<ul>
{{- range .Connections }}
<li><b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago </li>
{{- end }}
</ul>
</li>
{{- end }}
</ul>
<h2 id=missing-agents><a href=#missing-agents>#</a> missing agents: total {{ len .MissingAgents }}</h2>
<ul>
{{- range .MissingAgents}}
<li style="margin-top:4px"><b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created ? ago, write ? ago, overwrites ? </li>
<h3 style="margin:0px;font-size:16px;font-weight:400"> connections: total {{ len .Connections }} </h3>
<ul>
{{- range .Connections }}
<li><b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago </li>
{{- end }}
</ul>
{{- end }}
</ul>
<h2 id=nodes><a href=#nodes>#</a> nodes: total {{ len .Nodes }}</h2>
<ul>
{{- range .Nodes }}
<li style="margin-top:4px"><b>{{ .Name }}</b> (<code>{{ .ID }}</code>):
<span style="white-space: pre;"><code>{{ marshal .Node }}</code></span>
</li>
{{- end }}
</ul>
</body>
</html>
`

View File

@ -3,17 +3,12 @@ package tailnet
import (
"context"
"database/sql"
"encoding/json"
"io"
"net"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"nhooyr.io/websocket"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/cenkalti/backoff/v4"
@ -30,17 +25,16 @@ import (
)
const (
EventHeartbeats = "tailnet_coordinator_heartbeat"
eventPeerUpdate = "tailnet_peer_update"
eventTunnelUpdate = "tailnet_tunnel_update"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 10
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
requestResponseBuffSize = 32
EventHeartbeats = "tailnet_coordinator_heartbeat"
eventPeerUpdate = "tailnet_peer_update"
eventTunnelUpdate = "tailnet_tunnel_update"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 10
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
)
// pgCoord is a postgres-backed coordinator
@ -161,55 +155,8 @@ func NewPGCoordV2(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, sto
return newPGCoordInternal(ctx, logger, ps, store)
}
// This is copied from codersdk because importing it here would cause an import
// cycle. This is just temporary until wsconncache is phased out.
var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
logger := c.logger.With(slog.F("client_id", id)).Named("multiagent")
ctx, cancel := context.WithCancel(c.ctx)
reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.SingleTailnetTunnelAuth{})
ma := (&agpl.MultiAgent{
ID: id,
AgentIsLegacyFunc: func(agentID uuid.UUID) bool {
if n := c.Node(agentID); n == nil {
// If we don't have the node at all assume it's legacy for
// safety.
return true
} else if len(n.Addresses) > 0 && n.Addresses[0].Addr() == legacyAgentIP {
// An agent is determined to be "legacy" if it's first IP is the
// legacy IP. Agents with only the legacy IP aren't compatible
// with single_tailnet and must be routed through wsconncache.
return true
} else {
return false
}
},
OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) {
err := sendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}})
return c.Node(agent), err
},
OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error {
err := sendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}})
return err
},
OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error {
pn, err := agpl.NodeToProto(node)
if err != nil {
return err
}
return sendCtx(c.ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: pn,
}})
},
OnRemove: func(_ agpl.Queue) {
_ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
cancel()
},
}).Init()
go v1SendLoop(ctx, cancel, logger, ma, resps)
return ma
return agpl.ServeMultiAgent(c, c.logger, id)
}
func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
@ -253,116 +200,11 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
}
func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
defer func() {
err := conn.Close()
if err != nil {
logger.Debug(c.ctx, "closing client connection", slog.Error(err))
}
}()
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.ClientTunnelAuth{AgentID: agent})
err := sendCtx(ctx, reqs, &proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)},
})
if err != nil {
// can only be a context error, no need to log here.
return err
}
defer func() {
_ = sendCtx(ctx, reqs, &proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)},
})
}()
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient)
go tc.SendUpdates()
go v1SendLoop(ctx, cancel, logger, tc, resps)
go v1RecvLoop(ctx, cancel, logger, conn, reqs)
<-ctx.Done()
return nil
return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent)
}
func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
logger := c.logger.With(slog.F("agent_id", id), slog.F("name", name))
defer func() {
logger.Debug(c.ctx, "closing agent connection")
err := conn.Close()
logger.Debug(c.ctx, "closed agent connection", slog.Error(err))
}()
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
reqs, resps := c.Coordinate(ctx, id, name, agpl.AgentTunnelAuth{})
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, agpl.QueueKindAgent)
go tc.SendUpdates()
go v1SendLoop(ctx, cancel, logger, tc, resps)
go v1RecvLoop(ctx, cancel, logger, conn, reqs)
<-ctx.Done()
return nil
}
func v1RecvLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger,
conn net.Conn, reqs chan<- *proto.CoordinateRequest,
) {
defer cancel()
decoder := json.NewDecoder(conn)
for {
var node agpl.Node
err := decoder.Decode(&node)
if err != nil {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, io.ErrClosedPipe) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
websocket.CloseStatus(err) > 0 {
logger.Debug(ctx, "exiting recvLoop", slog.Error(err))
} else {
logger.Error(ctx, "failed to decode Node update", slog.Error(err))
}
return
}
logger.Debug(ctx, "got node update", slog.F("node", node))
pn, err := agpl.NodeToProto(&node)
if err != nil {
logger.Critical(ctx, "failed to convert v1 node", slog.F("node", node), slog.Error(err))
return
}
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: pn,
}}
if err := sendCtx(ctx, reqs, req); err != nil {
logger.Debug(ctx, "recvLoop ctx expired", slog.Error(err))
return
}
}
}
func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q agpl.Queue, resps <-chan *proto.CoordinateResponse) {
defer cancel()
for {
resp, err := recvCtx(ctx, resps)
if err != nil {
logger.Debug(ctx, "done reading responses", slog.Error(err))
return
}
logger.Debug(ctx, "v1: got response", slog.F("resp", resp))
nodes, err := agpl.OnlyNodeUpdates(resp)
if err != nil {
logger.Critical(ctx, "failed to decode resp", slog.F("resp", resp), slog.Error(err))
_ = q.CoordinatorClose()
return
}
// don't send empty updates
if len(nodes) == 0 {
logger.Debug(ctx, "skipping enqueueing 0-length v1 update")
continue
}
err = q.Enqueue(nodes)
if err != nil {
logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err))
}
}
return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name)
}
func (c *pgCoord) Close() error {
@ -378,43 +220,22 @@ func (c *pgCoord) Coordinate(
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
) {
logger := c.logger.With(slog.F("peer_id", id))
reqs := make(chan *proto.CoordinateRequest, requestResponseBuffSize)
reqs := make(chan *proto.CoordinateRequest, agpl.RequestBufferSize)
resps := make(chan *proto.CoordinateResponse, agpl.ResponseBufferSize)
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
err := sendCtx(c.ctx, c.newConnections, cIO)
err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
if err != nil {
// this can only happen if the context is canceled, no need to log
return reqs, resps
}
go func() {
<-cIO.Done()
_ = sendCtx(c.ctx, c.closeConnections, cIO)
_ = agpl.SendCtx(c.ctx, c.closeConnections, cIO)
}()
return reqs, resps
}
func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) {
select {
case <-ctx.Done():
return ctx.Err()
case c <- a:
return nil
}
}
func recvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
select {
case <-ctx.Done():
return a, ctx.Err()
case a, ok := <-c:
if ok {
return a, nil
}
return a, io.EOF
}
}
type tKey struct {
src uuid.UUID
dst uuid.UUID
@ -873,7 +694,7 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
eq, err := sm.node.Equal(mpng.node)
if err != nil {
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind))
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.node))
continue
}
if eq {
@ -1303,7 +1124,7 @@ func (q *querier) updateAll() {
go func(m *mapper) {
// make sure we send on the _mapper_ context, not our own in case the mapper is
// shutting down or shut down.
_ = sendCtx(m.ctx, m.update, struct{}{})
_ = agpl.SendCtx(m.ctx, m.update, struct{}{})
}(mpr)
}
}
@ -1603,7 +1424,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) {
h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id))
// send on a separate goroutine to avoid holding lock. Triggering update can be async
go func() {
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
}()
}
h.coordinators[id] = time.Now()
@ -1650,7 +1471,7 @@ func (h *heartbeats) checkExpiry() {
if expired {
// send on a separate goroutine to avoid holding lock. Triggering update can be async
go func() {
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
}()
}
// we need to reset the timer for when the next oldest coordinator will expire, if any.
@ -1685,14 +1506,14 @@ func (h *heartbeats) sendBeat() {
h.failedHeartbeats++
if h.failedHeartbeats == 3 {
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
}
return
}
h.logger.Debug(h.ctx, "sent heartbeat")
if h.failedHeartbeats >= 3 {
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
}
h.failedHeartbeats = 0
}

View File

@ -9,6 +9,8 @@ import (
"testing"
"time"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@ -612,18 +614,7 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p2.addTunnel(p1.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
agpltest.BidirectionalTunnels(ctx, t, coordinator)
}
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
@ -638,21 +629,7 @@ func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
p2.disconnect()
p1.assertEventuallyDisconnected(p2.id)
p2.assertEventuallyResponsesClosed()
agpltest.GracefulDisconnectTest(ctx, t, coordinator)
}
func TestPGCoordinator_Lost(t *testing.T) {
@ -667,20 +644,7 @@ func TestPGCoordinator_Lost(t *testing.T) {
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
p2.close(ctx)
p1.assertEventuallyLost(p2.id)
agpltest.LostTest(ctx, t, coordinator)
}
type testConn struct {
@ -918,177 +882,6 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
}, testutil.WaitShort, testutil.IntervalFast)
}
type peerStatus struct {
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
}
type testPeer struct {
ctx context.Context
cancel context.CancelFunc
t testing.TB
id uuid.UUID
name string
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
peers map[uuid.UUID]peerStatus
}
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)}
p.ctx, p.cancel = context.WithCancel(ctx)
if len(id) > 1 {
t.Fatal("too many")
}
if len(id) == 1 {
p.id = id[0]
} else {
p.id = uuid.New()
}
// SingleTailnetTunnelAuth allows connections to arbitrary peers
p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{})
return p
}
func (p *testPeer) addTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout adding tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) updateDERP(derp int32) {
p.t.Helper()
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) disconnect() {
p.t.Helper()
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
p.t.Helper()
for {
o, ok := p.peers[other]
if ok && o.preferredDERP == derp {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) {
p.t.Helper()
for {
_, ok := p.peers[other]
if !ok {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyLost(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *testPeer) assertEventuallyResponsesClosed() {
p.t.Helper()
for {
err := p.handleOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
if !assert.NoError(p.t, err) {
return
}
}
}
var responsesClosed = xerrors.New("responses closed")
func (p *testPeer) handleOneResp() error {
select {
case <-p.ctx.Done():
return p.ctx.Err()
case resp, ok := <-p.resps:
if !ok {
return responsesClosed
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Uuid)
if err != nil {
return err
}
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
p.peers[id] = peerStatus{
preferredDERP: update.GetNode().GetPreferredDerp(),
status: update.Kind,
}
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
}
}
}
return nil
}
func (p *testPeer) close(ctx context.Context) {
p.t.Helper()
p.cancel()
for {
select {
case <-ctx.Done():
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
return
case _, ok := <-p.resps:
if ok {
continue
}
return
}
}
}
type fakeCoordinator struct {
ctx context.Context
t *testing.T

View File

@ -83,6 +83,7 @@ result
# Testdata shouldn't be formatted.
../scripts/apitypings/testdata/**/*.ts
../enterprise/tailnet/testdata/*.golden.html
../tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
e2e/provisionerGenerated.ts

View File

@ -83,6 +83,7 @@ result
# Testdata shouldn't be formatted.
../scripts/apitypings/testdata/**/*.ts
../enterprise/tailnet/testdata/*.golden.html
../tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
e2e/provisionerGenerated.ts

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,74 @@
package tailnet
import (
"bytes"
"flag"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
// UpdateGoldenFiles indicates golden files should be updated.
// To update the golden files:
// make update-golden-files
var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
func TestDebugTemplate(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("newlines screw up golden files on windows")
}
p1 := uuid.MustParse("01000000-2222-2222-2222-222222222222")
p2 := uuid.MustParse("02000000-2222-2222-2222-222222222222")
in := HTMLDebug{
Peers: []HTMLPeer{
{
Name: "Peer 1",
ID: p1,
LastWriteAge: 5 * time.Second,
Node: `id:1 preferred_derp:999 endpoints:"192.168.0.49:4449"`,
CreatedAge: 87 * time.Second,
Overwrites: 0,
},
{
Name: "Peer 2",
ID: p2,
LastWriteAge: 7 * time.Second,
Node: `id:2 preferred_derp:999 endpoints:"192.168.0.33:4449"`,
CreatedAge: time.Hour,
Overwrites: 2,
},
},
Tunnels: []HTMLTunnel{
{
Src: p1,
Dst: p2,
},
},
}
buf := new(bytes.Buffer)
err := debugTempl.Execute(buf, in)
require.NoError(t, err)
actual := buf.Bytes()
goldenPath := filepath.Join("testdata", "debug.golden.html")
if *UpdateGoldenFiles {
t.Logf("update golden file %s", goldenPath)
err := os.WriteFile(goldenPath, actual, 0o600)
require.NoError(t, err, "update golden file")
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes")
require.Equal(
t, string(expected), string(actual),
"golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes",
goldenPath,
)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
@ -27,7 +28,12 @@ func TestCoordinator(t *testing.T) {
t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
@ -45,14 +51,19 @@ func TestCoordinator(t *testing.T) {
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, client.Close())
require.NoError(t, server.Close())
<-errChan
<-closeChan
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
@ -70,14 +81,18 @@ func TestCoordinator(t *testing.T) {
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
<-errChan
<-closeChan
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
// in this test we use real websockets to test use of deadlines
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
@ -116,14 +131,11 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeClientChan)
}()
select {
case agentNodes := <-clientNodeChan:
require.Len(t, agentNodes, 1)
case <-ctx.Done():
t.Fatal("timed out")
}
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := <-agentNodeChan
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
// wait longer than the internal wait timeout.
@ -132,18 +144,14 @@ func TestCoordinator(t *testing.T) {
// Ensure an update to the agent node reaches the client!
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
select {
case agentNodes := <-clientNodeChan:
require.Len(t, agentNodes, 1)
case <-ctx.Done():
t.Fatal("timed out")
}
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect.
err := agentWS.Close()
require.NoError(t, err)
<-agentErrChan
<-closeAgentChan
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
// Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe()
@ -159,30 +167,32 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeAgentChan)
}()
// Ensure the existing listening client sends it's node immediately!
clientNodes = <-agentNodeChan
// Ensure the existing listening client sends its node immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
<-agentErrChan
<-closeAgentChan
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
err = clientWS.Close()
require.NoError(t, err)
<-clientErrChan
<-closeClientChan
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})
t.Run("AgentDoubleConnect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitLong)
agentWS1, agentServerWS1 := net.Pipe()
defer agentWS1.Close()
agentNodeChan1 := make(chan []*tailnet.Node)
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
t.Logf("agent1 got node update: %v", nodes)
agentNodeChan1 <- nodes
return nil
})
@ -203,6 +213,7 @@ func TestCoordinator(t *testing.T) {
defer clientServerWS.Close()
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
t.Logf("client got node update: %v", nodes)
clientNodeChan <- nodes
return nil
})
@ -213,15 +224,15 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := <-clientNodeChan
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := <-agentNodeChan1
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
agentNodes = <-clientNodeChan
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
// Create a new agent connection without disconnecting the old one.
@ -229,6 +240,7 @@ func TestCoordinator(t *testing.T) {
defer agentWS2.Close()
agentNodeChan2 := make(chan []*tailnet.Node)
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
t.Logf("agent2 got node update: %v", nodes)
agentNodeChan2 <- nodes
return nil
})
@ -240,33 +252,22 @@ func TestCoordinator(t *testing.T) {
}()
// Ensure the existing listening client sends it's node immediately!
clientNodes = <-agentNodeChan2
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2)
require.Len(t, clientNodes, 1)
counts, ok := coordinator.(interface {
NodeCount() int
AgentCount() int
})
if !ok {
t.Fatal("coordinator should have NodeCount() and AgentCount()")
}
assert.Equal(t, 2, counts.NodeCount())
assert.Equal(t, 1, counts.AgentCount())
// This original agent websocket should've been closed forcefully.
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan1)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
err := agentWS2.Close()
require.NoError(t, err)
<-agentErrChan2
<-closeAgentChan2
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
err = clientWS.Close()
require.NoError(t, err)
<-clientErrChan
<-closeClientChan
// This original agent websocket should've been closed forcefully.
<-agentErrChan1
<-closeAgentChan1
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
})
}
@ -334,9 +335,8 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
require.NoError(t, err)
n, err = clientWS.Read(buf[1:])
require.NoError(t, err)
require.Equal(t, len(buf)-1, n)
var cNodes []*tailnet.Node
err = json.Unmarshal(buf, &cNodes)
err = json.Unmarshal(buf[:n+1], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 0, cNodes[0].PreferredDERP)
@ -348,13 +348,36 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
require.NoError(t, err)
n, err = clientWS.Read(buf)
require.NoError(t, err)
require.Equal(t, len(buf), n)
err = json.Unmarshal(buf, &cNodes)
err = json.Unmarshal(buf[:n], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 1, cNodes[0].PreferredDERP)
}
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinatorV2(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.BidirectionalTunnels(ctx, t, coordinator)
}
func TestCoordinator_GracefulDisconnect(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinatorV2(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.GracefulDisconnectTest(ctx, t, coordinator)
}
func TestCoordinator_Lost(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinatorV2(logger)
ctx := testutil.Context(t, testutil.WaitShort)
test.LostTest(ctx, t, coordinator)
}
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
t.Helper()
sc := make(chan net.Conn, 1)

160
tailnet/peer.go Normal file
View File

@ -0,0 +1,160 @@
package tailnet
import (
"context"
"time"
"golang.org/x/xerrors"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
)
type peer struct {
logger slog.Logger
id uuid.UUID
node *proto.Node
resps chan<- *proto.CoordinateResponse
reqs <-chan *proto.CoordinateRequest
auth TunnelAuth
sent map[uuid.UUID]*proto.Node
name string
start time.Time
lastWrite time.Time
overwrites int
}
// updateMappingLocked updates the mapping for another peer linked to this one by a tunnel. This method
// is NOT threadsafe and must be called while holding the core lock.
func (p *peer) updateMappingLocked(id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) error {
logger := p.logger.With(slog.F("from_id", id), slog.F("kind", k), slog.F("reason", reason))
update, err := p.storeMappingLocked(id, n, k, reason)
if xerrors.Is(err, noResp) {
logger.Debug(context.Background(), "skipping update")
return nil
}
if err != nil {
return err
}
req := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{update}}
select {
case p.resps <- req:
p.lastWrite = time.Now()
logger.Debug(context.Background(), "wrote peer update")
return nil
default:
return ErrWouldBlock
}
}
// batchUpdateMapping updates the mappings for a list of peers linked to this one by a tunnel. This
// method is NOT threadsafe and must be called while holding the core lock.
func (p *peer) batchUpdateMappingLocked(others []*peer, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) error {
req := &proto.CoordinateResponse{}
for _, other := range others {
if other == nil || other.node == nil {
continue
}
update, err := p.storeMappingLocked(other.id, other.node, k, reason)
if xerrors.Is(err, noResp) {
continue
}
if err != nil {
return err
}
req.PeerUpdates = append(req.PeerUpdates, update)
}
if len(req.PeerUpdates) == 0 {
return nil
}
select {
case p.resps <- req:
p.lastWrite = time.Now()
p.logger.Debug(context.Background(), "wrote batched update", slog.F("num_peer_updates", len(req.PeerUpdates)))
return nil
default:
return ErrWouldBlock
}
}
var noResp = xerrors.New("no response needed")
func (p *peer) storeMappingLocked(
id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string,
) (
*proto.CoordinateResponse_PeerUpdate, error,
) {
p.logger.Debug(context.Background(), "got updated mapping",
slog.F("from_id", id), slog.F("kind", k), slog.F("reason", reason))
sn, ok := p.sent[id]
switch {
case !ok && (k == proto.CoordinateResponse_PeerUpdate_LOST || k == proto.CoordinateResponse_PeerUpdate_DISCONNECTED):
// we don't need to send a lost/disconnect update if we've never sent an update about this peer
return nil, noResp
case !ok && k == proto.CoordinateResponse_PeerUpdate_NODE:
p.sent[id] = n
case ok && k == proto.CoordinateResponse_PeerUpdate_LOST:
delete(p.sent, id)
case ok && k == proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.sent, id)
case ok && k == proto.CoordinateResponse_PeerUpdate_NODE:
eq, err := sn.Equal(n)
if err != nil {
p.logger.Critical(context.Background(), "failed to compare nodes", slog.F("old", sn), slog.F("new", n))
return nil, xerrors.Errorf("failed to compare nodes: %s", sn.String())
}
if eq {
return nil, noResp
}
p.sent[id] = n
}
return &proto.CoordinateResponse_PeerUpdate{
Uuid: id[:],
Kind: k,
Node: n,
Reason: reason,
}, nil
}
func (p *peer) reqLoop(ctx context.Context, logger slog.Logger, handler func(*peer, *proto.CoordinateRequest) error) {
for {
select {
case <-ctx.Done():
logger.Debug(ctx, "peerReadLoop context done")
return
case req, ok := <-p.reqs:
if !ok {
logger.Debug(ctx, "peerReadLoop channel closed")
return
}
logger.Debug(ctx, "peerReadLoop got request")
if err := handler(p, req); err != nil {
if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
return
}
logger.Error(ctx, "peerReadLoop error handling request", slog.Error(err), slog.F("request", req))
return
}
}
}
}
func (p *peer) htmlDebug() HTMLPeer {
node := "<nil>"
if p.node != nil {
node = p.node.String()
}
return HTMLPeer{
ID: p.id,
Name: p.name,
CreatedAge: time.Since(p.start),
LastWriteAge: time.Since(p.lastWrite),
Overwrites: p.overwrites,
Node: node,
}
}

55
tailnet/test/cases.go Normal file
View File

@ -0,0 +1,55 @@
package test
import (
"context"
"testing"
"github.com/coder/coder/v2/tailnet"
)
func GracefulDisconnectTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
p1 := NewPeer(ctx, t, coordinator, "p1")
defer p1.Close(ctx)
p2 := NewPeer(ctx, t, coordinator, "p2")
defer p2.Close(ctx)
p1.AddTunnel(p2.ID)
p1.UpdateDERP(1)
p2.UpdateDERP(2)
p1.AssertEventuallyHasDERP(p2.ID, 2)
p2.AssertEventuallyHasDERP(p1.ID, 1)
p2.Disconnect()
p1.AssertEventuallyDisconnected(p2.ID)
p2.AssertEventuallyResponsesClosed()
}
func LostTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
p1 := NewPeer(ctx, t, coordinator, "p1")
defer p1.Close(ctx)
p2 := NewPeer(ctx, t, coordinator, "p2")
defer p2.Close(ctx)
p1.AddTunnel(p2.ID)
p1.UpdateDERP(1)
p2.UpdateDERP(2)
p1.AssertEventuallyHasDERP(p2.ID, 2)
p2.AssertEventuallyHasDERP(p1.ID, 1)
p2.Close(ctx)
p1.AssertEventuallyLost(p2.ID)
}
func BidirectionalTunnels(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
p1 := NewPeer(ctx, t, coordinator, "p1")
defer p1.Close(ctx)
p2 := NewPeer(ctx, t, coordinator, "p2")
defer p2.Close(ctx)
p1.AddTunnel(p2.ID)
p2.AddTunnel(p1.ID)
p1.UpdateDERP(1)
p2.UpdateDERP(2)
p1.AssertEventuallyHasDERP(p2.ID, 2)
p2.AssertEventuallyHasDERP(p1.ID, 1)
}

184
tailnet/test/peer.go Normal file
View File

@ -0,0 +1,184 @@
package test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
type PeerStatus struct {
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
}
type Peer struct {
ctx context.Context
cancel context.CancelFunc
t testing.TB
ID uuid.UUID
name string
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
peers map[uuid.UUID]PeerStatus
}
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer {
p := &Peer{t: t, name: name, peers: make(map[uuid.UUID]PeerStatus)}
p.ctx, p.cancel = context.WithCancel(ctx)
if len(id) > 1 {
t.Fatal("too many")
}
if len(id) == 1 {
p.ID = id[0]
} else {
p.ID = uuid.New()
}
// SingleTailnetTunnelAuth allows connections to arbitrary peers
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{})
return p
}
func (p *Peer) AddTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: tailnet.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout adding tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) UpdateDERP(derp int32) {
p.t.Helper()
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) Disconnect() {
p.t.Helper()
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
p.t.Helper()
for {
o, ok := p.peers[other]
if ok && o.preferredDERP == derp {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
p.t.Helper()
for {
_, ok := p.peers[other]
if !ok {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
return
}
if err := p.handleOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *Peer) AssertEventuallyResponsesClosed() {
p.t.Helper()
for {
err := p.handleOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
if !assert.NoError(p.t, err) {
return
}
}
}
var responsesClosed = xerrors.New("responses closed")
func (p *Peer) handleOneResp() error {
select {
case <-p.ctx.Done():
return p.ctx.Err()
case resp, ok := <-p.resps:
if !ok {
return responsesClosed
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Uuid)
if err != nil {
return err
}
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
p.peers[id] = PeerStatus{
preferredDERP: update.GetNode().GetPreferredDerp(),
status: update.Kind,
}
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
}
}
}
return nil
}
func (p *Peer) Close(ctx context.Context) {
p.t.Helper()
p.cancel()
for {
select {
case <-ctx.Done():
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
return
case _, ok := <-p.resps:
if ok {
continue
}
return
}
}
}

62
tailnet/testdata/debug.golden.html vendored Normal file
View File

@ -0,0 +1,62 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
th, td {
padding-top: 6px;
padding-bottom: 6px;
padding-left: 10px;
padding-right: 10px;
text-align: left;
}
tr {
border-bottom: 1px solid #ddd;
}
</style>
</head>
<body>
<h1>in-memory wireguard coordinator debug</h1>
<h2 id=peers> <a href=#peers>#</a> peers: total 2 </h2>
<table>
<tr style="margin-top:4px">
<th>Name</th>
<th>ID</th>
<th>Created Age</th>
<th>Last Write Age</th>
<th>Overwrites</th>
<th>Node</th>
</tr>
<tr style="margin-top:4px">
<td>Peer 1</td>
<td>01000000-2222-2222-2222-222222222222</td>
<td>1m27s</td>
<td>5s ago</td>
<td>0</td>
<td style="white-space: pre;"><code>id:1 preferred_derp:999 endpoints:&#34;192.168.0.49:4449&#34;</code></td>
</tr>
<tr style="margin-top:4px">
<td>Peer 2</td>
<td>02000000-2222-2222-2222-222222222222</td>
<td>1h0m0s</td>
<td>7s ago</td>
<td>2</td>
<td style="white-space: pre;"><code>id:2 preferred_derp:999 endpoints:&#34;192.168.0.33:4449&#34;</code></td>
</tr>
</table>
<h2 id=tunnels><a href=#tunnels>#</a> tunnels: total 1</h2>
<table>
<tr style="margin-top:4px">
<th>SrcID</th>
<th>DstID</th>
</tr>
<tr style="margin-top:4px">
<td>01000000-2222-2222-2222-222222222222</td>
<td>02000000-2222-2222-2222-222222222222</td>
</tr>
</table>
</body>
</html>

View File

@ -20,6 +20,8 @@ const (
// ResponseBufferSize is the max number of responses to buffer per connection before we start
// dropping updates
ResponseBufferSize = 512
// RequestBufferSize is the max number of requests to buffer per connection
RequestBufferSize = 32
)
type TrackedConn struct {

View File

@ -28,3 +28,74 @@ type AgentTunnelAuth struct{}
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
return false
}
// tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all
// methods must be serialized by holding, e.g. the core mutex.
type tunnelStore struct {
bySrc map[uuid.UUID]map[uuid.UUID]struct{}
byDst map[uuid.UUID]map[uuid.UUID]struct{}
}
func newTunnelStore() *tunnelStore {
return &tunnelStore{
bySrc: make(map[uuid.UUID]map[uuid.UUID]struct{}),
byDst: make(map[uuid.UUID]map[uuid.UUID]struct{}),
}
}
func (s *tunnelStore) add(src, dst uuid.UUID) {
srcM, ok := s.bySrc[src]
if !ok {
srcM = make(map[uuid.UUID]struct{})
s.bySrc[src] = srcM
}
srcM[dst] = struct{}{}
dstM, ok := s.byDst[dst]
if !ok {
dstM = make(map[uuid.UUID]struct{})
s.byDst[dst] = dstM
}
dstM[src] = struct{}{}
}
func (s *tunnelStore) remove(src, dst uuid.UUID) {
delete(s.bySrc[src], dst)
if len(s.bySrc[src]) == 0 {
delete(s.bySrc, src)
}
delete(s.byDst[dst], src)
if len(s.byDst[dst]) == 0 {
delete(s.byDst, dst)
}
}
func (s *tunnelStore) removeAll(src uuid.UUID) {
for dst := range s.bySrc[src] {
s.remove(src, dst)
}
}
func (s *tunnelStore) findTunnelPeers(id uuid.UUID) []uuid.UUID {
set := make(map[uuid.UUID]struct{})
for dst := range s.bySrc[id] {
set[dst] = struct{}{}
}
for src := range s.byDst[id] {
set[src] = struct{}{}
}
out := make([]uuid.UUID, 0, len(set))
for id := range set {
out = append(out, id)
}
return out
}
func (s *tunnelStore) htmlDebug() []HTMLTunnel {
out := make([]HTMLTunnel, 0)
for src, dsts := range s.bySrc {
for dst := range dsts {
out = append(out, HTMLTunnel{Src: src, Dst: dst})
}
}
return out
}

View File

@ -0,0 +1,45 @@
package tailnet
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestTunnelStore_Bidir(t *testing.T) {
t.Parallel()
p1 := uuid.MustParse("00000001-1111-1111-1111-111111111111")
p2 := uuid.MustParse("00000002-1111-1111-1111-111111111111")
uut := newTunnelStore()
uut.add(p1, p2)
require.Equal(t, []uuid.UUID{p1}, uut.findTunnelPeers(p2))
require.Equal(t, []uuid.UUID{p2}, uut.findTunnelPeers(p1))
uut.remove(p1, p2)
require.Empty(t, uut.findTunnelPeers(p1))
require.Empty(t, uut.findTunnelPeers(p2))
require.Len(t, uut.byDst, 0)
require.Len(t, uut.bySrc, 0)
}
func TestTunnelStore_RemoveAll(t *testing.T) {
t.Parallel()
p1 := uuid.MustParse("00000001-1111-1111-1111-111111111111")
p2 := uuid.MustParse("00000002-1111-1111-1111-111111111111")
p3 := uuid.MustParse("00000003-1111-1111-1111-111111111111")
uut := newTunnelStore()
uut.add(p1, p2)
uut.add(p1, p3)
uut.add(p3, p1)
require.Len(t, uut.findTunnelPeers(p1), 2)
require.Len(t, uut.findTunnelPeers(p2), 1)
require.Len(t, uut.findTunnelPeers(p3), 1)
uut.removeAll(p1)
require.Len(t, uut.findTunnelPeers(p1), 1)
require.Len(t, uut.findTunnelPeers(p2), 0)
require.Len(t, uut.findTunnelPeers(p3), 1)
uut.removeAll(p3)
require.Len(t, uut.findTunnelPeers(p1), 0)
require.Len(t, uut.findTunnelPeers(p2), 0)
require.Len(t, uut.findTunnelPeers(p3), 0)
}