feat: use tailnet v2 API for coordination (#11638)

This one is huge, and I'm sorry.

The problem is that once I change `tailnet.Conn` to start doing v2 behavior, I kind of have to change it everywhere, including in CoderSDK (CLI), the agent, wsproxy, and ServerTailnet.

There is still a bit more cleanup to do, and I need to add code so that when we lose connection to the Coordinator, we mark all peers as LOST, but that will be in a separate PR since this is big enough!
This commit is contained in:
Spike Curtis 2024-01-22 11:07:50 +04:00 committed by GitHub
parent 5a2cf7cd14
commit f01cab9894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1192 additions and 1114 deletions

View File

@ -475,7 +475,8 @@ gen: \
site/.eslintignore \
site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \
examples/examples.gen.json
examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go
.PHONY: gen
# Mark all generated files as fresh so make thinks they're up-to-date. This is
@ -502,6 +503,7 @@ gen/mark-fresh:
site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \
examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go \
"
for file in $$files; do
echo "$$file"
@ -529,6 +531,9 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $
coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go
go generate ./coderd/database/dbmock/
tailnet/tailnettest/coordinatormock.go: tailnet/coordinator.go
go generate ./tailnet/tailnettest/
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
protoc \
--go_out=. \

View File

@ -30,6 +30,7 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"storj.io/drpc"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"
@ -47,6 +48,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
)
const (
@ -86,7 +88,7 @@ type Options struct {
type Client interface {
Manifest(ctx context.Context) (agentsdk.Manifest, error)
Listen(ctx context.Context) (net.Conn, error)
Listen(ctx context.Context) (drpc.Conn, error)
DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
@ -1058,20 +1060,34 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
ctx, cancel := context.WithCancel(ctx)
defer cancel()
coordinator, err := a.client.Listen(ctx)
conn, err := a.client.Listen(ctx)
if err != nil {
return err
}
defer coordinator.Close()
defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
}
}()
tClient := tailnetproto.NewDRPCTailnetClient(conn)
coordinate, err := tClient.Coordinate(ctx)
if err != nil {
return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err)
}
defer func() {
cErr := coordinate.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err))
}
}()
a.logger.Info(ctx, "connected to coordination endpoint")
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error {
return network.UpdateNodes(nodes, false)
})
network.SetNodeCallback(sendNodes)
coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
case err := <-coordination.Error():
return err
}
}

View File

@ -1664,9 +1664,11 @@ func TestAgent_UpdatedDERP(t *testing.T) {
require.NotNil(t, originalDerpMap)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
// use t.Cleanup so the coordinator closing doesn't deadlock with in-memory
// coordination
t.Cleanup(func() {
_ = coordinator.Close()
}()
})
agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs()
@ -1681,41 +1683,42 @@ func TestAgent_UpdatedDERP(t *testing.T) {
statsCh,
coordinator,
)
closer := agent.New(agent.Options{
uut := agent.New(agent.Options{
Client: client,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: time.Minute,
})
defer func() {
_ = closer.Close()
}()
t.Cleanup(func() {
_ = uut.Close()
})
// Setup a client connection.
newClientConn := func(derpMap *tailcfg.DERPMap) *codersdk.WorkspaceAgentConn {
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *codersdk.WorkspaceAgentConn {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: derpMap,
Logger: logger.Named("client"),
Logger: logger.Named(name),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
t.Logf("closing conn %s", name)
_ = conn.Close()
<-serveClientDone
})
go func() {
defer close(serveClientDone)
err := coordinator.ServeClient(serverConn, uuid.New(), agentID)
assert.NoError(t, err)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, false)
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
clientID := uuid.New()
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, agentID,
coordinator, conn)
t.Cleanup(func() {
t.Logf("closing coordination %s", name)
err := coordination.Close()
if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error())
}
})
conn.SetNodeCallback(sendNode)
// Force DERP.
conn.SetBlockEndpoints(true)
@ -1724,6 +1727,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
t.Cleanup(func() {
t.Logf("closing sdkConn %s", name)
_ = sdkConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@ -1734,7 +1738,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
return sdkConn
}
conn1 := newClientConn(originalDerpMap)
conn1 := newClientConn(originalDerpMap, "client1")
// Change the DERP map.
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
@ -1753,27 +1757,34 @@ func TestAgent_UpdatedDERP(t *testing.T) {
DERPMap: newDerpMap,
})
require.NoError(t, err)
t.Logf("client Pushed DERPMap update")
require.Eventually(t, func() bool {
conn := closer.TailnetConn()
conn := uut.TailnetConn()
if conn == nil {
return false
}
regionIDs := conn.DERPMap().RegionIDs()
return len(regionIDs) == 1 && regionIDs[0] == 2 && conn.Node().PreferredDERP == 2
preferredDERP := conn.Node().PreferredDERP
t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP)
return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2
}, testutil.WaitLong, testutil.IntervalFast)
t.Logf("agent got the new DERPMap")
// Connect from a second client and make sure it uses the new DERP map.
conn2 := newClientConn(newDerpMap)
conn2 := newClientConn(newDerpMap, "client2")
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
t.Log("conn2 got the new DERPMap")
// If the first client gets a DERP map update, it should be able to
// reconnect just fine.
conn1.SetDERPMap(newDerpMap)
require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs())
t.Log("set the new DERPMap on conn1")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
require.True(t, conn1.AwaitReachable(ctx))
t.Log("conn1 reached agent with new DERP")
}
func TestAgent_Speedtest(t *testing.T) {
@ -2050,22 +2061,22 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
Logger: logger.Named("client"),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
<-serveClientDone
})
go func() {
defer close(serveClientDone)
coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, false)
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
clientID := uuid.New()
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, metadata.AgentID,
coordinator, conn)
t.Cleanup(func() {
err := coordination.Close()
if err != nil {
t.Logf("error closing in-mem coordination: %s", err.Error())
}
})
conn.SetNodeCallback(sendNode)
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: metadata.AgentID,
})

View File

@ -3,19 +3,26 @@ package agenttest
import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
@ -24,11 +31,31 @@ func NewClient(t testing.TB,
agentID uuid.UUID,
manifest agentsdk.Manifest,
statsChan chan *agentsdk.Stats,
coordinator tailnet.CoordinatorV1,
coordinator tailnet.Coordinator,
) *Client {
if manifest.AgentID == uuid.Nil {
manifest.AgentID = agentID
}
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coordinator)
mux := drpcmux.New()
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
// TODO: handle DERPMap too!
DerpMapUpdateFrequency: time.Hour,
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
return &Client{
t: t,
logger: logger.Named("client"),
@ -36,6 +63,7 @@ func NewClient(t testing.TB,
manifest: manifest,
statsChan: statsChan,
coordinator: coordinator,
server: server,
derpMapUpdates: make(chan agentsdk.DERPMapUpdate),
}
}
@ -47,7 +75,8 @@ type Client struct {
manifest agentsdk.Manifest
metadata map[string]agentsdk.Metadata
statsChan chan *agentsdk.Stats
coordinator tailnet.CoordinatorV1
coordinator tailnet.Coordinator
server *drpcserver.Server
LastWorkspaceAgent func()
PatchWorkspaceLogs func() error
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
@ -63,20 +92,29 @@ func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}
func (c *Client) Listen(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
func (c *Client) Listen(_ context.Context) (drpc.Conn, error) {
conn, lis := drpcsdk.MemTransportPipe()
closed := make(chan struct{})
c.LastWorkspaceAgent = func() {
_ = serverConn.Close()
_ = clientConn.Close()
_ = conn.Close()
_ = lis.Close()
<-closed
}
c.t.Cleanup(c.LastWorkspaceAgent)
serveCtx, cancel := context.WithCancel(context.Background())
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "agenttest",
ID: c.agentID,
Auth: auth,
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
_ = c.server.Serve(serveCtx, lis)
close(closed)
}()
return clientConn, nil
return conn, nil
}
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {

View File

@ -33,6 +33,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
@ -98,14 +99,32 @@ func TestDERP(t *testing.T) {
w2Ready := make(chan struct{})
w2ReadyOnce := sync.Once{}
w1ID := uuid.New()
w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node}, false)
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
w2.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
Id: w1ID[:],
Node: pn,
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
}})
w2ReadyOnce.Do(func() {
close(w2Ready)
})
})
w2ID := uuid.New()
w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node}, false)
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
w1.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
Id: w2ID[:],
Node: pn,
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
}})
})
conn := make(chan struct{})
@ -199,7 +218,11 @@ func TestDERPForceWebSockets(t *testing.T) {
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID,
&codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("client"),
},
)
require.NoError(t, err)
defer func() {
_ = conn.Close()

View File

@ -121,12 +121,23 @@ func NewServerTailnet(
}
tn.agentConn.Store(&agentConn)
err = tn.getAgentConn().UpdateSelf(conn.Node())
pn, err := tailnet.NodeToProto(conn.Node())
if err != nil {
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
} else {
err = tn.getAgentConn().UpdateSelf(pn)
if err != nil {
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
}
}
conn.SetNodeCallback(func(node *tailnet.Node) {
err := tn.getAgentConn().UpdateSelf(node)
pn, err := tailnet.NodeToProto(node)
if err != nil {
tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
return
}
err = tn.getAgentConn().UpdateSelf(pn)
if err != nil {
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
@ -191,21 +202,9 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
ID: tailnet.NodeID(agentID),
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
})
if err != nil {
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
continue
}
if !deleted {
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
}
deletedCount++
delete(s.agentConnectionTimes, agentID)
err = agentConn.UnsubscribeAgent(agentID)
err := agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
}
@ -221,7 +220,7 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
func (s *ServerTailnet) watchAgentUpdates() {
for {
conn := s.getAgentConn()
nodes, ok := conn.NextUpdate(s.ctx)
resp, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
@ -231,7 +230,7 @@ func (s *ServerTailnet) watchAgentUpdates() {
return
}
err := s.conn.UpdateNodes(nodes, false)
err := s.conn.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
if xerrors.Is(err, tailnet.ErrConnClosed) {
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))

View File

@ -3,7 +3,6 @@ package coderd_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/netip"
@ -204,22 +203,20 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
Logger: logger.Named("client"),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
<-serveClientDone
})
go func() {
defer close(serveClientDone)
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
clientID := uuid.New()
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, manifest.AgentID,
coord, conn,
)
t.Cleanup(func() {
_ = coordination.Close()
})
conn.SetNodeCallback(sendNode)
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,

View File

@ -30,6 +30,10 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion {
return v
}
func (v *APIVersion) String() string {
return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor)
}
// Validate validates the given version against the given constraints:
// A given major.minor version is valid iff:
// 1. The requested major version is contained within v.supportedMajors
@ -42,10 +46,6 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion {
// - 1.x is supported,
// - 2.0, 2.1, and 2.2 are supported,
// - 2.3+ is not supported.
func (v *APIVersion) String() string {
return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor)
}
func (v *APIVersion) Validate(version string) error {
major, minor, err := Parse(version)
if err != nil {

View File

@ -857,8 +857,6 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
// Deprecated: use api.tailnet.AgentConn instead.
// See: https://github.com/coder/coder/issues/8218
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
clientConn, serverConn := net.Pipe()
derpMap := api.DERPMap()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
@ -868,8 +866,6 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
})
if err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
ctx, cancel := context.WithCancel(api.ctx)
@ -887,10 +883,10 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
return left
})
sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, true)
})
conn.SetNodeCallback(sendNodes)
clientID := uuid.New()
coordination := tailnet.NewInMemoryCoordination(ctx, api.Logger,
clientID, agentID,
*(api.TailnetCoordinator.Load()), conn)
// Check for updated DERP map every 5 seconds.
go func() {
@ -920,27 +916,13 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
AgentID: agentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error {
_ = coordination.Close()
cancel()
_ = clientConn.Close()
_ = serverConn.Close()
return nil
},
})
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
// Sometimes, we get benign closed pipe errors when the server is
// shutting down.
if api.ctx.Err() == nil {
api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
}
_ = agentConn.Close()
}
}()
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
_ = serverConn.Close()
_ = clientConn.Close()
cancel()
return nil, xerrors.Errorf("agent not reachable")
}

View File

@ -535,7 +535,6 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
})
require.NoError(t, err)
defer conn.Close()
require.True(t, conn.BlockEndpoints())
require.True(t, conn.AwaitReachable(ctx))
_, p2p, _, err := conn.Ping(ctx)

View File

@ -12,14 +12,19 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
@ -27,7 +32,9 @@ import (
"github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
@ -41,7 +48,7 @@ func TestCache(t *testing.T) {
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil
return setupAgent(t, agentsdk.Manifest{}, 0)
}, 0)
defer func() {
_ = cache.Close()
@ -54,10 +61,10 @@ func TestCache(t *testing.T) {
})
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
called := int32(0)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
called.Add(1)
return setupAgent(t, agentsdk.Manifest{}, 0), nil
atomic.AddInt32(&called, 1)
return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond)
defer func() {
_ = cache.Close()
@ -70,12 +77,12 @@ func TestCache(t *testing.T) {
require.NoError(t, err)
release()
<-conn.Closed()
require.Equal(t, int32(2), called.Load())
require.Equal(t, int32(2), called)
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil
return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond)
defer func() {
_ = cache.Close()
@ -108,7 +115,7 @@ func TestCache(t *testing.T) {
go server.Serve(random)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil
return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond)
defer func() {
_ = cache.Close()
@ -154,7 +161,7 @@ func TestCache(t *testing.T) {
})
}
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) {
t.Helper()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
@ -184,18 +191,25 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
DERPForceWebSockets: manifest.DERPForceWebSockets,
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
// setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the
// test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks.
if !assert.NoError(t, err) {
return nil, err
}
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
})
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, false)
clientID := uuid.New()
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, manifest.AgentID,
coordinator, conn,
)
t.Cleanup(func() {
_ = coordination.Close()
})
conn.SetNodeCallback(sendNode)
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
@ -206,16 +220,20 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
// setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end
// the test, which in turn closes the wsconncache, which in turn waits for the Dialer and
// deadlocks.
t.Error("agent not reachable")
return nil, xerrors.New("agent not reachable")
}
return agentConn
return agentConn, nil
}
type client struct {
t *testing.T
agentID uuid.UUID
manifest agentsdk.Manifest
coordinator tailnet.CoordinatorV1
coordinator tailnet.Coordinator
}
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
@ -240,19 +258,53 @@ func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate,
}, nil
}
func (c *client) Listen(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
conn, lis := drpcsdk.MemTransportPipe()
closed := make(chan struct{})
c.t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
_ = conn.Close()
_ = lis.Close()
<-closed
})
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&c.coordinator)
mux := drpcmux.New()
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
// TODO: handle DERPMap too!
DerpMapUpdateFrequency: time.Hour,
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
if err != nil {
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
serveCtx, cancel := context.WithCancel(context.Background())
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "wsconncache_test-agent",
ID: c.agentID,
Auth: auth,
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
server.Serve(serveCtx, lis)
close(closed)
}()
return clientConn, nil
return conn, nil
}
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {

View File

@ -14,12 +14,15 @@ import (
"cloud.google.com/go/compute/metadata"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"storj.io/drpc"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/retry"
)
@ -280,8 +283,8 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
// Listen connects to the workspace agent coordinate WebSocket
// that handles connection negotiation.
func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/coordinate")
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
@ -312,14 +315,21 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "coordinate")
return &closeNetConn{
netConn := &closeNetConn{
Conn: wsNetConn,
closeFunc: func() {
cancelFunc()
_ = conn.Close(websocket.StatusGoingAway, "Listen closed")
<-pingClosed
},
}, nil
}
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(netConn, config)
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return drpcsdk.MultiplexedConn(session), nil
}
type PostAppHealthsRequest struct {

View File

@ -313,6 +313,9 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
q.Add("version", tailnet.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
closedCoordinator := make(chan struct{})
// Must only ever be used once, send error OR close to avoid
// reassignment race. Buffered so we don't hang in goroutine.
@ -344,12 +347,22 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
continue
}
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, false)
})
conn.SetNodeCallback(sendNode)
client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary))
if err != nil {
options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
continue
}
coordinate, err := client.Coordinate(ctx)
if err != nil {
options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
continue
}
coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID)
options.Logger.Debug(ctx, "serving coordinator")
err = <-errChan
err = <-coordination.Error()
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusGoingAway, "")
return

View File

@ -8,6 +8,7 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
@ -53,6 +54,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
ctx := r.Context()
version := "1.0"
msgType := websocket.MessageText
qv := r.URL.Query().Get("version")
if qv != "" {
version = qv
@ -66,6 +68,11 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
})
return
}
maj, _, _ := apiversion.Parse(version)
if maj >= 2 {
// Versions 2+ use dRPC over a binary connection
msgType = websocket.MessageBinary
}
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
@ -81,7 +88,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
return
}
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText)
ctx, nc := websocketNetConn(ctx, conn, msgType)
defer nc.Close()
id := uuid.New()

View File

@ -10,6 +10,7 @@ import (
"github.com/moby/moby/pkg/namesgenerator"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/types/key"
"cdr.dev/slog/sloggers/slogtest"
@ -20,6 +21,7 @@ import (
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
@ -27,6 +29,12 @@ import (
func Test_agentIsLegacy(t *testing.T) {
t.Parallel()
nodeKey := key.NewNode().Public()
discoKey := key.NewDisco().Public()
nkBin, err := nodeKey.MarshalBinary()
require.NoError(t, err)
dkBin, err := discoKey.MarshalText()
require.NoError(t, err)
t.Run("Legacy", func(t *testing.T) {
t.Parallel()
@ -54,18 +62,18 @@ func Test_agentIsLegacy(t *testing.T) {
nodeID := uuid.New()
ma := coordinator.ServeMultiAgent(nodeID)
defer ma.Close()
require.NoError(t, ma.UpdateSelf(&agpl.Node{
ID: 55,
AsOf: time.Unix(1689653252, 0),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 0,
DERPLatency: map[string]float64{
require.NoError(t, ma.UpdateSelf(&proto.Node{
Id: 55,
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: nkBin,
Disco: string(dkBin),
PreferredDerp: 0,
DerpLatency: map[string]float64{
"0": 1.0,
},
DERPForcedWebsocket: map[int]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
DerpForcedWebsocket: map[int32]string{},
Addresses: []string{codersdk.WorkspaceAgentIP.String() + "/128"},
AllowedIps: []string{codersdk.WorkspaceAgentIP.String() + "/128"},
Endpoints: []string{"192.168.1.1:18842"},
}))
require.Eventually(t, func() bool {
@ -114,18 +122,18 @@ func Test_agentIsLegacy(t *testing.T) {
nodeID := uuid.New()
ma := coordinator.ServeMultiAgent(nodeID)
defer ma.Close()
require.NoError(t, ma.UpdateSelf(&agpl.Node{
ID: 55,
AsOf: time.Unix(1689653252, 0),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 0,
DERPLatency: map[string]float64{
require.NoError(t, ma.UpdateSelf(&proto.Node{
Id: 55,
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: nkBin,
Disco: string(dkBin),
PreferredDerp: 0,
DerpLatency: map[string]float64{
"0": 1.0,
},
DERPForcedWebsocket: map[int]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)},
DerpForcedWebsocket: map[int32]string{},
Addresses: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()},
AllowedIps: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()},
Endpoints: []string{"192.168.1.1:18842"},
}))
require.Eventually(t, func() bool {

View File

@ -6,12 +6,15 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
@ -39,22 +42,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord1.ServeMultiAgent(id)
defer ma1.Close()
ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
ma1.sendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.Close())
ma1.close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -86,23 +86,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord1.ServeMultiAgent(id)
defer ma1.Close()
ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
ma1.sendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
require.NoError(t, ma1.Close())
ma1.unsubscribeAgent(agent1.id)
ma1.close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -134,37 +131,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord1.ServeMultiAgent(id)
defer ma1.Close()
ma1 := newTestMultiAgent(t, coord1)
defer ma1.close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
ma1.assertEventuallyHasDERPs(ctx, 1)
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}))
ma1.sendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
ma1.unsubscribeAgent(agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9}))
ma1.sendNodeWithDERP(9)
assertNeverHasDERPs(ctx, t, agent1, 9)
}()
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8)
ma1.assertNeverHasDERPs(ctx, 8)
}()
require.NoError(t, ma1.Close())
ma1.close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -201,22 +196,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord2.ServeMultiAgent(id)
defer ma1.Close()
ma1 := newTestMultiAgent(t, coord2)
defer ma1.close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
ma1.sendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.Close())
ma1.close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -254,22 +246,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord2.ServeMultiAgent(id)
defer ma1.Close()
ma1 := newTestMultiAgent(t, coord2)
defer ma1.close()
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
ma1.sendNodeWithDERP(3)
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
ma1.assertEventuallyHasDERPs(ctx, 1)
require.NoError(t, ma1.Close())
ma1.close()
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -316,33 +305,129 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
defer agent1.close()
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
id := uuid.New()
ma1 := coord3.ServeMultiAgent(id)
defer ma1.Close()
ma1 := newTestMultiAgent(t, coord3)
defer ma1.close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
ma1.subscribeAgent(agent1.id)
ma1.assertEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.SubscribeAgent(agent2.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6)
ma1.subscribeAgent(agent2.id)
ma1.assertEventuallyHasDERPs(ctx, 6)
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2)
ma1.assertEventuallyHasDERPs(ctx, 2)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
ma1.sendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3)
require.NoError(t, ma1.Close())
ma1.close()
require.NoError(t, agent1.close())
require.NoError(t, agent2.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id)
}
type testMultiAgent struct {
t testing.TB
id uuid.UUID
a agpl.MultiAgentConn
nodeKey []byte
discoKey string
}
func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id)
return m
}
func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: derp,
})
require.NoError(m.t, err)
}
func (m *testMultiAgent) close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}
func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}

View File

@ -818,56 +818,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte
}
}
func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
require.True(t, ok)
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
if !ok {
return
}
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {

View File

@ -96,7 +96,11 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
return xerrors.Errorf("unsubscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
err := ma.UpdateSelf(msg.Node)
pn, err := agpl.NodeToProto(msg.Node)
if err != nil {
return err
}
err = ma.UpdateSelf(pn)
if err != nil {
return xerrors.Errorf("update self: %w", err)
}
@ -110,11 +114,14 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
var lastData []byte
for {
nodes, ok := ma.NextUpdate(ctx)
resp, ok := ma.NextUpdate(ctx)
if !ok {
return xerrors.New("multiagent is closed")
}
nodes, err := agpl.OnlyNodeUpdates(resp)
if err != nil {
return xerrors.Errorf("failed to convert response: %w", err)
}
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
if err != nil {
return err

View File

@ -158,7 +158,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
// TODO: Probably do some version checking here
info, err := client.SDKClient.BuildInfo(ctx)
if err != nil {
return nil, fmt.Errorf("buildinfo: %w", errors.Join(
return nil, xerrors.Errorf("buildinfo: %w", errors.Join(
xerrors.Errorf("unable to fetch build info from primary coderd. Are you sure %q is a coderd instance?", opts.DashboardURL),
err,
))

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"
@ -23,6 +22,7 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
// Client is a HTTP client for a subset of Coder API routes that external
@ -438,6 +438,9 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
cancel()
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
q.Add("version", agpl.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
coordinateHeaders := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader
if c.SDKClient.SessionTokenHeader != "" {
@ -457,10 +460,24 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
nc := websocket.NetConn(ctx, conn, websocket.MessageText)
nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
client, err := agpl.NewDRPCClient(nc)
if err != nil {
logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "")
return nil, xerrors.Errorf("failed to create DRPCClient: %w", err)
}
protocol, err := client.Coordinate(ctx)
if err != nil {
logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "")
return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err)
}
rma := remoteMultiAgentHandler{
sdk: c,
nc: nc,
logger: logger,
protocol: protocol,
cancel: cancel,
legacyAgentCache: map[uuid.UUID]bool{},
}
@ -471,103 +488,75 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
OnSubscribe: rma.OnSubscribe,
OnUnsubscribe: rma.OnUnsubscribe,
OnNodeUpdate: rma.OnNodeUpdate,
OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") },
OnRemove: rma.OnRemove,
}).Init()
go func() {
<-ctx.Done()
ma.Close()
_ = conn.Close(websocket.StatusGoingAway, "closed")
}()
go func() {
defer cancel()
dec := json.NewDecoder(nc)
for {
var msg CoordinateNodes
err := dec.Decode(&msg)
if err != nil {
if xerrors.Is(err, io.EOF) {
logger.Info(ctx, "websocket connection severed", slog.Error(err))
return
}
logger.Error(ctx, "decode coordinator nodes", slog.Error(err))
return
}
err = ma.Enqueue(msg.Nodes)
if err != nil {
logger.Error(ctx, "enqueue nodes from coordinator", slog.Error(err))
continue
}
}
}()
rma.ma = ma
go rma.respLoop()
return ma, nil
}
type remoteMultiAgentHandler struct {
sdk *Client
nc net.Conn
cancel func()
sdk *Client
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
ma *agpl.MultiAgent
cancel func()
legacyMu sync.RWMutex
legacyAgentCache map[uuid.UUID]bool
legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse]
}
func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return xerrors.Errorf("json marshal message: %w", err)
}
func (a *remoteMultiAgentHandler) respLoop() {
{
defer a.cancel()
for {
resp, err := a.protocol.Recv()
if err != nil {
if xerrors.Is(err, io.EOF) {
a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err))
return
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
a.cancel()
return xerrors.Errorf("set write deadline: %w", err)
}
_, err = a.nc.Write(data)
if err != nil {
a.cancel()
return xerrors.Errorf("write message: %w", err)
}
a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err))
return
}
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = a.nc.SetWriteDeadline(time.Time{})
if err != nil {
a.cancel()
return xerrors.Errorf("clear write deadline: %w", err)
err = a.ma.Enqueue(resp)
if err != nil {
a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err))
continue
}
}
}
return nil
}
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error {
return a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeNodeUpdate,
Node: node,
})
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error {
return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}})
}
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
return nil, a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeSubscribe,
AgentID: agentID,
})
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
}
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeUnsubscribe,
AgentID: agentID,
})
return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
}
func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) {
err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err))
}
_ = a.protocol.CloseSend()
}
func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool {

View File

@ -18,8 +18,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
@ -30,6 +31,7 @@ import (
"github.com/coder/coder/v2/enterprise/tailnet"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
@ -156,25 +158,48 @@ func TestDialCoordinator(t *testing.T) {
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID = uuid.New()
serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t))
r = chi.NewRouter()
srv = httptest.NewServer(r)
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID = uuid.UUID{33}
proxyID = uuid.UUID{44}
mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t))
coord agpl.Coordinator = mCoord
r = chi.NewRouter()
srv = httptest.NewServer(r)
)
defer cancel()
defer srv.Close()
coordPtr := atomic.Pointer[agpl.Coordinator]{}
coordPtr.Store(&coord)
cSrv, err := tailnet.NewClientService(
logger, &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
// buffer the channels here, so we don't need to read and write in goroutines to
// avoid blocking
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}).
Times(1).
Return(reqs, resps)
serveMACErr := make(chan error, 1)
r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, nil)
require.NoError(t, err)
nc := websocket.NetConn(r.Context(), conn, websocket.MessageText)
defer serverMultiAgent.Close()
err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent)
if !xerrors.Is(err, io.EOF) {
assert.NoError(t, err)
if !assert.NoError(t, err) {
return
}
version := r.URL.Query().Get("version")
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
return
}
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID)
serveMACErr <- err
})
r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{
@ -188,51 +213,50 @@ func TestDialCoordinator(t *testing.T) {
client := wsproxysdk.New(u)
client.SDKClient.SetLogger(logger)
expected := []*agpl.Node{{
ID: 55,
AsOf: time.Unix(1689653252, 0),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 0,
DERPLatency: map[string]float64{
"0": 1.0,
peerID := uuid.UUID{55}
peerNodeKey, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
peerDiscoKey, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: peerID[:],
Node: &proto.Node{
Id: 55,
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: peerNodeKey[:],
Disco: string(peerDiscoKey),
PreferredDerp: 0,
DerpLatency: map[string]float64{
"0": 1.0,
},
DerpForcedWebsocket: map[int32]string{},
Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
Endpoints: []string{"192.168.1.1:18842"},
},
DERPForcedWebsocket: map[int]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
Endpoints: []string{"192.168.1.1:18842"},
}}
sendNode := make(chan struct{})
serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) {
select {
case <-sendNode:
return expected, true
case <-ctx.Done():
return nil, false
}
})
}}}
rma, err := client.DialCoordinator(ctx)
require.NoError(t, err)
// Subscribe
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) {
close(ch)
})
require.NoError(t, rma.SubscribeAgent(agentID))
waitOrCancel(ctx, t, ch)
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
}
// Read updated agent node
{
sendNode <- struct{}{}
got, ok := rma.NextUpdate(ctx)
resps <- expected
resp, ok := rma.NextUpdate(ctx)
assert.True(t, ok)
got[0].AsOf = got[0].AsOf.In(time.Local)
assert.Equal(t, *expected[0], *got[0])
updates := resp.GetPeerUpdates()
assert.Len(t, updates, 1)
eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode())
assert.NoError(t, err)
assert.True(t, eq)
}
// Check legacy
{
@ -241,45 +265,38 @@ func TestDialCoordinator(t *testing.T) {
}
// UpdateSelf
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) {
node.AsOf = node.AsOf.In(time.Local)
assert.Equal(t, expected[0], node)
close(ch)
})
require.NoError(t, rma.UpdateSelf(expected[0]))
waitOrCancel(ctx, t, ch)
require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode()))
req := testutil.RequireRecvCtx(ctx, t, reqs)
eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode())
require.NoError(t, err)
require.True(t, eq)
}
// Unsubscribe
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) {
close(ch)
})
require.NoError(t, rma.UnsubscribeAgent(agentID))
waitOrCancel(ctx, t, ch)
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId())
}
// Close
{
ch := make(chan struct{})
serverMultiAgent.EXPECT().Close().Do(func() {
close(ch)
})
require.NoError(t, rma.Close())
waitOrCancel(ctx, t, ch)
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for req close")
case _, ok := <-reqs:
require.False(t, ok, "didn't close requests")
}
require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr))
}
})
}
func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
case <-ctx.Done():
t.Fatal("timed out waiting for channel")
}
}
type ResponseRecorder struct {
rw *httptest.ResponseRecorder
wasWritten atomic.Bool

View File

@ -490,6 +490,18 @@ func (c *configMaps) protoNodeToTailcfg(p *proto.Node) (*tailcfg.Node, error) {
}, nil
}
// nodeAddresses returns the addresses for the peer with the given publicKey, if known.
func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
c.L.Lock()
defer c.L.Unlock()
for _, lc := range c.peers {
if lc.node.Key == publicKey {
return lc.node.Addresses, true
}
}
return nil, false
}
type peerLifecycle struct {
peerID uuid.UUID
node *tailcfg.Node

View File

@ -3,48 +3,40 @@ package tailnet
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"os"
"reflect"
"strconv"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"go4.org/netipx"
"golang.org/x/xerrors"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/connstats"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/tsdial"
"tailscale.com/net/tstun"
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
tslogger "tailscale.com/types/logger"
"tailscale.com/types/netlogtype"
"tailscale.com/types/netmap"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/netstack"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg/nmcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/tailnet/proto"
)
var ErrConnClosed = xerrors.New("connection closed")
@ -128,42 +120,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
}
nodePrivateKey := key.NewNode()
nodePublicKey := nodePrivateKey.Public()
netMap := &netmap.NetworkMap{
DERPMap: options.DERPMap,
NodeKey: nodePublicKey,
PrivateKey: nodePrivateKey,
Addresses: options.Addresses,
PacketFilter: []filter.Match{{
// Allow any protocol!
IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP},
// Allow traffic sourced from anywhere.
Srcs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
},
// Allow traffic to route anywhere.
Dsts: []filter.NetPortRange{
{
Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
{
Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
},
Caps: []filter.CapMatch{},
}},
}
var nodeID tailcfg.NodeID
// If we're provided with a UUID, use it to populate our node ID.
@ -177,14 +133,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
nodeID = tailcfg.NodeID(uid)
}
// This is used by functions below to identify the node via key
netMap.SelfNode = &tailcfg.Node{
ID: nodeID,
Key: nodePublicKey,
Addresses: options.Addresses,
AllowedIPs: options.Addresses,
}
wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
@ -243,7 +191,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
if err != nil {
return nil, xerrors.Errorf("set node private key: %w", err)
}
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
netStack, err := netstack.Create(
Logger(options.Logger.Named("net.netstack")),
@ -262,44 +209,46 @@ func NewConn(options *Options) (conn *Conn, err error) {
}
netStack.ProcessLocalIPs = true
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
options.Logger.Debug(context.Background(), "updating network map")
wireguardEngine.SetNetworkMap(&netMapCopy)
localIPSet := netipx.IPSetBuilder{}
for _, addr := range netMap.Addresses {
localIPSet.AddPrefix(addr)
}
localIPs, _ := localIPSet.IPSet()
logIPSet := netipx.IPSetBuilder{}
logIPs, _ := logIPSet.IPSet()
wireguardEngine.SetFilter(filter.New(
netMap.PacketFilter,
localIPs,
logIPs,
cfgMaps := newConfigMaps(
options.Logger,
wireguardEngine,
nodeID,
nodePrivateKey,
magicConn.DiscoPublicKey(),
)
cfgMaps.setAddresses(options.Addresses)
cfgMaps.setDERPMap(DERPMapToProto(options.DERPMap))
cfgMaps.setBlockEndpoints(options.BlockEndpoints)
nodeUp := newNodeUpdater(
options.Logger,
nil,
Logger(options.Logger.Named("net.packet-filter")),
))
nodeID,
nodePrivateKey.Public(),
magicConn.DiscoPublicKey(),
)
nodeUp.setAddresses(options.Addresses)
nodeUp.setBlockEndpoints(options.BlockEndpoints)
wireguardEngine.SetStatusCallback(nodeUp.setStatus)
wireguardEngine.SetNetInfoCallback(nodeUp.setNetInfo)
magicConn.SetDERPForcedWebsocketCallback(nodeUp.setDERPForcedWebsocket)
server := &Conn{
blockEndpoints: options.BlockEndpoints,
derpForceWebSockets: options.DERPForceWebSockets,
closed: make(chan struct{}),
logger: options.Logger,
magicConn: magicConn,
dialer: dialer,
listeners: map[listenKey]*listener{},
peerMap: map[tailcfg.NodeID]*tailcfg.Node{},
lastDERPForcedWebSockets: map[int]string{},
tunDevice: sys.Tun.Get(),
netMap: netMap,
netStack: netStack,
wireguardMonitor: wireguardMonitor,
closed: make(chan struct{}),
logger: options.Logger,
magicConn: magicConn,
dialer: dialer,
listeners: map[listenKey]*listener{},
tunDevice: sys.Tun.Get(),
netStack: netStack,
wireguardMonitor: wireguardMonitor,
wireguardRouter: &router.Config{
LocalAddrs: netMap.Addresses,
LocalAddrs: options.Addresses,
},
wireguardEngine: wireguardEngine,
configMaps: cfgMaps,
nodeUpdater: nodeUp,
}
defer func() {
if err != nil {
@ -307,52 +256,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
}
}()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
if err != nil {
return
}
server.lastMutex.Lock()
if s.AsOf.Before(server.lastStatus) {
// Don't process outdated status!
server.lastMutex.Unlock()
return
}
server.lastStatus = s.AsOf
if endpointsEqual(s.LocalAddrs, server.lastEndpoints) {
// No need to update the node if nothing changed!
server.lastMutex.Unlock()
return
}
server.lastEndpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...)
server.lastMutex.Unlock()
server.sendNode()
})
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni))
server.lastMutex.Lock()
if reflect.DeepEqual(server.lastNetInfo, ni) {
server.lastMutex.Unlock()
return
}
server.lastNetInfo = ni.Clone()
server.lastMutex.Unlock()
server.sendNode()
})
magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) {
server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason))
server.lastMutex.Lock()
if server.lastDERPForcedWebSockets[region] == reason {
server.lastMutex.Unlock()
return
}
server.lastDERPForcedWebSockets[region] = reason
server.lastMutex.Unlock()
server.sendNode()
})
netStack.GetTCPHandlerForFlow = server.forwardTCP
err = netStack.Start(nil)
@ -389,16 +292,14 @@ func IPFromUUID(uid uuid.UUID) netip.Addr {
// Conn is an actively listening Wireguard connection.
type Conn struct {
mutex sync.Mutex
closed chan struct{}
logger slog.Logger
blockEndpoints bool
derpForceWebSockets bool
mutex sync.Mutex
closed chan struct{}
logger slog.Logger
dialer *tsdial.Dialer
tunDevice *tstun.Wrapper
peerMap map[tailcfg.NodeID]*tailcfg.Node
netMap *netmap.NetworkMap
configMaps *configMaps
nodeUpdater *nodeUpdater
netStack *netstack.Impl
magicConn *magicsock.Conn
wireguardMonitor *netmon.Monitor
@ -406,17 +307,6 @@ type Conn struct {
wireguardEngine wgengine.Engine
listeners map[listenKey]*listener
lastMutex sync.Mutex
nodeSending bool
nodeChanged bool
// It's only possible to store these values via status functions,
// so the values must be stored for retrieval later on.
lastStatus time.Time
lastEndpoints []tailcfg.Endpoint
lastDERPForcedWebSockets map[int]string
lastNetInfo *tailcfg.NetInfo
nodeCallback func(node *Node)
trafficStats *connstats.Statistics
}
@ -425,57 +315,30 @@ func (c *Conn) MagicsockSetDebugLoggingEnabled(enabled bool) {
}
func (c *Conn) SetAddresses(ips []netip.Prefix) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.netMap.Addresses = ips
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
c.configMaps.setAddresses(ips)
c.nodeUpdater.setAddresses(ips)
return nil
}
func (c *Conn) Addresses() []netip.Prefix {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.Addresses
}
func (c *Conn) SetNodeCallback(callback func(node *Node)) {
c.lastMutex.Lock()
c.nodeCallback = callback
c.lastMutex.Unlock()
c.sendNode()
c.nodeUpdater.setCallback(callback)
}
// SetDERPMap updates the DERPMap of a connection.
func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
c.wireguardEngine.SetDERPMap(derpMap)
c.netMap.DERPMap = derpMap
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
c.configMaps.setDERPMap(DERPMapToProto(derpMap))
}
func (c *Conn) SetDERPForceWebSockets(v bool) {
c.logger.Info(context.Background(), "setting DERP Force Websockets", slog.F("force_derp_websockets", v))
c.magicConn.SetDERPForceWebsockets(v)
}
// SetBlockEndpoints sets whether or not to block P2P endpoints. This setting
// SetBlockEndpoints sets whether to block P2P endpoints. This setting
// will only apply to new peers.
func (c *Conn) SetBlockEndpoints(blockEndpoints bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.blockEndpoints = blockEndpoints
c.configMaps.setBlockEndpoints(blockEndpoints)
c.nodeUpdater.setBlockEndpoints(blockEndpoints)
}
// SetDERPRegionDialer updates the dialer to use for connecting to DERP regions.
@ -483,186 +346,24 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail
c.magicConn.SetDERPRegionDialer(dialer)
}
// UpdateNodes connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary. If replacePeers is
// true, all peers will be removed before adding the new ones.
//
//nolint:revive // Complains about replacePeers.
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// UpdatePeers connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary.
func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
if c.isClosed() {
return ErrConnClosed
}
status := c.Status()
if replacePeers {
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
}
for _, peer := range c.netMap.Peers {
peerStatus, ok := status.Peer[peer.Key]
if !ok {
continue
}
// If this peer was added in the last 5 minutes, assume it
// could still be active.
if time.Since(peer.Created) < 5*time.Minute {
continue
}
// We double-check that it's safe to remove by ensuring no
// handshake has been sent in the past 5 minutes as well. Connections that
// are actively exchanging IP traffic will handshake every 2 minutes.
if time.Since(peerStatus.LastHandshake) < 5*time.Minute {
continue
}
c.logger.Debug(context.Background(), "removing peer, last handshake >5m ago",
slog.F("peer", peer.Key), slog.F("last_handshake", peerStatus.LastHandshake),
)
delete(c.peerMap, peer.ID)
}
for _, node := range nodes {
// If no preferred DERP is provided, we can't reach the node.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
continue
}
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
peerStatus, ok := status.Peer[node.Key]
peerNode := &tailcfg.Node{
ID: node.ID,
Created: time.Now(),
Key: node.Key,
DiscoKey: node.DiscoKey,
Addresses: node.Addresses,
AllowedIPs: node.AllowedIPs,
Endpoints: node.Endpoints,
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
Hostinfo: (&tailcfg.Hostinfo{}).View(),
// Starting KeepAlive messages at the initialization of a connection
// causes a race condition. If we handshake before the peer has our
// node, we'll have wait for 5 seconds before trying again. Ideally,
// the first handshake starts when the user first initiates a
// connection to the peer. After a successful connection we enable
// keep alives to persist the connection and keep it from becoming
// idle. SSH connections don't send send packets while idle, so we
// use keep alives to avoid random hangs while we set up the
// connection again after inactivity.
KeepAlive: ok && peerStatus.Active,
}
if c.blockEndpoints {
peerNode.Endpoints = nil
}
c.peerMap[node.ID] = peerNode
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
// PeerSelector is used to select a peer from within a Tailnet.
type PeerSelector struct {
ID tailcfg.NodeID
IP netip.Prefix
}
func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.isClosed() {
return false, ErrConnClosed
}
deleted = false
for _, peer := range c.peerMap {
if peer.ID == selector.ID {
delete(c.peerMap, peer.ID)
deleted = true
break
}
for _, peerIP := range peer.Addresses {
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
delete(c.peerMap, peer.ID)
deleted = true
break
}
}
}
if !deleted {
return false, nil
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err = c.reconfig()
if err != nil {
return false, xerrors.Errorf("reconfig: %w", err)
}
return true, nil
}
func (c *Conn) reconfig() error {
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
return xerrors.Errorf("update wireguard config: %w", err)
}
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
if c.isClosed() {
return nil
}
if errors.Is(err, wgengine.ErrNoChanges) {
return nil
}
return xerrors.Errorf("reconfig: %w", err)
}
c.configMaps.updatePeers(updates)
return nil
}
// NodeAddresses returns the addresses of a node from the NetworkMap.
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
for _, node := range c.netMap.Peers {
if node.Key == publicKey {
return node.Addresses, true
}
}
return nil, false
return c.configMaps.nodeAddresses(publicKey)
}
// Status returns the current ipnstate of a connection.
func (c *Conn) Status() *ipnstate.Status {
sb := &ipnstate.StatusBuilder{WantPeers: true}
c.wireguardEngine.UpdateStatus(sb)
return sb.Status()
return c.configMaps.status()
}
// Ping sends a ping to the Wireguard engine.
@ -689,16 +390,9 @@ func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *i
// DERPMap returns the currently set DERP mapping.
func (c *Conn) DERPMap() *tailcfg.DERPMap {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.DERPMap
}
// BlockEndpoints returns whether or not P2P is blocked.
func (c *Conn) BlockEndpoints() bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.blockEndpoints
c.configMaps.L.Lock()
defer c.configMaps.L.Unlock()
return c.configMaps.derpMapLocked()
}
// AwaitReachable pings the provided IP continually until the
@ -759,6 +453,9 @@ func (c *Conn) Closed() <-chan struct{} {
// Close shuts down the Wireguard connection.
func (c *Conn) Close() error {
c.logger.Info(context.Background(), "closing tailnet Conn")
c.configMaps.close()
c.nodeUpdater.close()
c.mutex.Lock()
select {
case <-c.closed:
@ -808,91 +505,11 @@ func (c *Conn) isClosed() bool {
}
}
func (c *Conn) sendNode() {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
if c.nodeSending {
c.nodeChanged = true
return
}
node := c.selfNode()
// Conn.UpdateNodes will skip any nodes that don't have the PreferredDERP
// set to non-zero, since we cannot reach nodes without DERP for discovery.
// Therefore, there is no point in sending the node without this, and we can
// save ourselves from churn in the tailscale/wireguard layer.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node))
return
}
nodeCallback := c.nodeCallback
if nodeCallback == nil {
return
}
c.nodeSending = true
go func() {
c.logger.Debug(context.Background(), "sending node", slog.F("node", node))
nodeCallback(node)
c.lastMutex.Lock()
c.nodeSending = false
if c.nodeChanged {
c.nodeChanged = false
c.lastMutex.Unlock()
c.sendNode()
return
}
c.lastMutex.Unlock()
}()
}
// Node returns the last node that was sent to the node callback.
func (c *Conn) Node() *Node {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
return c.selfNode()
}
func (c *Conn) selfNode() *Node {
endpoints := make([]string, 0, len(c.lastEndpoints))
for _, addr := range c.lastEndpoints {
endpoints = append(endpoints, addr.Addr.String())
}
var preferredDERP int
var derpLatency map[string]float64
derpForcedWebsocket := make(map[int]string, 0)
if c.lastNetInfo != nil {
preferredDERP = c.lastNetInfo.PreferredDERP
derpLatency = c.lastNetInfo.DERPLatency
if c.derpForceWebSockets {
// We only need to store this for a single region, since this is
// mostly used for debugging purposes and doesn't actually have a
// code purpose.
derpForcedWebsocket[preferredDERP] = "DERP is configured to always fallback to WebSockets"
} else {
for k, v := range c.lastDERPForcedWebSockets {
derpForcedWebsocket[k] = v
}
}
}
node := &Node{
ID: c.netMap.SelfNode.ID,
AsOf: dbtime.Now(),
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: endpoints,
PreferredDERP: preferredDERP,
DERPLatency: derpLatency,
DERPForcedWebsocket: derpForcedWebsocket,
}
c.mutex.Lock()
if c.blockEndpoints {
node.Endpoints = nil
}
c.mutex.Unlock()
return node
c.nodeUpdater.L.Lock()
defer c.nodeUpdater.L.Unlock()
return c.nodeUpdater.nodeLocked()
}
// This and below is taken _mostly_ verbatim from Tailscale:
@ -1056,15 +673,3 @@ func Logger(logger slog.Logger) tslogger.Logf {
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
})
}
func endpointsEqual(x, y []tailcfg.Endpoint) bool {
if len(x) != len(y) {
return false
}
for i := range x {
if x[i] != y[i] {
return false
}
}
return true
}

View File

@ -5,6 +5,7 @@ import (
"net/netip"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
@ -12,6 +13,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
@ -22,10 +24,10 @@ func TestMain(m *testing.M) {
func TestTailnet(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w1"),
@ -37,6 +39,8 @@ func TestTailnet(t *testing.T) {
})
t.Run("Connect", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
@ -55,14 +59,8 @@ func TestTailnet(t *testing.T) {
_ = w1.Close()
_ = w2.Close()
})
w1.SetNodeCallback(func(node *tailnet.Node) {
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
conn := make(chan struct{}, 1)
go func() {
@ -89,7 +87,7 @@ func TestTailnet(t *testing.T) {
default:
}
})
node := <-nodes
node := testutil.RequireRecvCtx(ctx, t, nodes)
// Ensure this connected over DERP!
require.Len(t, node.DERPForcedWebsocket, 0)
@ -99,6 +97,7 @@ func TestTailnet(t *testing.T) {
t.Run("ForcesWebSockets", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
w1IP := tailnet.IP()
@ -122,14 +121,8 @@ func TestTailnet(t *testing.T) {
_ = w1.Close()
_ = w2.Close()
})
w1.SetNodeCallback(func(node *tailnet.Node) {
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
stitch(t, w2, w1)
stitch(t, w1, w2)
require.True(t, w2.AwaitReachable(ctx, w1IP))
conn := make(chan struct{}, 1)
go func() {
@ -243,11 +236,16 @@ func TestConn_UpdateDERP(t *testing.T) {
err := client1.Close()
assert.NoError(t, err)
}()
client1.SetNodeCallback(func(node *tailnet.Node) {
err := conn.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false)
stitch(t, conn, client1)
pn, err := tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
connID := uuid.New()
err = client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
require.NoError(t, err)
awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel1()
@ -288,7 +286,13 @@ parentLoop:
// ... unless the client updates it's derp map and nodes.
client1.SetDERPMap(derpMap2)
client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false)
pn, err = tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel3()
require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip))
@ -306,13 +310,34 @@ parentLoop:
err := client2.Close()
assert.NoError(t, err)
}()
client2.SetNodeCallback(func(node *tailnet.Node) {
err := conn.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
client2.UpdateNodes([]*tailnet.Node{conn.Node()}, false)
stitch(t, conn, client2)
pn, err = tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
client2.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel4()
require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip))
}
// stitch sends node updates from src Conn as peer updates to dst Conn. Sort of
// like the Coordinator would, but without actually needing a Coordinator.
func stitch(t *testing.T, dst, src *tailnet.Conn) {
srcID := uuid.New()
src.SetNodeCallback(func(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
err = dst.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: srcID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
assert.NoError(t, err)
})
}

View File

@ -3,6 +3,7 @@ package tailnet
import (
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"net"
@ -92,6 +93,237 @@ type Node struct {
Endpoints []string `json:"endpoints"`
}
// Coordinatee is something that can be coordinated over the Coordinate protocol. Usually this is a
// Conn.
type Coordinatee interface {
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
SetNodeCallback(func(*Node))
}
type Coordination interface {
io.Closer
Error() <-chan error
}
type remoteCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
}
func (c *remoteCoordination) Close() error {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
}
return nil
}
func (c *remoteCoordination) Error() <-chan error {
return c.errChan
}
func (c *remoteCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *remoteCoordination) respLoop() {
for {
resp, err := c.protocol.Recv()
if err != nil {
c.sendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("update peers: %w", err))
return
}
}
}
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
// a client---agents should NOT set this!).
func NewRemoteCoordination(logger slog.Logger,
protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee,
tunnelTarget uuid.UUID,
) Coordination {
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
}
if tunnelTarget != uuid.Nil {
c.Lock()
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
c.Unlock()
if err != nil {
c.sendErr(err)
}
}
coordinatee.SetNodeCallback(func(node *Node) {
pn, err := NodeToProto(node)
if err != nil {
c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
c.sendErr(err)
return
}
c.Lock()
defer c.Unlock()
if c.closed {
c.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
c.sendErr(xerrors.Errorf("write: %w", err))
}
})
go c.respLoop()
return c
}
type inMemoryCoordination struct {
sync.Mutex
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}
func (c *inMemoryCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *inMemoryCoordination) Error() <-chan error {
return c.errChan
}
// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing
// or local clients. Set ClientID to uuid.Nil for an agent.
func NewInMemoryCoordination(
ctx context.Context, logger slog.Logger,
clientID, agentID uuid.UUID,
coordinator Coordinator, coordinatee Coordinatee,
) Coordination {
thisID := agentID
logger = logger.With(slog.F("agent_id", agentID))
var auth TunnelAuth = AgentTunnelAuth{}
if clientID != uuid.Nil {
// this is a client connection
auth = ClientTunnelAuth{AgentID: agentID}
logger = logger.With(slog.F("client_id", clientID))
thisID = clientID
}
c := &inMemoryCoordination{
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
}
// use the background context since we will depend exclusively on closing the req channel to
// tell the coordinator we are done.
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
thisID, fmt.Sprintf("inmemory%s", thisID),
auth,
)
go c.respLoop()
if agentID != uuid.Nil {
select {
case <-ctx.Done():
c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err()))
return c
case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}:
// OK!
}
}
coordinatee.SetNodeCallback(func(n *Node) {
pn, err := NodeToProto(n)
if err != nil {
c.logger.Critical(ctx, "failed to convert node", slog.Error(err))
c.sendErr(err)
return
}
c.Lock()
defer c.Unlock()
if c.closed {
return
}
select {
case <-ctx.Done():
c.logger.Info(ctx, "context expired before sending node update")
return
case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}:
c.logger.Debug(ctx, "sent node in-memory to coordinator")
}
})
return c
}
func (c *inMemoryCoordination) respLoop() {
for {
select {
case <-c.closedCh:
c.logger.Debug(context.Background(), "in-memory coordination closed")
return
case resp, ok := <-c.resps:
if !ok {
c.logger.Debug(context.Background(), "in-memory response channel closed")
return
}
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return
}
}
}
}
func (c *inMemoryCoordination) Close() error {
c.Lock()
defer c.Unlock()
c.logger.Debug(context.Background(), "closing in-memory coordination")
if c.closed {
return nil
}
defer close(c.reqs)
c.closed = true
close(c.closedCh)
select {
case <-c.ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
return nil
}
}
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) {
errChan := make(chan error, 1)
@ -237,21 +469,17 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge
}
return false
},
OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) {
OnSubscribe: func(enq Queue, agent uuid.UUID) error {
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
return c.Node(agent), err
return err
},
OnUnsubscribe: func(enq Queue, agent uuid.UUID) error {
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
return err
},
OnNodeUpdate: func(id uuid.UUID, node *Node) error {
pn, err := NodeToProto(node)
if err != nil {
return err
}
OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error {
return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: pn,
Node: node,
}})
},
OnRemove: func(_ Queue) {
@ -285,7 +513,7 @@ const (
type Queue interface {
UniqueID() uuid.UUID
Kind() QueueKind
Enqueue(n []*Node) error
Enqueue(resp *proto.CoordinateResponse) error
Name() string
Stats() (start, lastWrite int64)
Overwrites() int64
@ -793,18 +1021,7 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
return
}
logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp))
nodes, err := OnlyNodeUpdates(resp)
if err != nil {
logger.Critical(ctx, "v1RespLoop failed to decode resp", slog.F("resp", resp), slog.Error(err))
_ = q.CoordinatorClose()
return
}
// don't send empty updates
if len(nodes) == 0 {
logger.Debug(ctx, "v1RespLoop skipping enqueueing 0-length v1 update")
continue
}
err = q.Enqueue(nodes)
err = q.Enqueue(resp)
if err != nil && !xerrors.Is(err, context.Canceled) {
logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err))
}

View File

@ -8,13 +8,15 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/tailnet/proto"
)
type MultiAgentConn interface {
UpdateSelf(node *Node) error
UpdateSelf(node *proto.Node) error
SubscribeAgent(agentID uuid.UUID) error
UnsubscribeAgent(agentID uuid.UUID) error
NextUpdate(ctx context.Context) ([]*Node, bool)
NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool)
AgentIsLegacy(agentID uuid.UUID) bool
Close() error
IsClosed() bool
@ -26,16 +28,16 @@ type MultiAgent struct {
ID uuid.UUID
AgentIsLegacyFunc func(agentID uuid.UUID) bool
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
OnSubscribe func(enq Queue, agent uuid.UUID) error
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
OnNodeUpdate func(id uuid.UUID, node *Node) error
OnNodeUpdate func(id uuid.UUID, node *proto.Node) error
OnRemove func(enq Queue)
ctx context.Context
ctxCancel func()
closed bool
updates chan []*Node
updates chan *proto.CoordinateResponse
closeOnce sync.Once
start int64
lastWrite int64
@ -45,7 +47,7 @@ type MultiAgent struct {
}
func (m *MultiAgent) Init() *MultiAgent {
m.updates = make(chan []*Node, 128)
m.updates = make(chan *proto.CoordinateResponse, 128)
m.start = time.Now().Unix()
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
return m
@ -65,7 +67,7 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool {
var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
func (m *MultiAgent) UpdateSelf(node *Node) error {
func (m *MultiAgent) UpdateSelf(node *proto.Node) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
@ -82,15 +84,11 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
return ErrMultiAgentClosed
}
node, err := m.OnSubscribe(m, agentID)
err := m.OnSubscribe(m, agentID)
if err != nil {
return err
}
if node != nil {
return m.enqueueLocked([]*Node{node})
}
return nil
}
@ -104,17 +102,17 @@ func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
return m.OnUnsubscribe(m, agentID)
}
func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) {
func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) {
select {
case <-ctx.Done():
return nil, false
case nodes, ok := <-m.updates:
return nodes, ok
case resp, ok := <-m.updates:
return resp, ok
}
}
func (m *MultiAgent) Enqueue(nodes []*Node) error {
func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error {
m.mu.RLock()
defer m.mu.RUnlock()
@ -122,14 +120,14 @@ func (m *MultiAgent) Enqueue(nodes []*Node) error {
return nil
}
return m.enqueueLocked(nodes)
return m.enqueueLocked(resp)
}
func (m *MultiAgent) enqueueLocked(nodes []*Node) error {
func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error {
atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
select {
case m.updates <- nodes:
case m.updates <- resp:
return nil
default:
return ErrWouldBlock

View File

@ -75,7 +75,9 @@ func NewClientService(
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))

View File

@ -0,0 +1,142 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: Coordinator)
//
// Generated by this command:
//
// mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
net "net"
http "net/http"
reflect "reflect"
tailnet "github.com/coder/coder/v2/tailnet"
proto "github.com/coder/coder/v2/tailnet/proto"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockCoordinator is a mock of Coordinator interface.
type MockCoordinator struct {
ctrl *gomock.Controller
recorder *MockCoordinatorMockRecorder
}
// MockCoordinatorMockRecorder is the mock recorder for MockCoordinator.
type MockCoordinatorMockRecorder struct {
mock *MockCoordinator
}
// NewMockCoordinator creates a new mock instance.
func NewMockCoordinator(ctrl *gomock.Controller) *MockCoordinator {
mock := &MockCoordinator{ctrl: ctrl}
mock.recorder = &MockCoordinatorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCoordinator) EXPECT() *MockCoordinatorMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockCoordinator) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCoordinator)(nil).Close))
}
// Coordinate mocks base method.
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(chan<- *proto.CoordinateRequest)
ret1, _ := ret[1].(<-chan *proto.CoordinateResponse)
return ret0, ret1
}
// Coordinate indicates an expected call of Coordinate.
func (mr *MockCoordinatorMockRecorder) Coordinate(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Coordinate", reflect.TypeOf((*MockCoordinator)(nil).Coordinate), arg0, arg1, arg2, arg3)
}
// Node mocks base method.
func (m *MockCoordinator) Node(arg0 uuid.UUID) *tailnet.Node {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Node", arg0)
ret0, _ := ret[0].(*tailnet.Node)
return ret0
}
// Node indicates an expected call of Node.
func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0)
}
// ServeAgent mocks base method.
func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ServeAgent indicates an expected call of ServeAgent.
func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2)
}
// ServeClient mocks base method.
func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ServeClient indicates an expected call of ServeClient.
func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2)
}
// ServeHTTPDebug mocks base method.
func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ServeHTTPDebug", arg0, arg1)
}
// ServeHTTPDebug indicates an expected call of ServeHTTPDebug.
func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1)
}
// ServeMultiAgent mocks base method.
func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeMultiAgent", arg0)
ret0, _ := ret[0].(tailnet.MultiAgentConn)
return ret0
}
// ServeMultiAgent indicates an expected call of ServeMultiAgent.
func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0)
}

View File

@ -1,141 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn)
//
// Generated by this command:
//
// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
reflect "reflect"
tailnet "github.com/coder/coder/v2/tailnet"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockMultiAgentConn is a mock of MultiAgentConn interface.
type MockMultiAgentConn struct {
ctrl *gomock.Controller
recorder *MockMultiAgentConnMockRecorder
}
// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn.
type MockMultiAgentConnMockRecorder struct {
mock *MockMultiAgentConn
}
// NewMockMultiAgentConn creates a new mock instance.
func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn {
mock := &MockMultiAgentConn{ctrl: ctrl}
mock.recorder = &MockMultiAgentConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder {
return m.recorder
}
// AgentIsLegacy mocks base method.
func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AgentIsLegacy", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// AgentIsLegacy indicates an expected call of AgentIsLegacy.
func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0)
}
// Close mocks base method.
func (m *MockMultiAgentConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close))
}
// IsClosed mocks base method.
func (m *MockMultiAgentConn) IsClosed() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosed")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosed indicates an expected call of IsClosed.
func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed))
}
// NextUpdate mocks base method.
func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextUpdate", arg0)
ret0, _ := ret[0].([]*tailnet.Node)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// NextUpdate indicates an expected call of NextUpdate.
func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0)
}
// SubscribeAgent mocks base method.
func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SubscribeAgent indicates an expected call of SubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0)
}
// UnsubscribeAgent mocks base method.
func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UnsubscribeAgent indicates an expected call of UnsubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0)
}
// UpdateSelf mocks base method.
func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateSelf", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateSelf indicates an expected call of UpdateSelf.
func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0)
}

View File

@ -21,7 +21,7 @@ import (
"github.com/coder/coder/v2/tailnet"
)
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
// RunDERPAndSTUN creates a DERP mapping for tests.
func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
)
const (
@ -29,7 +30,7 @@ type TrackedConn struct {
cancel func()
kind QueueKind
conn net.Conn
updates chan []*Node
updates chan *proto.CoordinateResponse
logger slog.Logger
lastData []byte
@ -55,7 +56,7 @@ func NewTrackedConn(ctx context.Context, cancel func(),
// coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but
// the most pathological cases.
updates := make(chan []*Node, ResponseBufferSize)
updates := make(chan *proto.CoordinateResponse, ResponseBufferSize)
now := time.Now().Unix()
return &TrackedConn{
ctx: ctx,
@ -72,10 +73,10 @@ func NewTrackedConn(ctx context.Context, cancel func(),
}
}
func (t *TrackedConn) Enqueue(n []*Node) (err error) {
func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) {
atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
select {
case t.updates <- n:
case t.updates <- resp:
return nil
default:
return ErrWouldBlock
@ -124,7 +125,16 @@ func (t *TrackedConn) SendUpdates() {
case <-t.ctx.Done():
t.logger.Debug(t.ctx, "done sending updates")
return
case nodes := <-t.updates:
case resp := <-t.updates:
nodes, err := OnlyNodeUpdates(resp)
if err != nil {
t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err))
return
}
if len(nodes) == 0 {
t.logger.Debug(t.ctx, "skipping response with no nodes")
continue
}
data, err := json.Marshal(nodes)
if err != nil {
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))