mirror of https://github.com/coder/coder.git
parent
38ed816207
commit
2c86d0bed0
|
@ -83,6 +83,7 @@ helm/**/templates/*.yaml
|
||||||
# Testdata shouldn't be formatted.
|
# Testdata shouldn't be formatted.
|
||||||
scripts/apitypings/testdata/**/*.ts
|
scripts/apitypings/testdata/**/*.ts
|
||||||
enterprise/tailnet/testdata/*.golden.html
|
enterprise/tailnet/testdata/*.golden.html
|
||||||
|
tailnet/testdata/*.golden.html
|
||||||
|
|
||||||
# Generated files shouldn't be formatted.
|
# Generated files shouldn't be formatted.
|
||||||
site/e2e/provisionerGenerated.ts
|
site/e2e/provisionerGenerated.ts
|
||||||
|
|
|
@ -9,6 +9,7 @@ helm/**/templates/*.yaml
|
||||||
# Testdata shouldn't be formatted.
|
# Testdata shouldn't be formatted.
|
||||||
scripts/apitypings/testdata/**/*.ts
|
scripts/apitypings/testdata/**/*.ts
|
||||||
enterprise/tailnet/testdata/*.golden.html
|
enterprise/tailnet/testdata/*.golden.html
|
||||||
|
tailnet/testdata/*.golden.html
|
||||||
|
|
||||||
# Generated files shouldn't be formatted.
|
# Generated files shouldn't be formatted.
|
||||||
site/e2e/provisionerGenerated.ts
|
site/e2e/provisionerGenerated.ts
|
||||||
|
|
5
Makefile
5
Makefile
|
@ -602,6 +602,7 @@ update-golden-files: \
|
||||||
scripts/ci-report/testdata/.gen-golden \
|
scripts/ci-report/testdata/.gen-golden \
|
||||||
enterprise/cli/testdata/.gen-golden \
|
enterprise/cli/testdata/.gen-golden \
|
||||||
enterprise/tailnet/testdata/.gen-golden \
|
enterprise/tailnet/testdata/.gen-golden \
|
||||||
|
tailnet/testdata/.gen-golden \
|
||||||
coderd/.gen-golden \
|
coderd/.gen-golden \
|
||||||
provisioner/terraform/testdata/.gen-golden
|
provisioner/terraform/testdata/.gen-golden
|
||||||
.PHONY: update-golden-files
|
.PHONY: update-golden-files
|
||||||
|
@ -614,6 +615,10 @@ enterprise/cli/testdata/.gen-golden: $(wildcard enterprise/cli/testdata/*.golden
|
||||||
go test ./enterprise/cli -run="TestEnterpriseCommandHelp" -update
|
go test ./enterprise/cli -run="TestEnterpriseCommandHelp" -update
|
||||||
touch "$@"
|
touch "$@"
|
||||||
|
|
||||||
|
tailnet/testdata/.gen-golden: $(wildcard tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard tailnet/*_test.go)
|
||||||
|
go test ./tailnet -run="TestDebugTemplate" -update
|
||||||
|
touch "$@"
|
||||||
|
|
||||||
enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard enterprise/tailnet/*_test.go)
|
enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/*.golden.html) $(GO_SRC_FILES) $(wildcard enterprise/tailnet/*_test.go)
|
||||||
go test ./enterprise/tailnet -run="TestDebugTemplate" -update
|
go test ./enterprise/tailnet -run="TestDebugTemplate" -update
|
||||||
touch "$@"
|
touch "$@"
|
||||||
|
|
|
@ -86,7 +86,7 @@ func (c *connIO) recvLoop() {
|
||||||
if c.disconnected {
|
if c.disconnected {
|
||||||
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
|
b.kind = proto.CoordinateResponse_PeerUpdate_DISCONNECTED
|
||||||
}
|
}
|
||||||
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
|
if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil {
|
||||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
|
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
|
||||||
}
|
}
|
||||||
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
|
// only remove tunnels on graceful disconnect. If we remove tunnels for lost peers, then
|
||||||
|
@ -97,14 +97,14 @@ func (c *connIO) recvLoop() {
|
||||||
tKey: tKey{src: c.UniqueID()},
|
tKey: tKey{src: c.UniqueID()},
|
||||||
active: false,
|
active: false,
|
||||||
}
|
}
|
||||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
|
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
for {
|
for {
|
||||||
req, err := recvCtx(c.peerCtx, c.requests)
|
req, err := agpl.RecvCtx(c.peerCtx, c.requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if xerrors.Is(err, context.Canceled) ||
|
if xerrors.Is(err, context.Canceled) ||
|
||||||
xerrors.Is(err, context.DeadlineExceeded) ||
|
xerrors.Is(err, context.DeadlineExceeded) ||
|
||||||
|
@ -132,7 +132,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||||
node: req.UpdateSelf.Node,
|
node: req.UpdateSelf.Node,
|
||||||
kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||||
}
|
}
|
||||||
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
|
if err := agpl.SendCtx(c.coordCtx, c.bindings, b); err != nil {
|
||||||
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
|
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||||
},
|
},
|
||||||
active: true,
|
active: true,
|
||||||
}
|
}
|
||||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||||
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
|
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,7 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||||
},
|
},
|
||||||
active: false,
|
active: false,
|
||||||
}
|
}
|
||||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
if err := agpl.SendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||||
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
|
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,17 +5,22 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
lru "github.com/hashicorp/golang-lru/v2"
|
lru "github.com/hashicorp/golang-lru/v2"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
|
"github.com/coder/coder/v2/coderd/util/slice"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
agpl "github.com/coder/coder/v2/tailnet"
|
||||||
)
|
)
|
||||||
|
@ -719,7 +724,209 @@ func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
|
||||||
c.mutex.RLock()
|
c.mutex.RLock()
|
||||||
defer c.mutex.RUnlock()
|
defer c.mutex.RUnlock()
|
||||||
|
|
||||||
agpl.CoordinatorHTTPDebug(
|
CoordinatorHTTPDebug(
|
||||||
agpl.HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
|
HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache),
|
||||||
)(w, r)
|
)(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HTTPDebugFromLocal(
|
||||||
|
ha bool,
|
||||||
|
agentSocketsMap map[uuid.UUID]agpl.Queue,
|
||||||
|
agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]agpl.Queue,
|
||||||
|
nodesMap map[uuid.UUID]*agpl.Node,
|
||||||
|
agentNameCache *lru.Cache[uuid.UUID, string],
|
||||||
|
) HTMLDebugHA {
|
||||||
|
now := time.Now()
|
||||||
|
data := HTMLDebugHA{HA: ha}
|
||||||
|
for id, conn := range agentSocketsMap {
|
||||||
|
start, lastWrite := conn.Stats()
|
||||||
|
agent := &HTMLAgent{
|
||||||
|
Name: conn.Name(),
|
||||||
|
ID: id,
|
||||||
|
CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
|
||||||
|
LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
|
||||||
|
Overwrites: int(conn.Overwrites()),
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, conn := range agentToConnectionSocketsMap[id] {
|
||||||
|
start, lastWrite := conn.Stats()
|
||||||
|
agent.Connections = append(agent.Connections, &HTMLClient{
|
||||||
|
Name: conn.Name(),
|
||||||
|
ID: id,
|
||||||
|
CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
|
||||||
|
LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
|
||||||
|
return slice.Ascending(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
data.Agents = append(data.Agents, agent)
|
||||||
|
}
|
||||||
|
slices.SortFunc(data.Agents, func(a, b *HTMLAgent) int {
|
||||||
|
return slice.Ascending(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
for agentID, conns := range agentToConnectionSocketsMap {
|
||||||
|
if len(conns) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := agentSocketsMap[agentID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
agentName, ok := agentNameCache.Get(agentID)
|
||||||
|
if !ok {
|
||||||
|
agentName = "unknown"
|
||||||
|
}
|
||||||
|
agent := &HTMLAgent{
|
||||||
|
Name: agentName,
|
||||||
|
ID: agentID,
|
||||||
|
}
|
||||||
|
for id, conn := range conns {
|
||||||
|
start, lastWrite := conn.Stats()
|
||||||
|
agent.Connections = append(agent.Connections, &HTMLClient{
|
||||||
|
Name: conn.Name(),
|
||||||
|
ID: id,
|
||||||
|
CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second),
|
||||||
|
LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
slices.SortFunc(agent.Connections, func(a, b *HTMLClient) int {
|
||||||
|
return slice.Ascending(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
data.MissingAgents = append(data.MissingAgents, agent)
|
||||||
|
}
|
||||||
|
slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) int {
|
||||||
|
return slice.Ascending(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
for id, node := range nodesMap {
|
||||||
|
name, _ := agentNameCache.Get(id)
|
||||||
|
data.Nodes = append(data.Nodes, &HTMLNode{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
Node: node,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
slices.SortFunc(data.Nodes, func(a, b *HTMLNode) int {
|
||||||
|
return slice.Ascending(a.Name+a.ID.String(), b.Name+b.ID.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func CoordinatorHTTPDebug(data HTMLDebugHA) func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
|
||||||
|
tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{
|
||||||
|
"marshal": func(v any) template.JS {
|
||||||
|
a, err := json.MarshalIndent(v, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
//nolint:gosec
|
||||||
|
return template.JS(fmt.Sprintf(`{"err": %q}`, err))
|
||||||
|
}
|
||||||
|
//nolint:gosec
|
||||||
|
return template.JS(a)
|
||||||
|
},
|
||||||
|
}).Parse(haCoordinatorDebugTmpl)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tmpl.Execute(w, data)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTMLDebugHA struct {
|
||||||
|
HA bool
|
||||||
|
Agents []*HTMLAgent
|
||||||
|
MissingAgents []*HTMLAgent
|
||||||
|
Nodes []*HTMLNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTMLAgent struct {
|
||||||
|
Name string
|
||||||
|
ID uuid.UUID
|
||||||
|
CreatedAge time.Duration
|
||||||
|
LastWriteAge time.Duration
|
||||||
|
Overwrites int
|
||||||
|
Connections []*HTMLClient
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTMLClient struct {
|
||||||
|
Name string
|
||||||
|
ID uuid.UUID
|
||||||
|
CreatedAge time.Duration
|
||||||
|
LastWriteAge time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTMLNode struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Name string
|
||||||
|
Node any
|
||||||
|
}
|
||||||
|
|
||||||
|
var haCoordinatorDebugTmpl = `
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
{{- if .HA }}
|
||||||
|
<h1>high-availability wireguard coordinator debug</h1>
|
||||||
|
<h4 style="margin-top:-25px">warning: this only provides info from the node that served the request, if there are multiple replicas this data may be incomplete</h4>
|
||||||
|
{{- else }}
|
||||||
|
<h1>in-memory wireguard coordinator debug</h1>
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
<h2 id=agents> <a href=#agents>#</a> agents: total {{ len .Agents }} </h2>
|
||||||
|
<ul>
|
||||||
|
{{- range .Agents }}
|
||||||
|
<li style="margin-top:4px">
|
||||||
|
<b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago, overwrites {{ .Overwrites }}
|
||||||
|
<h3 style="margin:0px;font-size:16px;font-weight:400"> connections: total {{ len .Connections}} </h3>
|
||||||
|
<ul>
|
||||||
|
{{- range .Connections }}
|
||||||
|
<li><b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago </li>
|
||||||
|
{{- end }}
|
||||||
|
</ul>
|
||||||
|
</li>
|
||||||
|
{{- end }}
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2 id=missing-agents><a href=#missing-agents>#</a> missing agents: total {{ len .MissingAgents }}</h2>
|
||||||
|
<ul>
|
||||||
|
{{- range .MissingAgents}}
|
||||||
|
<li style="margin-top:4px"><b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created ? ago, write ? ago, overwrites ? </li>
|
||||||
|
<h3 style="margin:0px;font-size:16px;font-weight:400"> connections: total {{ len .Connections }} </h3>
|
||||||
|
<ul>
|
||||||
|
{{- range .Connections }}
|
||||||
|
<li><b>{{ .Name }}</b> (<code>{{ .ID }}</code>): created {{ .CreatedAge }} ago, write {{ .LastWriteAge }} ago </li>
|
||||||
|
{{- end }}
|
||||||
|
</ul>
|
||||||
|
{{- end }}
|
||||||
|
</ul>
|
||||||
|
|
||||||
|
<h2 id=nodes><a href=#nodes>#</a> nodes: total {{ len .Nodes }}</h2>
|
||||||
|
<ul>
|
||||||
|
{{- range .Nodes }}
|
||||||
|
<li style="margin-top:4px"><b>{{ .Name }}</b> (<code>{{ .ID }}</code>):
|
||||||
|
<span style="white-space: pre;"><code>{{ marshal .Node }}</code></span>
|
||||||
|
</li>
|
||||||
|
{{- end }}
|
||||||
|
</ul>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`
|
||||||
|
|
|
@ -3,17 +3,12 @@ package tailnet
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"nhooyr.io/websocket"
|
|
||||||
|
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
@ -30,17 +25,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EventHeartbeats = "tailnet_coordinator_heartbeat"
|
EventHeartbeats = "tailnet_coordinator_heartbeat"
|
||||||
eventPeerUpdate = "tailnet_peer_update"
|
eventPeerUpdate = "tailnet_peer_update"
|
||||||
eventTunnelUpdate = "tailnet_tunnel_update"
|
eventTunnelUpdate = "tailnet_tunnel_update"
|
||||||
HeartbeatPeriod = time.Second * 2
|
HeartbeatPeriod = time.Second * 2
|
||||||
MissedHeartbeats = 3
|
MissedHeartbeats = 3
|
||||||
numQuerierWorkers = 10
|
numQuerierWorkers = 10
|
||||||
numBinderWorkers = 10
|
numBinderWorkers = 10
|
||||||
numTunnelerWorkers = 10
|
numTunnelerWorkers = 10
|
||||||
dbMaxBackoff = 10 * time.Second
|
dbMaxBackoff = 10 * time.Second
|
||||||
cleanupPeriod = time.Hour
|
cleanupPeriod = time.Hour
|
||||||
requestResponseBuffSize = 32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// pgCoord is a postgres-backed coordinator
|
// pgCoord is a postgres-backed coordinator
|
||||||
|
@ -161,55 +155,8 @@ func NewPGCoordV2(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, sto
|
||||||
return newPGCoordInternal(ctx, logger, ps, store)
|
return newPGCoordInternal(ctx, logger, ps, store)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is copied from codersdk because importing it here would cause an import
|
|
||||||
// cycle. This is just temporary until wsconncache is phased out.
|
|
||||||
var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
|
|
||||||
|
|
||||||
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
||||||
logger := c.logger.With(slog.F("client_id", id)).Named("multiagent")
|
return agpl.ServeMultiAgent(c, c.logger, id)
|
||||||
ctx, cancel := context.WithCancel(c.ctx)
|
|
||||||
reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.SingleTailnetTunnelAuth{})
|
|
||||||
ma := (&agpl.MultiAgent{
|
|
||||||
ID: id,
|
|
||||||
AgentIsLegacyFunc: func(agentID uuid.UUID) bool {
|
|
||||||
if n := c.Node(agentID); n == nil {
|
|
||||||
// If we don't have the node at all assume it's legacy for
|
|
||||||
// safety.
|
|
||||||
return true
|
|
||||||
} else if len(n.Addresses) > 0 && n.Addresses[0].Addr() == legacyAgentIP {
|
|
||||||
// An agent is determined to be "legacy" if it's first IP is the
|
|
||||||
// legacy IP. Agents with only the legacy IP aren't compatible
|
|
||||||
// with single_tailnet and must be routed through wsconncache.
|
|
||||||
return true
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
OnSubscribe: func(enq agpl.Queue, agent uuid.UUID) (*agpl.Node, error) {
|
|
||||||
err := sendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}})
|
|
||||||
return c.Node(agent), err
|
|
||||||
},
|
|
||||||
OnUnsubscribe: func(enq agpl.Queue, agent uuid.UUID) error {
|
|
||||||
err := sendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)}})
|
|
||||||
return err
|
|
||||||
},
|
|
||||||
OnNodeUpdate: func(id uuid.UUID, node *agpl.Node) error {
|
|
||||||
pn, err := agpl.NodeToProto(node)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return sendCtx(c.ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
|
||||||
Node: pn,
|
|
||||||
}})
|
|
||||||
},
|
|
||||||
OnRemove: func(_ agpl.Queue) {
|
|
||||||
_ = sendCtx(c.ctx, reqs, &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
|
||||||
cancel()
|
|
||||||
},
|
|
||||||
}).Init()
|
|
||||||
|
|
||||||
go v1SendLoop(ctx, cancel, logger, ma, resps)
|
|
||||||
return ma
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
|
func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
|
||||||
|
@ -253,116 +200,11 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||||
logger := c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
|
return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent)
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug(c.ctx, "closing client connection", slog.Error(err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ctx, cancel := context.WithCancel(c.ctx)
|
|
||||||
defer cancel()
|
|
||||||
reqs, resps := c.Coordinate(ctx, id, id.String(), agpl.ClientTunnelAuth{AgentID: agent})
|
|
||||||
err := sendCtx(ctx, reqs, &proto.CoordinateRequest{
|
|
||||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// can only be a context error, no need to log here.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = sendCtx(ctx, reqs, &proto.CoordinateRequest{
|
|
||||||
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(agent)},
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
|
|
||||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient)
|
|
||||||
go tc.SendUpdates()
|
|
||||||
go v1SendLoop(ctx, cancel, logger, tc, resps)
|
|
||||||
go v1RecvLoop(ctx, cancel, logger, conn, reqs)
|
|
||||||
<-ctx.Done()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
|
func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
|
||||||
logger := c.logger.With(slog.F("agent_id", id), slog.F("name", name))
|
return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name)
|
||||||
defer func() {
|
|
||||||
logger.Debug(c.ctx, "closing agent connection")
|
|
||||||
err := conn.Close()
|
|
||||||
logger.Debug(c.ctx, "closed agent connection", slog.Error(err))
|
|
||||||
}()
|
|
||||||
ctx, cancel := context.WithCancel(c.ctx)
|
|
||||||
defer cancel()
|
|
||||||
reqs, resps := c.Coordinate(ctx, id, name, agpl.AgentTunnelAuth{})
|
|
||||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, agpl.QueueKindAgent)
|
|
||||||
go tc.SendUpdates()
|
|
||||||
go v1SendLoop(ctx, cancel, logger, tc, resps)
|
|
||||||
go v1RecvLoop(ctx, cancel, logger, conn, reqs)
|
|
||||||
<-ctx.Done()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func v1RecvLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger,
|
|
||||||
conn net.Conn, reqs chan<- *proto.CoordinateRequest,
|
|
||||||
) {
|
|
||||||
defer cancel()
|
|
||||||
decoder := json.NewDecoder(conn)
|
|
||||||
for {
|
|
||||||
var node agpl.Node
|
|
||||||
err := decoder.Decode(&node)
|
|
||||||
if err != nil {
|
|
||||||
if xerrors.Is(err, io.EOF) ||
|
|
||||||
xerrors.Is(err, io.ErrClosedPipe) ||
|
|
||||||
xerrors.Is(err, context.Canceled) ||
|
|
||||||
xerrors.Is(err, context.DeadlineExceeded) ||
|
|
||||||
websocket.CloseStatus(err) > 0 {
|
|
||||||
logger.Debug(ctx, "exiting recvLoop", slog.Error(err))
|
|
||||||
} else {
|
|
||||||
logger.Error(ctx, "failed to decode Node update", slog.Error(err))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Debug(ctx, "got node update", slog.F("node", node))
|
|
||||||
pn, err := agpl.NodeToProto(&node)
|
|
||||||
if err != nil {
|
|
||||||
logger.Critical(ctx, "failed to convert v1 node", slog.F("node", node), slog.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
|
||||||
Node: pn,
|
|
||||||
}}
|
|
||||||
if err := sendCtx(ctx, reqs, req); err != nil {
|
|
||||||
logger.Debug(ctx, "recvLoop ctx expired", slog.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func v1SendLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q agpl.Queue, resps <-chan *proto.CoordinateResponse) {
|
|
||||||
defer cancel()
|
|
||||||
for {
|
|
||||||
resp, err := recvCtx(ctx, resps)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug(ctx, "done reading responses", slog.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Debug(ctx, "v1: got response", slog.F("resp", resp))
|
|
||||||
nodes, err := agpl.OnlyNodeUpdates(resp)
|
|
||||||
if err != nil {
|
|
||||||
logger.Critical(ctx, "failed to decode resp", slog.F("resp", resp), slog.Error(err))
|
|
||||||
_ = q.CoordinatorClose()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// don't send empty updates
|
|
||||||
if len(nodes) == 0 {
|
|
||||||
logger.Debug(ctx, "skipping enqueueing 0-length v1 update")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = q.Enqueue(nodes)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "failed to enqueue v1 update", slog.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pgCoord) Close() error {
|
func (c *pgCoord) Close() error {
|
||||||
|
@ -378,43 +220,22 @@ func (c *pgCoord) Coordinate(
|
||||||
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
|
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
|
||||||
) {
|
) {
|
||||||
logger := c.logger.With(slog.F("peer_id", id))
|
logger := c.logger.With(slog.F("peer_id", id))
|
||||||
reqs := make(chan *proto.CoordinateRequest, requestResponseBuffSize)
|
reqs := make(chan *proto.CoordinateRequest, agpl.RequestBufferSize)
|
||||||
resps := make(chan *proto.CoordinateResponse, agpl.ResponseBufferSize)
|
resps := make(chan *proto.CoordinateResponse, agpl.ResponseBufferSize)
|
||||||
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
|
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
|
||||||
err := sendCtx(c.ctx, c.newConnections, cIO)
|
err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this can only happen if the context is canceled, no need to log
|
// this can only happen if the context is canceled, no need to log
|
||||||
return reqs, resps
|
return reqs, resps
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-cIO.Done()
|
<-cIO.Done()
|
||||||
_ = sendCtx(c.ctx, c.closeConnections, cIO)
|
_ = agpl.SendCtx(c.ctx, c.closeConnections, cIO)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return reqs, resps
|
return reqs, resps
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case c <- a:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func recvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return a, ctx.Err()
|
|
||||||
case a, ok := <-c:
|
|
||||||
if ok {
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
return a, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type tKey struct {
|
type tKey struct {
|
||||||
src uuid.UUID
|
src uuid.UUID
|
||||||
dst uuid.UUID
|
dst uuid.UUID
|
||||||
|
@ -873,7 +694,7 @@ func (m *mapper) bestToUpdate(best map[uuid.UUID]mapping) *proto.CoordinateRespo
|
||||||
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
case ok && sm.kind == proto.CoordinateResponse_PeerUpdate_NODE && mpng.kind == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||||
eq, err := sm.node.Equal(mpng.node)
|
eq, err := sm.node.Equal(mpng.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.kind))
|
m.logger.Critical(m.ctx, "failed to compare nodes", slog.F("old", sm.node), slog.F("new", mpng.node))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if eq {
|
if eq {
|
||||||
|
@ -1303,7 +1124,7 @@ func (q *querier) updateAll() {
|
||||||
go func(m *mapper) {
|
go func(m *mapper) {
|
||||||
// make sure we send on the _mapper_ context, not our own in case the mapper is
|
// make sure we send on the _mapper_ context, not our own in case the mapper is
|
||||||
// shutting down or shut down.
|
// shutting down or shut down.
|
||||||
_ = sendCtx(m.ctx, m.update, struct{}{})
|
_ = agpl.SendCtx(m.ctx, m.update, struct{}{})
|
||||||
}(mpr)
|
}(mpr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1603,7 +1424,7 @@ func (h *heartbeats) recvBeat(id uuid.UUID) {
|
||||||
h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id))
|
h.logger.Info(h.ctx, "heartbeats (re)started", slog.F("other_coordinator_id", id))
|
||||||
// send on a separate goroutine to avoid holding lock. Triggering update can be async
|
// send on a separate goroutine to avoid holding lock. Triggering update can be async
|
||||||
go func() {
|
go func() {
|
||||||
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
|
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
h.coordinators[id] = time.Now()
|
h.coordinators[id] = time.Now()
|
||||||
|
@ -1650,7 +1471,7 @@ func (h *heartbeats) checkExpiry() {
|
||||||
if expired {
|
if expired {
|
||||||
// send on a separate goroutine to avoid holding lock. Triggering update can be async
|
// send on a separate goroutine to avoid holding lock. Triggering update can be async
|
||||||
go func() {
|
go func() {
|
||||||
_ = sendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
|
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{filter: filterUpdateUpdated})
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
// we need to reset the timer for when the next oldest coordinator will expire, if any.
|
// we need to reset the timer for when the next oldest coordinator will expire, if any.
|
||||||
|
@ -1685,14 +1506,14 @@ func (h *heartbeats) sendBeat() {
|
||||||
h.failedHeartbeats++
|
h.failedHeartbeats++
|
||||||
if h.failedHeartbeats == 3 {
|
if h.failedHeartbeats == 3 {
|
||||||
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
|
h.logger.Error(h.ctx, "coordinator failed 3 heartbeats and is unhealthy")
|
||||||
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
|
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateUnhealthy})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.logger.Debug(h.ctx, "sent heartbeat")
|
h.logger.Debug(h.ctx, "sent heartbeat")
|
||||||
if h.failedHeartbeats >= 3 {
|
if h.failedHeartbeats >= 3 {
|
||||||
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
|
h.logger.Info(h.ctx, "coordinator sent heartbeat and is healthy")
|
||||||
_ = sendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
|
_ = agpl.SendCtx(h.ctx, h.update, hbUpdate{health: healthUpdateHealthy})
|
||||||
}
|
}
|
||||||
h.failedHeartbeats = 0
|
h.failedHeartbeats = 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -612,18 +614,7 @@ func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
|
||||||
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
agpltest.BidirectionalTunnels(ctx, t, coordinator)
|
||||||
p1 := newTestPeer(ctx, t, coordinator, "p1")
|
|
||||||
defer p1.close(ctx)
|
|
||||||
p2 := newTestPeer(ctx, t, coordinator, "p2")
|
|
||||||
defer p2.close(ctx)
|
|
||||||
p1.addTunnel(p2.id)
|
|
||||||
p2.addTunnel(p1.id)
|
|
||||||
p1.updateDERP(1)
|
|
||||||
p2.updateDERP(2)
|
|
||||||
|
|
||||||
p1.assertEventuallyHasDERP(p2.id, 2)
|
|
||||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
|
func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
|
||||||
|
@ -638,21 +629,7 @@ func TestPGCoordinator_GracefulDisconnect(t *testing.T) {
|
||||||
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
agpltest.GracefulDisconnectTest(ctx, t, coordinator)
|
||||||
p1 := newTestPeer(ctx, t, coordinator, "p1")
|
|
||||||
defer p1.close(ctx)
|
|
||||||
p2 := newTestPeer(ctx, t, coordinator, "p2")
|
|
||||||
defer p2.close(ctx)
|
|
||||||
p1.addTunnel(p2.id)
|
|
||||||
p1.updateDERP(1)
|
|
||||||
p2.updateDERP(2)
|
|
||||||
|
|
||||||
p1.assertEventuallyHasDERP(p2.id, 2)
|
|
||||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
|
||||||
|
|
||||||
p2.disconnect()
|
|
||||||
p1.assertEventuallyDisconnected(p2.id)
|
|
||||||
p2.assertEventuallyResponsesClosed()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinator_Lost(t *testing.T) {
|
func TestPGCoordinator_Lost(t *testing.T) {
|
||||||
|
@ -667,20 +644,7 @@ func TestPGCoordinator_Lost(t *testing.T) {
|
||||||
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
agpltest.LostTest(ctx, t, coordinator)
|
||||||
p1 := newTestPeer(ctx, t, coordinator, "p1")
|
|
||||||
defer p1.close(ctx)
|
|
||||||
p2 := newTestPeer(ctx, t, coordinator, "p2")
|
|
||||||
defer p2.close(ctx)
|
|
||||||
p1.addTunnel(p2.id)
|
|
||||||
p1.updateDERP(1)
|
|
||||||
p2.updateDERP(2)
|
|
||||||
|
|
||||||
p1.assertEventuallyHasDERP(p2.id, 2)
|
|
||||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
|
||||||
|
|
||||||
p2.close(ctx)
|
|
||||||
p1.assertEventuallyLost(p2.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type testConn struct {
|
type testConn struct {
|
||||||
|
@ -918,177 +882,6 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerStatus struct {
|
|
||||||
preferredDERP int32
|
|
||||||
status proto.CoordinateResponse_PeerUpdate_Kind
|
|
||||||
}
|
|
||||||
|
|
||||||
type testPeer struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
t testing.TB
|
|
||||||
id uuid.UUID
|
|
||||||
name string
|
|
||||||
resps <-chan *proto.CoordinateResponse
|
|
||||||
reqs chan<- *proto.CoordinateRequest
|
|
||||||
peers map[uuid.UUID]peerStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
|
|
||||||
p := &testPeer{t: t, name: name, peers: make(map[uuid.UUID]peerStatus)}
|
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
|
||||||
if len(id) > 1 {
|
|
||||||
t.Fatal("too many")
|
|
||||||
}
|
|
||||||
if len(id) == 1 {
|
|
||||||
p.id = id[0]
|
|
||||||
} else {
|
|
||||||
p.id = uuid.New()
|
|
||||||
}
|
|
||||||
// SingleTailnetTunnelAuth allows connections to arbitrary peers
|
|
||||||
p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{})
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) addTunnel(other uuid.UUID) {
|
|
||||||
p.t.Helper()
|
|
||||||
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}}
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
p.t.Errorf("timeout adding tunnel for %s", p.name)
|
|
||||||
return
|
|
||||||
case p.reqs <- req:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) updateDERP(derp int32) {
|
|
||||||
p.t.Helper()
|
|
||||||
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
p.t.Errorf("timeout updating node for %s", p.name)
|
|
||||||
return
|
|
||||||
case p.reqs <- req:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) disconnect() {
|
|
||||||
p.t.Helper()
|
|
||||||
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
p.t.Errorf("timeout updating node for %s", p.name)
|
|
||||||
return
|
|
||||||
case p.reqs <- req:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
|
|
||||||
p.t.Helper()
|
|
||||||
for {
|
|
||||||
o, ok := p.peers[other]
|
|
||||||
if ok && o.preferredDERP == derp {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := p.handleOneResp(); err != nil {
|
|
||||||
assert.NoError(p.t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) assertEventuallyDisconnected(other uuid.UUID) {
|
|
||||||
p.t.Helper()
|
|
||||||
for {
|
|
||||||
_, ok := p.peers[other]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := p.handleOneResp(); err != nil {
|
|
||||||
assert.NoError(p.t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) assertEventuallyLost(other uuid.UUID) {
|
|
||||||
p.t.Helper()
|
|
||||||
for {
|
|
||||||
o := p.peers[other]
|
|
||||||
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := p.handleOneResp(); err != nil {
|
|
||||||
assert.NoError(p.t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) assertEventuallyResponsesClosed() {
|
|
||||||
p.t.Helper()
|
|
||||||
for {
|
|
||||||
err := p.handleOneResp()
|
|
||||||
if xerrors.Is(err, responsesClosed) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !assert.NoError(p.t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var responsesClosed = xerrors.New("responses closed")
|
|
||||||
|
|
||||||
func (p *testPeer) handleOneResp() error {
|
|
||||||
select {
|
|
||||||
case <-p.ctx.Done():
|
|
||||||
return p.ctx.Err()
|
|
||||||
case resp, ok := <-p.resps:
|
|
||||||
if !ok {
|
|
||||||
return responsesClosed
|
|
||||||
}
|
|
||||||
for _, update := range resp.PeerUpdates {
|
|
||||||
id, err := uuid.FromBytes(update.Uuid)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch update.Kind {
|
|
||||||
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
|
||||||
p.peers[id] = peerStatus{
|
|
||||||
preferredDERP: update.GetNode().GetPreferredDerp(),
|
|
||||||
status: update.Kind,
|
|
||||||
}
|
|
||||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
|
||||||
delete(p.peers, id)
|
|
||||||
default:
|
|
||||||
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *testPeer) close(ctx context.Context) {
|
|
||||||
p.t.Helper()
|
|
||||||
p.cancel()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
|
|
||||||
return
|
|
||||||
case _, ok := <-p.resps:
|
|
||||||
if ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeCoordinator struct {
|
type fakeCoordinator struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
t *testing.T
|
t *testing.T
|
||||||
|
|
|
@ -83,6 +83,7 @@ result
|
||||||
# Testdata shouldn't be formatted.
|
# Testdata shouldn't be formatted.
|
||||||
../scripts/apitypings/testdata/**/*.ts
|
../scripts/apitypings/testdata/**/*.ts
|
||||||
../enterprise/tailnet/testdata/*.golden.html
|
../enterprise/tailnet/testdata/*.golden.html
|
||||||
|
../tailnet/testdata/*.golden.html
|
||||||
|
|
||||||
# Generated files shouldn't be formatted.
|
# Generated files shouldn't be formatted.
|
||||||
e2e/provisionerGenerated.ts
|
e2e/provisionerGenerated.ts
|
||||||
|
|
|
@ -83,6 +83,7 @@ result
|
||||||
# Testdata shouldn't be formatted.
|
# Testdata shouldn't be formatted.
|
||||||
../scripts/apitypings/testdata/**/*.ts
|
../scripts/apitypings/testdata/**/*.ts
|
||||||
../enterprise/tailnet/testdata/*.golden.html
|
../enterprise/tailnet/testdata/*.golden.html
|
||||||
|
../tailnet/testdata/*.golden.html
|
||||||
|
|
||||||
# Generated files shouldn't be formatted.
|
# Generated files shouldn't be formatted.
|
||||||
e2e/provisionerGenerated.ts
|
e2e/provisionerGenerated.ts
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,74 @@
|
||||||
|
package tailnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateGoldenFiles indicates golden files should be updated.
|
||||||
|
// To update the golden files:
|
||||||
|
// make update-golden-files
|
||||||
|
var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
|
||||||
|
|
||||||
|
func TestDebugTemplate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("newlines screw up golden files on windows")
|
||||||
|
}
|
||||||
|
p1 := uuid.MustParse("01000000-2222-2222-2222-222222222222")
|
||||||
|
p2 := uuid.MustParse("02000000-2222-2222-2222-222222222222")
|
||||||
|
in := HTMLDebug{
|
||||||
|
Peers: []HTMLPeer{
|
||||||
|
{
|
||||||
|
Name: "Peer 1",
|
||||||
|
ID: p1,
|
||||||
|
LastWriteAge: 5 * time.Second,
|
||||||
|
Node: `id:1 preferred_derp:999 endpoints:"192.168.0.49:4449"`,
|
||||||
|
CreatedAge: 87 * time.Second,
|
||||||
|
Overwrites: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Peer 2",
|
||||||
|
ID: p2,
|
||||||
|
LastWriteAge: 7 * time.Second,
|
||||||
|
Node: `id:2 preferred_derp:999 endpoints:"192.168.0.33:4449"`,
|
||||||
|
CreatedAge: time.Hour,
|
||||||
|
Overwrites: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tunnels: []HTMLTunnel{
|
||||||
|
{
|
||||||
|
Src: p1,
|
||||||
|
Dst: p2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
err := debugTempl.Execute(buf, in)
|
||||||
|
require.NoError(t, err)
|
||||||
|
actual := buf.Bytes()
|
||||||
|
|
||||||
|
goldenPath := filepath.Join("testdata", "debug.golden.html")
|
||||||
|
if *UpdateGoldenFiles {
|
||||||
|
t.Logf("update golden file %s", goldenPath)
|
||||||
|
err := os.WriteFile(goldenPath, actual, 0o600)
|
||||||
|
require.NoError(t, err, "update golden file")
|
||||||
|
}
|
||||||
|
|
||||||
|
expected, err := os.ReadFile(goldenPath)
|
||||||
|
require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes")
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, string(expected), string(actual),
|
||||||
|
"golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes",
|
||||||
|
goldenPath,
|
||||||
|
)
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/test"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +28,12 @@ func TestCoordinator(t *testing.T) {
|
||||||
t.Run("ClientWithoutAgent", func(t *testing.T) {
|
t.Run("ClientWithoutAgent", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
|
defer func() {
|
||||||
|
err := coordinator.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -45,14 +51,19 @@ func TestCoordinator(t *testing.T) {
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
require.NoError(t, client.Close())
|
require.NoError(t, client.Close())
|
||||||
require.NoError(t, server.Close())
|
require.NoError(t, server.Close())
|
||||||
<-errChan
|
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
||||||
<-closeChan
|
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentWithoutClients", func(t *testing.T) {
|
t.Run("AgentWithoutClients", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
|
defer func() {
|
||||||
|
err := coordinator.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -70,14 +81,18 @@ func TestCoordinator(t *testing.T) {
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
err := client.Close()
|
err := client.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-errChan
|
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
||||||
<-closeChan
|
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentWithClient", func(t *testing.T) {
|
t.Run("AgentWithClient", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
|
defer func() {
|
||||||
|
err := coordinator.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
// in this test we use real websockets to test use of deadlines
|
// in this test we use real websockets to test use of deadlines
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||||
|
@ -116,14 +131,11 @@ func TestCoordinator(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
close(closeClientChan)
|
close(closeClientChan)
|
||||||
}()
|
}()
|
||||||
select {
|
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
||||||
case agentNodes := <-clientNodeChan:
|
require.Len(t, agentNodes, 1)
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatal("timed out")
|
|
||||||
}
|
|
||||||
sendClientNode(&tailnet.Node{PreferredDERP: 2})
|
sendClientNode(&tailnet.Node{PreferredDERP: 2})
|
||||||
clientNodes := <-agentNodeChan
|
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan)
|
||||||
require.Len(t, clientNodes, 1)
|
require.Len(t, clientNodes, 1)
|
||||||
|
|
||||||
// wait longer than the internal wait timeout.
|
// wait longer than the internal wait timeout.
|
||||||
|
@ -132,18 +144,14 @@ func TestCoordinator(t *testing.T) {
|
||||||
|
|
||||||
// Ensure an update to the agent node reaches the client!
|
// Ensure an update to the agent node reaches the client!
|
||||||
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
|
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
|
||||||
select {
|
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
||||||
case agentNodes := <-clientNodeChan:
|
require.Len(t, agentNodes, 1)
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatal("timed out")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the agent WebSocket so a new one can connect.
|
// Close the agent WebSocket so a new one can connect.
|
||||||
err := agentWS.Close()
|
err := agentWS.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-agentErrChan
|
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
|
||||||
<-closeAgentChan
|
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
|
||||||
|
|
||||||
// Create a new agent connection. This is to simulate a reconnect!
|
// Create a new agent connection. This is to simulate a reconnect!
|
||||||
agentWS, agentServerWS = net.Pipe()
|
agentWS, agentServerWS = net.Pipe()
|
||||||
|
@ -159,30 +167,32 @@ func TestCoordinator(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
close(closeAgentChan)
|
close(closeAgentChan)
|
||||||
}()
|
}()
|
||||||
// Ensure the existing listening client sends it's node immediately!
|
// Ensure the existing listening client sends its node immediately!
|
||||||
clientNodes = <-agentNodeChan
|
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
|
||||||
require.Len(t, clientNodes, 1)
|
require.Len(t, clientNodes, 1)
|
||||||
|
|
||||||
err = agentWS.Close()
|
err = agentWS.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-agentErrChan
|
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
|
||||||
<-closeAgentChan
|
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
|
||||||
|
|
||||||
err = clientWS.Close()
|
err = clientWS.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-clientErrChan
|
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
|
||||||
<-closeClientChan
|
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentDoubleConnect", func(t *testing.T) {
|
t.Run("AgentDoubleConnect", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitLong)
|
||||||
|
|
||||||
agentWS1, agentServerWS1 := net.Pipe()
|
agentWS1, agentServerWS1 := net.Pipe()
|
||||||
defer agentWS1.Close()
|
defer agentWS1.Close()
|
||||||
agentNodeChan1 := make(chan []*tailnet.Node)
|
agentNodeChan1 := make(chan []*tailnet.Node)
|
||||||
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
|
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
|
||||||
|
t.Logf("agent1 got node update: %v", nodes)
|
||||||
agentNodeChan1 <- nodes
|
agentNodeChan1 <- nodes
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -203,6 +213,7 @@ func TestCoordinator(t *testing.T) {
|
||||||
defer clientServerWS.Close()
|
defer clientServerWS.Close()
|
||||||
clientNodeChan := make(chan []*tailnet.Node)
|
clientNodeChan := make(chan []*tailnet.Node)
|
||||||
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
|
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
|
||||||
|
t.Logf("client got node update: %v", nodes)
|
||||||
clientNodeChan <- nodes
|
clientNodeChan <- nodes
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -213,15 +224,15 @@ func TestCoordinator(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
close(closeClientChan)
|
close(closeClientChan)
|
||||||
}()
|
}()
|
||||||
agentNodes := <-clientNodeChan
|
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
||||||
require.Len(t, agentNodes, 1)
|
require.Len(t, agentNodes, 1)
|
||||||
sendClientNode(&tailnet.Node{PreferredDERP: 2})
|
sendClientNode(&tailnet.Node{PreferredDERP: 2})
|
||||||
clientNodes := <-agentNodeChan1
|
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
|
||||||
require.Len(t, clientNodes, 1)
|
require.Len(t, clientNodes, 1)
|
||||||
|
|
||||||
// Ensure an update to the agent node reaches the client!
|
// Ensure an update to the agent node reaches the client!
|
||||||
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
|
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
|
||||||
agentNodes = <-clientNodeChan
|
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
||||||
require.Len(t, agentNodes, 1)
|
require.Len(t, agentNodes, 1)
|
||||||
|
|
||||||
// Create a new agent connection without disconnecting the old one.
|
// Create a new agent connection without disconnecting the old one.
|
||||||
|
@ -229,6 +240,7 @@ func TestCoordinator(t *testing.T) {
|
||||||
defer agentWS2.Close()
|
defer agentWS2.Close()
|
||||||
agentNodeChan2 := make(chan []*tailnet.Node)
|
agentNodeChan2 := make(chan []*tailnet.Node)
|
||||||
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
|
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
|
||||||
|
t.Logf("agent2 got node update: %v", nodes)
|
||||||
agentNodeChan2 <- nodes
|
agentNodeChan2 <- nodes
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -240,33 +252,22 @@ func TestCoordinator(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Ensure the existing listening client sends it's node immediately!
|
// Ensure the existing listening client sends it's node immediately!
|
||||||
clientNodes = <-agentNodeChan2
|
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2)
|
||||||
require.Len(t, clientNodes, 1)
|
require.Len(t, clientNodes, 1)
|
||||||
|
|
||||||
counts, ok := coordinator.(interface {
|
// This original agent websocket should've been closed forcefully.
|
||||||
NodeCount() int
|
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan1)
|
||||||
AgentCount() int
|
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
|
||||||
})
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("coordinator should have NodeCount() and AgentCount()")
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 2, counts.NodeCount())
|
|
||||||
assert.Equal(t, 1, counts.AgentCount())
|
|
||||||
|
|
||||||
err := agentWS2.Close()
|
err := agentWS2.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-agentErrChan2
|
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
|
||||||
<-closeAgentChan2
|
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
|
||||||
|
|
||||||
err = clientWS.Close()
|
err = clientWS.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-clientErrChan
|
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
|
||||||
<-closeClientChan
|
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
|
||||||
|
|
||||||
// This original agent websocket should've been closed forcefully.
|
|
||||||
<-agentErrChan1
|
|
||||||
<-closeAgentChan1
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,9 +335,8 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
n, err = clientWS.Read(buf[1:])
|
n, err = clientWS.Read(buf[1:])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, len(buf)-1, n)
|
|
||||||
var cNodes []*tailnet.Node
|
var cNodes []*tailnet.Node
|
||||||
err = json.Unmarshal(buf, &cNodes)
|
err = json.Unmarshal(buf[:n+1], &cNodes)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, cNodes, 1)
|
require.Len(t, cNodes, 1)
|
||||||
require.Equal(t, 0, cNodes[0].PreferredDERP)
|
require.Equal(t, 0, cNodes[0].PreferredDERP)
|
||||||
|
@ -348,13 +348,36 @@ func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
n, err = clientWS.Read(buf)
|
n, err = clientWS.Read(buf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, len(buf), n)
|
err = json.Unmarshal(buf[:n], &cNodes)
|
||||||
err = json.Unmarshal(buf, &cNodes)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, cNodes, 1)
|
require.Len(t, cNodes, 1)
|
||||||
require.Equal(t, 1, cNodes[0].PreferredDERP)
|
require.Equal(t, 1, cNodes[0].PreferredDERP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
coordinator := tailnet.NewCoordinatorV2(logger)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
test.BidirectionalTunnels(ctx, t, coordinator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoordinator_GracefulDisconnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
coordinator := tailnet.NewCoordinatorV2(logger)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
test.GracefulDisconnectTest(ctx, t, coordinator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoordinator_Lost(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
coordinator := tailnet.NewCoordinatorV2(logger)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
test.LostTest(ctx, t, coordinator)
|
||||||
|
}
|
||||||
|
|
||||||
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
|
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
sc := make(chan net.Conn, 1)
|
sc := make(chan net.Conn, 1)
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
package tailnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type peer struct {
|
||||||
|
logger slog.Logger
|
||||||
|
id uuid.UUID
|
||||||
|
node *proto.Node
|
||||||
|
resps chan<- *proto.CoordinateResponse
|
||||||
|
reqs <-chan *proto.CoordinateRequest
|
||||||
|
auth TunnelAuth
|
||||||
|
sent map[uuid.UUID]*proto.Node
|
||||||
|
|
||||||
|
name string
|
||||||
|
start time.Time
|
||||||
|
lastWrite time.Time
|
||||||
|
overwrites int
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateMappingLocked updates the mapping for another peer linked to this one by a tunnel. This method
|
||||||
|
// is NOT threadsafe and must be called while holding the core lock.
|
||||||
|
func (p *peer) updateMappingLocked(id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) error {
|
||||||
|
logger := p.logger.With(slog.F("from_id", id), slog.F("kind", k), slog.F("reason", reason))
|
||||||
|
update, err := p.storeMappingLocked(id, n, k, reason)
|
||||||
|
if xerrors.Is(err, noResp) {
|
||||||
|
logger.Debug(context.Background(), "skipping update")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{update}}
|
||||||
|
select {
|
||||||
|
case p.resps <- req:
|
||||||
|
p.lastWrite = time.Now()
|
||||||
|
logger.Debug(context.Background(), "wrote peer update")
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return ErrWouldBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchUpdateMapping updates the mappings for a list of peers linked to this one by a tunnel. This
|
||||||
|
// method is NOT threadsafe and must be called while holding the core lock.
|
||||||
|
func (p *peer) batchUpdateMappingLocked(others []*peer, k proto.CoordinateResponse_PeerUpdate_Kind, reason string) error {
|
||||||
|
req := &proto.CoordinateResponse{}
|
||||||
|
for _, other := range others {
|
||||||
|
if other == nil || other.node == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
update, err := p.storeMappingLocked(other.id, other.node, k, reason)
|
||||||
|
if xerrors.Is(err, noResp) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.PeerUpdates = append(req.PeerUpdates, update)
|
||||||
|
}
|
||||||
|
if len(req.PeerUpdates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case p.resps <- req:
|
||||||
|
p.lastWrite = time.Now()
|
||||||
|
p.logger.Debug(context.Background(), "wrote batched update", slog.F("num_peer_updates", len(req.PeerUpdates)))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return ErrWouldBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var noResp = xerrors.New("no response needed")
|
||||||
|
|
||||||
|
func (p *peer) storeMappingLocked(
|
||||||
|
id uuid.UUID, n *proto.Node, k proto.CoordinateResponse_PeerUpdate_Kind, reason string,
|
||||||
|
) (
|
||||||
|
*proto.CoordinateResponse_PeerUpdate, error,
|
||||||
|
) {
|
||||||
|
p.logger.Debug(context.Background(), "got updated mapping",
|
||||||
|
slog.F("from_id", id), slog.F("kind", k), slog.F("reason", reason))
|
||||||
|
sn, ok := p.sent[id]
|
||||||
|
switch {
|
||||||
|
case !ok && (k == proto.CoordinateResponse_PeerUpdate_LOST || k == proto.CoordinateResponse_PeerUpdate_DISCONNECTED):
|
||||||
|
// we don't need to send a lost/disconnect update if we've never sent an update about this peer
|
||||||
|
return nil, noResp
|
||||||
|
case !ok && k == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||||
|
p.sent[id] = n
|
||||||
|
case ok && k == proto.CoordinateResponse_PeerUpdate_LOST:
|
||||||
|
delete(p.sent, id)
|
||||||
|
case ok && k == proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||||
|
delete(p.sent, id)
|
||||||
|
case ok && k == proto.CoordinateResponse_PeerUpdate_NODE:
|
||||||
|
eq, err := sn.Equal(n)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Critical(context.Background(), "failed to compare nodes", slog.F("old", sn), slog.F("new", n))
|
||||||
|
return nil, xerrors.Errorf("failed to compare nodes: %s", sn.String())
|
||||||
|
}
|
||||||
|
if eq {
|
||||||
|
return nil, noResp
|
||||||
|
}
|
||||||
|
p.sent[id] = n
|
||||||
|
}
|
||||||
|
return &proto.CoordinateResponse_PeerUpdate{
|
||||||
|
Uuid: id[:],
|
||||||
|
Kind: k,
|
||||||
|
Node: n,
|
||||||
|
Reason: reason,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *peer) reqLoop(ctx context.Context, logger slog.Logger, handler func(*peer, *proto.CoordinateRequest) error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Debug(ctx, "peerReadLoop context done")
|
||||||
|
return
|
||||||
|
case req, ok := <-p.reqs:
|
||||||
|
if !ok {
|
||||||
|
logger.Debug(ctx, "peerReadLoop channel closed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Debug(ctx, "peerReadLoop got request")
|
||||||
|
if err := handler(p, req); err != nil {
|
||||||
|
if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Error(ctx, "peerReadLoop error handling request", slog.Error(err), slog.F("request", req))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *peer) htmlDebug() HTMLPeer {
|
||||||
|
node := "<nil>"
|
||||||
|
if p.node != nil {
|
||||||
|
node = p.node.String()
|
||||||
|
}
|
||||||
|
return HTMLPeer{
|
||||||
|
ID: p.id,
|
||||||
|
Name: p.name,
|
||||||
|
CreatedAge: time.Since(p.start),
|
||||||
|
LastWriteAge: time.Since(p.lastWrite),
|
||||||
|
Overwrites: p.overwrites,
|
||||||
|
Node: node,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GracefulDisconnectTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||||
|
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||||
|
defer p1.Close(ctx)
|
||||||
|
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||||
|
defer p2.Close(ctx)
|
||||||
|
p1.AddTunnel(p2.ID)
|
||||||
|
p1.UpdateDERP(1)
|
||||||
|
p2.UpdateDERP(2)
|
||||||
|
|
||||||
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||||
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||||
|
|
||||||
|
p2.Disconnect()
|
||||||
|
p1.AssertEventuallyDisconnected(p2.ID)
|
||||||
|
p2.AssertEventuallyResponsesClosed()
|
||||||
|
}
|
||||||
|
|
||||||
|
func LostTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||||
|
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||||
|
defer p1.Close(ctx)
|
||||||
|
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||||
|
defer p2.Close(ctx)
|
||||||
|
p1.AddTunnel(p2.ID)
|
||||||
|
p1.UpdateDERP(1)
|
||||||
|
p2.UpdateDERP(2)
|
||||||
|
|
||||||
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||||
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||||
|
|
||||||
|
p2.Close(ctx)
|
||||||
|
p1.AssertEventuallyLost(p2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BidirectionalTunnels(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||||
|
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||||
|
defer p1.Close(ctx)
|
||||||
|
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||||
|
defer p2.Close(ctx)
|
||||||
|
p1.AddTunnel(p2.ID)
|
||||||
|
p2.AddTunnel(p1.ID)
|
||||||
|
p1.UpdateDERP(1)
|
||||||
|
p2.UpdateDERP(2)
|
||||||
|
|
||||||
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||||
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||||
|
}
|
|
@ -0,0 +1,184 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerStatus struct {
|
||||||
|
preferredDERP int32
|
||||||
|
status proto.CoordinateResponse_PeerUpdate_Kind
|
||||||
|
}
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
t testing.TB
|
||||||
|
ID uuid.UUID
|
||||||
|
name string
|
||||||
|
resps <-chan *proto.CoordinateResponse
|
||||||
|
reqs chan<- *proto.CoordinateRequest
|
||||||
|
peers map[uuid.UUID]PeerStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer {
|
||||||
|
p := &Peer{t: t, name: name, peers: make(map[uuid.UUID]PeerStatus)}
|
||||||
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
|
if len(id) > 1 {
|
||||||
|
t.Fatal("too many")
|
||||||
|
}
|
||||||
|
if len(id) == 1 {
|
||||||
|
p.ID = id[0]
|
||||||
|
} else {
|
||||||
|
p.ID = uuid.New()
|
||||||
|
}
|
||||||
|
// SingleTailnetTunnelAuth allows connections to arbitrary peers
|
||||||
|
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{})
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AddTunnel(other uuid.UUID) {
|
||||||
|
p.t.Helper()
|
||||||
|
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: tailnet.UUIDToByteSlice(other)}}
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
p.t.Errorf("timeout adding tunnel for %s", p.name)
|
||||||
|
return
|
||||||
|
case p.reqs <- req:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) UpdateDERP(derp int32) {
|
||||||
|
p.t.Helper()
|
||||||
|
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
p.t.Errorf("timeout updating node for %s", p.name)
|
||||||
|
return
|
||||||
|
case p.reqs <- req:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Disconnect() {
|
||||||
|
p.t.Helper()
|
||||||
|
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
p.t.Errorf("timeout updating node for %s", p.name)
|
||||||
|
return
|
||||||
|
case p.reqs <- req:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
o, ok := p.peers[other]
|
||||||
|
if ok && o.preferredDERP == derp {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.handleOneResp(); err != nil {
|
||||||
|
assert.NoError(p.t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
_, ok := p.peers[other]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.handleOneResp(); err != nil {
|
||||||
|
assert.NoError(p.t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
o := p.peers[other]
|
||||||
|
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.handleOneResp(); err != nil {
|
||||||
|
assert.NoError(p.t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertEventuallyResponsesClosed() {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
err := p.handleOneResp()
|
||||||
|
if xerrors.Is(err, responsesClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !assert.NoError(p.t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var responsesClosed = xerrors.New("responses closed")
|
||||||
|
|
||||||
|
func (p *Peer) handleOneResp() error {
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
return p.ctx.Err()
|
||||||
|
case resp, ok := <-p.resps:
|
||||||
|
if !ok {
|
||||||
|
return responsesClosed
|
||||||
|
}
|
||||||
|
for _, update := range resp.PeerUpdates {
|
||||||
|
id, err := uuid.FromBytes(update.Uuid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch update.Kind {
|
||||||
|
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
||||||
|
p.peers[id] = PeerStatus{
|
||||||
|
preferredDERP: update.GetNode().GetPreferredDerp(),
|
||||||
|
status: update.Kind,
|
||||||
|
}
|
||||||
|
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||||
|
delete(p.peers, id)
|
||||||
|
default:
|
||||||
|
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) Close(ctx context.Context) {
|
||||||
|
p.t.Helper()
|
||||||
|
p.cancel()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
|
||||||
|
return
|
||||||
|
case _, ok := <-p.resps:
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<style>
|
||||||
|
th, td {
|
||||||
|
padding-top: 6px;
|
||||||
|
padding-bottom: 6px;
|
||||||
|
padding-left: 10px;
|
||||||
|
padding-right: 10px;
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
tr {
|
||||||
|
border-bottom: 1px solid #ddd;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>in-memory wireguard coordinator debug</h1>
|
||||||
|
|
||||||
|
<h2 id=peers> <a href=#peers>#</a> peers: total 2 </h2>
|
||||||
|
<table>
|
||||||
|
<tr style="margin-top:4px">
|
||||||
|
<th>Name</th>
|
||||||
|
<th>ID</th>
|
||||||
|
<th>Created Age</th>
|
||||||
|
<th>Last Write Age</th>
|
||||||
|
<th>Overwrites</th>
|
||||||
|
<th>Node</th>
|
||||||
|
</tr>
|
||||||
|
<tr style="margin-top:4px">
|
||||||
|
<td>Peer 1</td>
|
||||||
|
<td>01000000-2222-2222-2222-222222222222</td>
|
||||||
|
<td>1m27s</td>
|
||||||
|
<td>5s ago</td>
|
||||||
|
<td>0</td>
|
||||||
|
<td style="white-space: pre;"><code>id:1 preferred_derp:999 endpoints:"192.168.0.49:4449"</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr style="margin-top:4px">
|
||||||
|
<td>Peer 2</td>
|
||||||
|
<td>02000000-2222-2222-2222-222222222222</td>
|
||||||
|
<td>1h0m0s</td>
|
||||||
|
<td>7s ago</td>
|
||||||
|
<td>2</td>
|
||||||
|
<td style="white-space: pre;"><code>id:2 preferred_derp:999 endpoints:"192.168.0.33:4449"</code></td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
<h2 id=tunnels><a href=#tunnels>#</a> tunnels: total 1</h2>
|
||||||
|
<table>
|
||||||
|
<tr style="margin-top:4px">
|
||||||
|
<th>SrcID</th>
|
||||||
|
<th>DstID</th>
|
||||||
|
</tr>
|
||||||
|
<tr style="margin-top:4px">
|
||||||
|
<td>01000000-2222-2222-2222-222222222222</td>
|
||||||
|
<td>02000000-2222-2222-2222-222222222222</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -20,6 +20,8 @@ const (
|
||||||
// ResponseBufferSize is the max number of responses to buffer per connection before we start
|
// ResponseBufferSize is the max number of responses to buffer per connection before we start
|
||||||
// dropping updates
|
// dropping updates
|
||||||
ResponseBufferSize = 512
|
ResponseBufferSize = 512
|
||||||
|
// RequestBufferSize is the max number of requests to buffer per connection
|
||||||
|
RequestBufferSize = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
type TrackedConn struct {
|
type TrackedConn struct {
|
||||||
|
|
|
@ -28,3 +28,74 @@ type AgentTunnelAuth struct{}
|
||||||
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
|
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all
|
||||||
|
// methods must be serialized by holding, e.g. the core mutex.
|
||||||
|
type tunnelStore struct {
|
||||||
|
bySrc map[uuid.UUID]map[uuid.UUID]struct{}
|
||||||
|
byDst map[uuid.UUID]map[uuid.UUID]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunnelStore() *tunnelStore {
|
||||||
|
return &tunnelStore{
|
||||||
|
bySrc: make(map[uuid.UUID]map[uuid.UUID]struct{}),
|
||||||
|
byDst: make(map[uuid.UUID]map[uuid.UUID]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStore) add(src, dst uuid.UUID) {
|
||||||
|
srcM, ok := s.bySrc[src]
|
||||||
|
if !ok {
|
||||||
|
srcM = make(map[uuid.UUID]struct{})
|
||||||
|
s.bySrc[src] = srcM
|
||||||
|
}
|
||||||
|
srcM[dst] = struct{}{}
|
||||||
|
dstM, ok := s.byDst[dst]
|
||||||
|
if !ok {
|
||||||
|
dstM = make(map[uuid.UUID]struct{})
|
||||||
|
s.byDst[dst] = dstM
|
||||||
|
}
|
||||||
|
dstM[src] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStore) remove(src, dst uuid.UUID) {
|
||||||
|
delete(s.bySrc[src], dst)
|
||||||
|
if len(s.bySrc[src]) == 0 {
|
||||||
|
delete(s.bySrc, src)
|
||||||
|
}
|
||||||
|
delete(s.byDst[dst], src)
|
||||||
|
if len(s.byDst[dst]) == 0 {
|
||||||
|
delete(s.byDst, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStore) removeAll(src uuid.UUID) {
|
||||||
|
for dst := range s.bySrc[src] {
|
||||||
|
s.remove(src, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStore) findTunnelPeers(id uuid.UUID) []uuid.UUID {
|
||||||
|
set := make(map[uuid.UUID]struct{})
|
||||||
|
for dst := range s.bySrc[id] {
|
||||||
|
set[dst] = struct{}{}
|
||||||
|
}
|
||||||
|
for src := range s.byDst[id] {
|
||||||
|
set[src] = struct{}{}
|
||||||
|
}
|
||||||
|
out := make([]uuid.UUID, 0, len(set))
|
||||||
|
for id := range set {
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStore) htmlDebug() []HTMLTunnel {
|
||||||
|
out := make([]HTMLTunnel, 0)
|
||||||
|
for src, dsts := range s.bySrc {
|
||||||
|
for dst := range dsts {
|
||||||
|
out = append(out, HTMLTunnel{Src: src, Dst: dst})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package tailnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTunnelStore_Bidir(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
p1 := uuid.MustParse("00000001-1111-1111-1111-111111111111")
|
||||||
|
p2 := uuid.MustParse("00000002-1111-1111-1111-111111111111")
|
||||||
|
uut := newTunnelStore()
|
||||||
|
uut.add(p1, p2)
|
||||||
|
require.Equal(t, []uuid.UUID{p1}, uut.findTunnelPeers(p2))
|
||||||
|
require.Equal(t, []uuid.UUID{p2}, uut.findTunnelPeers(p1))
|
||||||
|
uut.remove(p1, p2)
|
||||||
|
require.Empty(t, uut.findTunnelPeers(p1))
|
||||||
|
require.Empty(t, uut.findTunnelPeers(p2))
|
||||||
|
require.Len(t, uut.byDst, 0)
|
||||||
|
require.Len(t, uut.bySrc, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunnelStore_RemoveAll(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
p1 := uuid.MustParse("00000001-1111-1111-1111-111111111111")
|
||||||
|
p2 := uuid.MustParse("00000002-1111-1111-1111-111111111111")
|
||||||
|
p3 := uuid.MustParse("00000003-1111-1111-1111-111111111111")
|
||||||
|
uut := newTunnelStore()
|
||||||
|
uut.add(p1, p2)
|
||||||
|
uut.add(p1, p3)
|
||||||
|
uut.add(p3, p1)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p1), 2)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p2), 1)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p3), 1)
|
||||||
|
uut.removeAll(p1)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p1), 1)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p2), 0)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p3), 1)
|
||||||
|
uut.removeAll(p3)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p1), 0)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p2), 0)
|
||||||
|
require.Len(t, uut.findTunnelPeers(p3), 0)
|
||||||
|
}
|
Loading…
Reference in New Issue