feat: use tailnet v2 API for coordination (#11638)

This one is huge, and I'm sorry.

The problem is that once I change `tailnet.Conn` to start doing v2 behavior, I kind of have to change it everywhere, including in CoderSDK (CLI), the agent, wsproxy, and ServerTailnet.

There is still a bit more cleanup to do, and I need to add code so that when we lose connection to the Coordinator, we mark all peers as LOST, but that will be in a separate PR since this is big enough!
This commit is contained in:
Spike Curtis 2024-01-22 11:07:50 +04:00 committed by GitHub
parent 5a2cf7cd14
commit f01cab9894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1192 additions and 1114 deletions

View File

@ -475,7 +475,8 @@ gen: \
site/.eslintignore \ site/.eslintignore \
site/e2e/provisionerGenerated.ts \ site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \ site/src/theme/icons.json \
examples/examples.gen.json examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go
.PHONY: gen .PHONY: gen
# Mark all generated files as fresh so make thinks they're up-to-date. This is # Mark all generated files as fresh so make thinks they're up-to-date. This is
@ -502,6 +503,7 @@ gen/mark-fresh:
site/e2e/provisionerGenerated.ts \ site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \ site/src/theme/icons.json \
examples/examples.gen.json \ examples/examples.gen.json \
tailnet/tailnettest/coordinatormock.go \
" "
for file in $$files; do for file in $$files; do
echo "$$file" echo "$$file"
@ -529,6 +531,9 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $
coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go
go generate ./coderd/database/dbmock/ go generate ./coderd/database/dbmock/
tailnet/tailnettest/coordinatormock.go: tailnet/coordinator.go
go generate ./tailnet/tailnettest/
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
protoc \ protoc \
--go_out=. \ --go_out=. \

View File

@ -30,6 +30,7 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"storj.io/drpc"
"tailscale.com/net/speedtest" "tailscale.com/net/speedtest"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/netlogtype" "tailscale.com/types/netlogtype"
@ -47,6 +48,7 @@ import (
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
) )
const ( const (
@ -86,7 +88,7 @@ type Options struct {
type Client interface { type Client interface {
Manifest(ctx context.Context) (agentsdk.Manifest, error) Manifest(ctx context.Context) (agentsdk.Manifest, error)
Listen(ctx context.Context) (net.Conn, error) Listen(ctx context.Context) (drpc.Conn, error)
DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error)
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
@ -1058,20 +1060,34 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
coordinator, err := a.client.Listen(ctx) conn, err := a.client.Listen(ctx)
if err != nil { if err != nil {
return err return err
} }
defer coordinator.Close() defer func() {
cErr := conn.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
}
}()
tClient := tailnetproto.NewDRPCTailnetClient(conn)
coordinate, err := tClient.Coordinate(ctx)
if err != nil {
return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err)
}
defer func() {
cErr := coordinate.Close()
if cErr != nil {
a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err))
}
}()
a.logger.Info(ctx, "connected to coordination endpoint") a.logger.Info(ctx, "connected to coordination endpoint")
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error { coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil)
return network.UpdateNodes(nodes, false)
})
network.SetNodeCallback(sendNodes)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case err := <-errChan: case err := <-coordination.Error():
return err return err
} }
} }

View File

@ -1664,9 +1664,11 @@ func TestAgent_UpdatedDERP(t *testing.T) {
require.NotNil(t, originalDerpMap) require.NotNil(t, originalDerpMap)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
defer func() { // use t.Cleanup so the coordinator closing doesn't deadlock with in-memory
// coordination
t.Cleanup(func() {
_ = coordinator.Close() _ = coordinator.Close()
}() })
agentID := uuid.New() agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats, 50) statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
@ -1681,41 +1683,42 @@ func TestAgent_UpdatedDERP(t *testing.T) {
statsCh, statsCh,
coordinator, coordinator,
) )
closer := agent.New(agent.Options{ uut := agent.New(agent.Options{
Client: client, Client: client,
Filesystem: fs, Filesystem: fs,
Logger: logger.Named("agent"), Logger: logger.Named("agent"),
ReconnectingPTYTimeout: time.Minute, ReconnectingPTYTimeout: time.Minute,
}) })
defer func() { t.Cleanup(func() {
_ = closer.Close() _ = uut.Close()
}() })
// Setup a client connection. // Setup a client connection.
newClientConn := func(derpMap *tailcfg.DERPMap) *codersdk.WorkspaceAgentConn { newClientConn := func(derpMap *tailcfg.DERPMap, name string) *codersdk.WorkspaceAgentConn {
conn, err := tailnet.NewConn(&tailnet.Options{ conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: derpMap, DERPMap: derpMap,
Logger: logger.Named("client"), Logger: logger.Named(name),
}) })
require.NoError(t, err) require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() { t.Cleanup(func() {
_ = clientConn.Close() t.Logf("closing conn %s", name)
_ = serverConn.Close()
_ = conn.Close() _ = conn.Close()
<-serveClientDone
}) })
go func() { testCtx, testCtxCancel := context.WithCancel(context.Background())
defer close(serveClientDone) t.Cleanup(testCtxCancel)
err := coordinator.ServeClient(serverConn, uuid.New(), agentID) clientID := uuid.New()
assert.NoError(t, err) coordination := tailnet.NewInMemoryCoordination(
}() testCtx, logger,
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { clientID, agentID,
return conn.UpdateNodes(nodes, false) coordinator, conn)
t.Cleanup(func() {
t.Logf("closing coordination %s", name)
err := coordination.Close()
if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error())
}
}) })
conn.SetNodeCallback(sendNode)
// Force DERP. // Force DERP.
conn.SetBlockEndpoints(true) conn.SetBlockEndpoints(true)
@ -1724,6 +1727,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
CloseFunc: func() error { return codersdk.ErrSkipClose }, CloseFunc: func() error { return codersdk.ErrSkipClose },
}) })
t.Cleanup(func() { t.Cleanup(func() {
t.Logf("closing sdkConn %s", name)
_ = sdkConn.Close() _ = sdkConn.Close()
}) })
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@ -1734,7 +1738,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
return sdkConn return sdkConn
} }
conn1 := newClientConn(originalDerpMap) conn1 := newClientConn(originalDerpMap, "client1")
// Change the DERP map. // Change the DERP map.
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t) newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
@ -1753,27 +1757,34 @@ func TestAgent_UpdatedDERP(t *testing.T) {
DERPMap: newDerpMap, DERPMap: newDerpMap,
}) })
require.NoError(t, err) require.NoError(t, err)
t.Logf("client Pushed DERPMap update")
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
conn := closer.TailnetConn() conn := uut.TailnetConn()
if conn == nil { if conn == nil {
return false return false
} }
regionIDs := conn.DERPMap().RegionIDs() regionIDs := conn.DERPMap().RegionIDs()
return len(regionIDs) == 1 && regionIDs[0] == 2 && conn.Node().PreferredDERP == 2 preferredDERP := conn.Node().PreferredDERP
t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP)
return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2
}, testutil.WaitLong, testutil.IntervalFast) }, testutil.WaitLong, testutil.IntervalFast)
t.Logf("agent got the new DERPMap")
// Connect from a second client and make sure it uses the new DERP map. // Connect from a second client and make sure it uses the new DERP map.
conn2 := newClientConn(newDerpMap) conn2 := newClientConn(newDerpMap, "client2")
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
t.Log("conn2 got the new DERPMap")
// If the first client gets a DERP map update, it should be able to // If the first client gets a DERP map update, it should be able to
// reconnect just fine. // reconnect just fine.
conn1.SetDERPMap(newDerpMap) conn1.SetDERPMap(newDerpMap)
require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs())
t.Log("set the new DERPMap on conn1")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
require.True(t, conn1.AwaitReachable(ctx)) require.True(t, conn1.AwaitReachable(ctx))
t.Log("conn1 reached agent with new DERP")
} }
func TestAgent_Speedtest(t *testing.T) { func TestAgent_Speedtest(t *testing.T) {
@ -2050,22 +2061,22 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
Logger: logger.Named("client"), Logger: logger.Named("client"),
}) })
require.NoError(t, err) require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() { t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close() _ = conn.Close()
<-serveClientDone
}) })
go func() { testCtx, testCtxCancel := context.WithCancel(context.Background())
defer close(serveClientDone) t.Cleanup(testCtxCancel)
coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID) clientID := uuid.New()
}() coordination := tailnet.NewInMemoryCoordination(
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { testCtx, logger,
return conn.UpdateNodes(nodes, false) clientID, metadata.AgentID,
coordinator, conn)
t.Cleanup(func() {
err := coordination.Close()
if err != nil {
t.Logf("error closing in-mem coordination: %s", err.Error())
}
}) })
conn.SetNodeCallback(sendNode)
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: metadata.AgentID, AgentID: metadata.AgentID,
}) })

View File

@ -3,19 +3,26 @@ package agenttest
import ( import (
"context" "context"
"io" "io"
"net"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/tailcfg"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -24,11 +31,31 @@ func NewClient(t testing.TB,
agentID uuid.UUID, agentID uuid.UUID,
manifest agentsdk.Manifest, manifest agentsdk.Manifest,
statsChan chan *agentsdk.Stats, statsChan chan *agentsdk.Stats,
coordinator tailnet.CoordinatorV1, coordinator tailnet.Coordinator,
) *Client { ) *Client {
if manifest.AgentID == uuid.Nil { if manifest.AgentID == uuid.Nil {
manifest.AgentID = agentID manifest.AgentID = agentID
} }
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coordinator)
mux := drpcmux.New()
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
// TODO: handle DERPMap too!
DerpMapUpdateFrequency: time.Hour,
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
return &Client{ return &Client{
t: t, t: t,
logger: logger.Named("client"), logger: logger.Named("client"),
@ -36,6 +63,7 @@ func NewClient(t testing.TB,
manifest: manifest, manifest: manifest,
statsChan: statsChan, statsChan: statsChan,
coordinator: coordinator, coordinator: coordinator,
server: server,
derpMapUpdates: make(chan agentsdk.DERPMapUpdate), derpMapUpdates: make(chan agentsdk.DERPMapUpdate),
} }
} }
@ -47,7 +75,8 @@ type Client struct {
manifest agentsdk.Manifest manifest agentsdk.Manifest
metadata map[string]agentsdk.Metadata metadata map[string]agentsdk.Metadata
statsChan chan *agentsdk.Stats statsChan chan *agentsdk.Stats
coordinator tailnet.CoordinatorV1 coordinator tailnet.Coordinator
server *drpcserver.Server
LastWorkspaceAgent func() LastWorkspaceAgent func()
PatchWorkspaceLogs func() error PatchWorkspaceLogs func() error
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
@ -63,20 +92,29 @@ func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil return c.manifest, nil
} }
func (c *Client) Listen(_ context.Context) (net.Conn, error) { func (c *Client) Listen(_ context.Context) (drpc.Conn, error) {
clientConn, serverConn := net.Pipe() conn, lis := drpcsdk.MemTransportPipe()
closed := make(chan struct{}) closed := make(chan struct{})
c.LastWorkspaceAgent = func() { c.LastWorkspaceAgent = func() {
_ = serverConn.Close() _ = conn.Close()
_ = clientConn.Close() _ = lis.Close()
<-closed <-closed
} }
c.t.Cleanup(c.LastWorkspaceAgent) c.t.Cleanup(c.LastWorkspaceAgent)
serveCtx, cancel := context.WithCancel(context.Background())
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "agenttest",
ID: c.agentID,
Auth: auth,
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() { go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "") _ = c.server.Serve(serveCtx, lis)
close(closed) close(closed)
}() }()
return clientConn, nil return conn, nil
} }
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) { func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {

View File

@ -33,6 +33,7 @@ import (
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -98,14 +99,32 @@ func TestDERP(t *testing.T) {
w2Ready := make(chan struct{}) w2Ready := make(chan struct{})
w2ReadyOnce := sync.Once{} w2ReadyOnce := sync.Once{}
w1ID := uuid.New()
w1.SetNodeCallback(func(node *tailnet.Node) { w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node}, false) pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
w2.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
Id: w1ID[:],
Node: pn,
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
}})
w2ReadyOnce.Do(func() { w2ReadyOnce.Do(func() {
close(w2Ready) close(w2Ready)
}) })
}) })
w2ID := uuid.New()
w2.SetNodeCallback(func(node *tailnet.Node) { w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node}, false) pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
w1.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
Id: w2ID[:],
Node: pn,
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
}})
}) })
conn := make(chan struct{}) conn := make(chan struct{})
@ -199,7 +218,11 @@ func TestDERPForceWebSockets(t *testing.T) {
defer cancel() defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID,
&codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("client"),
},
)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
_ = conn.Close() _ = conn.Close()

View File

@ -121,12 +121,23 @@ func NewServerTailnet(
} }
tn.agentConn.Store(&agentConn) tn.agentConn.Store(&agentConn)
err = tn.getAgentConn().UpdateSelf(conn.Node()) pn, err := tailnet.NodeToProto(conn.Node())
if err != nil { if err != nil {
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
} else {
err = tn.getAgentConn().UpdateSelf(pn)
if err != nil {
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
}
} }
conn.SetNodeCallback(func(node *tailnet.Node) { conn.SetNodeCallback(func(node *tailnet.Node) {
err := tn.getAgentConn().UpdateSelf(node) pn, err := tailnet.NodeToProto(node)
if err != nil {
tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
return
}
err = tn.getAgentConn().UpdateSelf(pn)
if err != nil { if err != nil {
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
} }
@ -191,21 +202,9 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// If no one has connected since the cutoff and there are no active // If no one has connected since the cutoff and there are no active
// connections, remove the agent. // connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
ID: tailnet.NodeID(agentID),
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
})
if err != nil {
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
continue
}
if !deleted {
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
}
deletedCount++ deletedCount++
delete(s.agentConnectionTimes, agentID) delete(s.agentConnectionTimes, agentID)
err = agentConn.UnsubscribeAgent(agentID) err := agentConn.UnsubscribeAgent(agentID)
if err != nil { if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
} }
@ -221,7 +220,7 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
func (s *ServerTailnet) watchAgentUpdates() { func (s *ServerTailnet) watchAgentUpdates() {
for { for {
conn := s.getAgentConn() conn := s.getAgentConn()
nodes, ok := conn.NextUpdate(s.ctx) resp, ok := conn.NextUpdate(s.ctx)
if !ok { if !ok {
if conn.IsClosed() && s.ctx.Err() == nil { if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing") s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
@ -231,7 +230,7 @@ func (s *ServerTailnet) watchAgentUpdates() {
return return
} }
err := s.conn.UpdateNodes(nodes, false) err := s.conn.UpdatePeers(resp.GetPeerUpdates())
if err != nil { if err != nil {
if xerrors.Is(err, tailnet.ErrConnClosed) { if xerrors.Is(err, tailnet.ErrConnClosed) {
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err)) s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))

View File

@ -3,7 +3,6 @@ package coderd_test
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
@ -204,22 +203,20 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
Logger: logger.Named("client"), Logger: logger.Named("client"),
}) })
require.NoError(t, err) require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() { t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close() _ = conn.Close()
<-serveClientDone
}) })
go func() { clientID := uuid.New()
defer close(serveClientDone) testCtx, testCtxCancel := context.WithCancel(context.Background())
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID) t.Cleanup(testCtxCancel)
}() coordination := tailnet.NewInMemoryCoordination(
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { testCtx, logger,
return conn.UpdateNodes(node, false) clientID, manifest.AgentID,
coord, conn,
)
t.Cleanup(func() {
_ = coordination.Close()
}) })
conn.SetNodeCallback(sendNode)
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID, AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP, AgentIP: codersdk.WorkspaceAgentIP,

View File

@ -30,6 +30,10 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion {
return v return v
} }
func (v *APIVersion) String() string {
return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor)
}
// Validate validates the given version against the given constraints: // Validate validates the given version against the given constraints:
// A given major.minor version is valid iff: // A given major.minor version is valid iff:
// 1. The requested major version is contained within v.supportedMajors // 1. The requested major version is contained within v.supportedMajors
@ -42,10 +46,6 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion {
// - 1.x is supported, // - 1.x is supported,
// - 2.0, 2.1, and 2.2 are supported, // - 2.0, 2.1, and 2.2 are supported,
// - 2.3+ is not supported. // - 2.3+ is not supported.
func (v *APIVersion) String() string {
return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor)
}
func (v *APIVersion) Validate(version string) error { func (v *APIVersion) Validate(version string) error {
major, minor, err := Parse(version) major, minor, err := Parse(version)
if err != nil { if err != nil {

View File

@ -857,8 +857,6 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
// Deprecated: use api.tailnet.AgentConn instead. // Deprecated: use api.tailnet.AgentConn instead.
// See: https://github.com/coder/coder/issues/8218 // See: https://github.com/coder/coder/issues/8218
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
clientConn, serverConn := net.Pipe()
derpMap := api.DERPMap() derpMap := api.DERPMap()
conn, err := tailnet.NewConn(&tailnet.Options{ conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
@ -868,8 +866,6 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(), BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
}) })
if err != nil { if err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
return nil, xerrors.Errorf("create tailnet conn: %w", err) return nil, xerrors.Errorf("create tailnet conn: %w", err)
} }
ctx, cancel := context.WithCancel(api.ctx) ctx, cancel := context.WithCancel(api.ctx)
@ -887,10 +883,10 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
return left return left
}) })
sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { clientID := uuid.New()
return conn.UpdateNodes(nodes, true) coordination := tailnet.NewInMemoryCoordination(ctx, api.Logger,
}) clientID, agentID,
conn.SetNodeCallback(sendNodes) *(api.TailnetCoordinator.Load()), conn)
// Check for updated DERP map every 5 seconds. // Check for updated DERP map every 5 seconds.
go func() { go func() {
@ -920,27 +916,13 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
AgentID: agentID, AgentID: agentID,
AgentIP: codersdk.WorkspaceAgentIP, AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error { CloseFunc: func() error {
_ = coordination.Close()
cancel() cancel()
_ = clientConn.Close()
_ = serverConn.Close()
return nil return nil
}, },
}) })
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
// Sometimes, we get benign closed pipe errors when the server is
// shutting down.
if api.ctx.Err() == nil {
api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
}
_ = agentConn.Close()
}
}()
if !agentConn.AwaitReachable(ctx) { if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close() _ = agentConn.Close()
_ = serverConn.Close()
_ = clientConn.Close()
cancel() cancel()
return nil, xerrors.Errorf("agent not reachable") return nil, xerrors.Errorf("agent not reachable")
} }

View File

@ -535,7 +535,6 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
require.True(t, conn.BlockEndpoints())
require.True(t, conn.AwaitReachable(ctx)) require.True(t, conn.AwaitReachable(ctx))
_, p2p, _, err := conn.Ping(ctx) _, p2p, _, err := conn.Ping(ctx)

View File

@ -12,14 +12,19 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak" "go.uber.org/goleak"
"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/tailcfg"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
@ -27,7 +32,9 @@ import (
"github.com/coder/coder/v2/coderd/wsconncache" "github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -41,7 +48,7 @@ func TestCache(t *testing.T) {
t.Run("Same", func(t *testing.T) { t.Run("Same", func(t *testing.T) {
t.Parallel() t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil return setupAgent(t, agentsdk.Manifest{}, 0)
}, 0) }, 0)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
@ -54,10 +61,10 @@ func TestCache(t *testing.T) {
}) })
t.Run("Expire", func(t *testing.T) { t.Run("Expire", func(t *testing.T) {
t.Parallel() t.Parallel()
called := atomic.NewInt32(0) called := int32(0)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
called.Add(1) atomic.AddInt32(&called, 1)
return setupAgent(t, agentsdk.Manifest{}, 0), nil return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond) }, time.Microsecond)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
@ -70,12 +77,12 @@ func TestCache(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
release() release()
<-conn.Closed() <-conn.Closed()
require.Equal(t, int32(2), called.Load()) require.Equal(t, int32(2), called)
}) })
t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel() t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond) }, time.Microsecond)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
@ -108,7 +115,7 @@ func TestCache(t *testing.T) {
go server.Serve(random) go server.Serve(random)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond) }, time.Microsecond)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
@ -154,7 +161,7 @@ func TestCache(t *testing.T) {
}) })
} }
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn { func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) {
t.Helper() t.Helper()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t) manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
@ -184,18 +191,25 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
DERPForceWebSockets: manifest.DERPForceWebSockets, DERPForceWebSockets: manifest.DERPForceWebSockets,
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug), Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
}) })
require.NoError(t, err) // setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the
clientConn, serverConn := net.Pipe() // test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks.
if !assert.NoError(t, err) {
return nil, err
}
t.Cleanup(func() { t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close() _ = conn.Close()
}) })
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID) clientID := uuid.New()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error { testCtx, testCtxCancel := context.WithCancel(context.Background())
return conn.UpdateNodes(nodes, false) t.Cleanup(testCtxCancel)
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, manifest.AgentID,
coordinator, conn,
)
t.Cleanup(func() {
_ = coordination.Close()
}) })
conn.SetNodeCallback(sendNode)
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{ agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID, AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP, AgentIP: codersdk.WorkspaceAgentIP,
@ -206,16 +220,20 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel() defer cancel()
if !agentConn.AwaitReachable(ctx) { if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable") // setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end
// the test, which in turn closes the wsconncache, which in turn waits for the Dialer and
// deadlocks.
t.Error("agent not reachable")
return nil, xerrors.New("agent not reachable")
} }
return agentConn return agentConn, nil
} }
type client struct { type client struct {
t *testing.T t *testing.T
agentID uuid.UUID agentID uuid.UUID
manifest agentsdk.Manifest manifest agentsdk.Manifest
coordinator tailnet.CoordinatorV1 coordinator tailnet.Coordinator
} }
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) { func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
@ -240,19 +258,53 @@ func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate,
}, nil }, nil
} }
func (c *client) Listen(_ context.Context) (net.Conn, error) { func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
clientConn, serverConn := net.Pipe() logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
conn, lis := drpcsdk.MemTransportPipe()
closed := make(chan struct{}) closed := make(chan struct{})
c.t.Cleanup(func() { c.t.Cleanup(func() {
_ = serverConn.Close() _ = conn.Close()
_ = clientConn.Close() _ = lis.Close()
<-closed <-closed
}) })
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&c.coordinator)
mux := drpcmux.New()
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
// TODO: handle DERPMap too!
DerpMapUpdateFrequency: time.Hour,
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
if err != nil {
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
serveCtx, cancel := context.WithCancel(context.Background())
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "wsconncache_test-agent",
ID: c.agentID,
Auth: auth,
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() { go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "") server.Serve(serveCtx, lis)
close(closed) close(closed)
}() }()
return clientConn, nil return conn, nil
} }
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) { func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {

View File

@ -14,12 +14,15 @@ import (
"cloud.google.com/go/compute/metadata" "cloud.google.com/go/compute/metadata"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"storj.io/drpc"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/retry" "github.com/coder/retry"
) )
@ -280,8 +283,8 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
// Listen connects to the workspace agent coordinate WebSocket // Listen connects to the workspace agent coordinate WebSocket
// that handles connection negotiation. // that handles connection negotiation.
func (c *Client) Listen(ctx context.Context) (net.Conn, error) { func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/coordinate") coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
if err != nil { if err != nil {
return nil, xerrors.Errorf("parse url: %w", err) return nil, xerrors.Errorf("parse url: %w", err)
} }
@ -312,14 +315,21 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "coordinate") pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "coordinate")
return &closeNetConn{ netConn := &closeNetConn{
Conn: wsNetConn, Conn: wsNetConn,
closeFunc: func() { closeFunc: func() {
cancelFunc() cancelFunc()
_ = conn.Close(websocket.StatusGoingAway, "Listen closed") _ = conn.Close(websocket.StatusGoingAway, "Listen closed")
<-pingClosed <-pingClosed
}, },
}, nil }
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(netConn, config)
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return drpcsdk.MultiplexedConn(session), nil
} }
type PostAppHealthsRequest struct { type PostAppHealthsRequest struct {

View File

@ -313,6 +313,9 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
if err != nil { if err != nil {
return nil, xerrors.Errorf("parse url: %w", err) return nil, xerrors.Errorf("parse url: %w", err)
} }
q := coordinateURL.Query()
q.Add("version", tailnet.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
closedCoordinator := make(chan struct{}) closedCoordinator := make(chan struct{})
// Must only ever be used once, send error OR close to avoid // Must only ever be used once, send error OR close to avoid
// reassignment race. Buffered so we don't hang in goroutine. // reassignment race. Buffered so we don't hang in goroutine.
@ -344,12 +347,22 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
options.Logger.Debug(ctx, "failed to dial", slog.Error(err)) options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
continue continue
} }
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error { client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary))
return conn.UpdateNodes(nodes, false) if err != nil {
}) options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
conn.SetNodeCallback(sendNode) _ = ws.Close(websocket.StatusInternalError, "")
continue
}
coordinate, err := client.Coordinate(ctx)
if err != nil {
options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
continue
}
coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID)
options.Logger.Debug(ctx, "serving coordinator") options.Logger.Debug(ctx, "serving coordinator")
err = <-errChan err = <-coordination.Error()
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusGoingAway, "") _ = ws.Close(websocket.StatusGoingAway, "")
return return

View File

@ -8,6 +8,7 @@ import (
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet"
@ -53,6 +54,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
ctx := r.Context() ctx := r.Context()
version := "1.0" version := "1.0"
msgType := websocket.MessageText
qv := r.URL.Query().Get("version") qv := r.URL.Query().Get("version")
if qv != "" { if qv != "" {
version = qv version = qv
@ -66,6 +68,11 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
}) })
return return
} }
maj, _, _ := apiversion.Parse(version)
if maj >= 2 {
// Versions 2+ use dRPC over a binary connection
msgType = websocket.MessageBinary
}
api.AGPL.WebsocketWaitMutex.Lock() api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1) api.AGPL.WebsocketWaitGroup.Add(1)
@ -81,7 +88,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
return return
} }
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) ctx, nc := websocketNetConn(ctx, conn, msgType)
defer nc.Close() defer nc.Close()
id := uuid.New() id := uuid.New()

View File

@ -10,6 +10,7 @@ import (
"github.com/moby/moby/pkg/namesgenerator" "github.com/moby/moby/pkg/namesgenerator"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/types/key" "tailscale.com/types/key"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
@ -20,6 +21,7 @@ import (
"github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -27,6 +29,12 @@ import (
func Test_agentIsLegacy(t *testing.T) { func Test_agentIsLegacy(t *testing.T) {
t.Parallel() t.Parallel()
nodeKey := key.NewNode().Public()
discoKey := key.NewDisco().Public()
nkBin, err := nodeKey.MarshalBinary()
require.NoError(t, err)
dkBin, err := discoKey.MarshalText()
require.NoError(t, err)
t.Run("Legacy", func(t *testing.T) { t.Run("Legacy", func(t *testing.T) {
t.Parallel() t.Parallel()
@ -54,18 +62,18 @@ func Test_agentIsLegacy(t *testing.T) {
nodeID := uuid.New() nodeID := uuid.New()
ma := coordinator.ServeMultiAgent(nodeID) ma := coordinator.ServeMultiAgent(nodeID)
defer ma.Close() defer ma.Close()
require.NoError(t, ma.UpdateSelf(&agpl.Node{ require.NoError(t, ma.UpdateSelf(&proto.Node{
ID: 55, Id: 55,
AsOf: time.Unix(1689653252, 0), AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: key.NewNode().Public(), Key: nkBin,
DiscoKey: key.NewDisco().Public(), Disco: string(dkBin),
PreferredDERP: 0, PreferredDerp: 0,
DERPLatency: map[string]float64{ DerpLatency: map[string]float64{
"0": 1.0, "0": 1.0,
}, },
DERPForcedWebsocket: map[int]string{}, DerpForcedWebsocket: map[int32]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, Addresses: []string{codersdk.WorkspaceAgentIP.String() + "/128"},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, AllowedIps: []string{codersdk.WorkspaceAgentIP.String() + "/128"},
Endpoints: []string{"192.168.1.1:18842"}, Endpoints: []string{"192.168.1.1:18842"},
})) }))
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
@ -114,18 +122,18 @@ func Test_agentIsLegacy(t *testing.T) {
nodeID := uuid.New() nodeID := uuid.New()
ma := coordinator.ServeMultiAgent(nodeID) ma := coordinator.ServeMultiAgent(nodeID)
defer ma.Close() defer ma.Close()
require.NoError(t, ma.UpdateSelf(&agpl.Node{ require.NoError(t, ma.UpdateSelf(&proto.Node{
ID: 55, Id: 55,
AsOf: time.Unix(1689653252, 0), AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: key.NewNode().Public(), Key: nkBin,
DiscoKey: key.NewDisco().Public(), Disco: string(dkBin),
PreferredDERP: 0, PreferredDerp: 0,
DERPLatency: map[string]float64{ DerpLatency: map[string]float64{
"0": 1.0, "0": 1.0,
}, },
DERPForcedWebsocket: map[int]string{}, DerpForcedWebsocket: map[int32]string{},
Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, Addresses: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, AllowedIps: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()},
Endpoints: []string{"192.168.1.1:18842"}, Endpoints: []string{"192.168.1.1:18842"},
})) }))
require.Eventually(t, func() bool { require.Eventually(t, func() bool {

View File

@ -6,12 +6,15 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/types/key"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet" "github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -39,22 +42,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
defer agent1.close() defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New() ma1 := newTestMultiAgent(t, coord1)
ma1 := coord1.ServeMultiAgent(id) defer ma1.close()
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id) ma1.subscribeAgent(agent1.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 5)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) ma1.sendNodeWithDERP(3)
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.Close()) ma1.close()
require.NoError(t, agent1.close()) require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -86,23 +86,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
defer agent1.close() defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New() ma1 := newTestMultiAgent(t, coord1)
ma1 := coord1.ServeMultiAgent(id) defer ma1.close()
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id) ma1.subscribeAgent(agent1.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 5)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) ma1.sendNodeWithDERP(3)
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) ma1.unsubscribeAgent(agent1.id)
require.NoError(t, ma1.Close()) ma1.close()
require.NoError(t, agent1.close()) require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -134,37 +131,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
defer agent1.close() defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New() ma1 := newTestMultiAgent(t, coord1)
ma1 := coord1.ServeMultiAgent(id) defer ma1.close()
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id) ma1.subscribeAgent(agent1.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 5)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) ma1.assertEventuallyHasDERPs(ctx, 1)
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})) ma1.sendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.UnsubscribeAgent(agent1.id)) ma1.unsubscribeAgent(agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
func() { func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel() defer cancel()
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9})) ma1.sendNodeWithDERP(9)
assertNeverHasDERPs(ctx, t, agent1, 9) assertNeverHasDERPs(ctx, t, agent1, 9)
}() }()
func() { func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel() defer cancel()
agent1.sendNode(&agpl.Node{PreferredDERP: 8}) agent1.sendNode(&agpl.Node{PreferredDERP: 8})
assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8) ma1.assertNeverHasDERPs(ctx, 8)
}() }()
require.NoError(t, ma1.Close()) ma1.close()
require.NoError(t, agent1.close()) require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -201,22 +196,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
defer agent1.close() defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New() ma1 := newTestMultiAgent(t, coord2)
ma1 := coord2.ServeMultiAgent(id) defer ma1.close()
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id) ma1.subscribeAgent(agent1.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 5)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) ma1.sendNodeWithDERP(3)
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.Close()) ma1.close()
require.NoError(t, agent1.close()) require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -254,22 +246,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
defer agent1.close() defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New() ma1 := newTestMultiAgent(t, coord2)
ma1 := coord2.ServeMultiAgent(id) defer ma1.close()
defer ma1.Close()
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) ma1.sendNodeWithDERP(3)
require.NoError(t, err)
err = ma1.SubscribeAgent(agent1.id) ma1.subscribeAgent(agent1.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 5)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent1, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) ma1.assertEventuallyHasDERPs(ctx, 1)
require.NoError(t, ma1.Close()) ma1.close()
require.NoError(t, agent1.close()) require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
@ -316,33 +305,129 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
defer agent1.close() defer agent1.close()
agent2.sendNode(&agpl.Node{PreferredDERP: 6}) agent2.sendNode(&agpl.Node{PreferredDERP: 6})
id := uuid.New() ma1 := newTestMultiAgent(t, coord3)
ma1 := coord3.ServeMultiAgent(id) defer ma1.close()
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id) ma1.subscribeAgent(agent1.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 5)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1) ma1.assertEventuallyHasDERPs(ctx, 1)
err = ma1.SubscribeAgent(agent2.id) ma1.subscribeAgent(agent2.id)
require.NoError(t, err) ma1.assertEventuallyHasDERPs(ctx, 6)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6)
agent2.sendNode(&agpl.Node{PreferredDERP: 2}) agent2.sendNode(&agpl.Node{PreferredDERP: 2})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2) ma1.assertEventuallyHasDERPs(ctx, 2)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}) ma1.sendNodeWithDERP(3)
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3) assertEventuallyHasDERPs(ctx, t, agent1, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3) assertEventuallyHasDERPs(ctx, t, agent2, 3)
require.NoError(t, ma1.Close()) ma1.close()
require.NoError(t, agent1.close()) require.NoError(t, agent1.close())
require.NoError(t, agent2.close()) require.NoError(t, agent2.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.id)
} }
type testMultiAgent struct {
t testing.TB
id uuid.UUID
a agpl.MultiAgentConn
nodeKey []byte
discoKey string
}
func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id)
return m
}
func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
m.t.Helper()
err := m.a.UpdateSelf(&proto.Node{
Key: m.nodeKey,
Disco: m.discoKey,
PreferredDerp: derp,
})
require.NoError(m.t, err)
}
func (m *testMultiAgent) close() {
m.t.Helper()
err := m.a.Close()
require.NoError(m.t, err)
}
func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.SubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
m.t.Helper()
err := m.a.UnsubscribeAgent(id)
require.NoError(m.t, err)
}
func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
require.True(m.t, ok)
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
m.t.Helper()
for {
resp, ok := m.a.NextUpdate(ctx)
if !ok {
return
}
nodes, err := agpl.OnlyNodeUpdates(resp)
require.NoError(m.t, err)
if len(nodes) != len(expected) {
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
m.t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}

View File

@ -818,56 +818,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte
} }
} }
func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
require.True(t, ok)
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
if !ok {
return
}
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper() t.Helper()
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {

View File

@ -96,7 +96,11 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
return xerrors.Errorf("unsubscribe agent: %w", err) return xerrors.Errorf("unsubscribe agent: %w", err)
} }
case wsproxysdk.CoordinateMessageTypeNodeUpdate: case wsproxysdk.CoordinateMessageTypeNodeUpdate:
err := ma.UpdateSelf(msg.Node) pn, err := agpl.NodeToProto(msg.Node)
if err != nil {
return err
}
err = ma.UpdateSelf(pn)
if err != nil { if err != nil {
return xerrors.Errorf("update self: %w", err) return xerrors.Errorf("update self: %w", err)
} }
@ -110,11 +114,14 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
var lastData []byte var lastData []byte
for { for {
nodes, ok := ma.NextUpdate(ctx) resp, ok := ma.NextUpdate(ctx)
if !ok { if !ok {
return xerrors.New("multiagent is closed") return xerrors.New("multiagent is closed")
} }
nodes, err := agpl.OnlyNodeUpdates(resp)
if err != nil {
return xerrors.Errorf("failed to convert response: %w", err)
}
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes}) data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
if err != nil { if err != nil {
return err return err

View File

@ -158,7 +158,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
// TODO: Probably do some version checking here // TODO: Probably do some version checking here
info, err := client.SDKClient.BuildInfo(ctx) info, err := client.SDKClient.BuildInfo(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("buildinfo: %w", errors.Join( return nil, xerrors.Errorf("buildinfo: %w", errors.Join(
xerrors.Errorf("unable to fetch build info from primary coderd. Are you sure %q is a coderd instance?", opts.DashboardURL), xerrors.Errorf("unable to fetch build info from primary coderd. Are you sure %q is a coderd instance?", opts.DashboardURL),
err, err,
)) ))

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
@ -23,6 +22,7 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
agpl "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
) )
// Client is a HTTP client for a subset of Coder API routes that external // Client is a HTTP client for a subset of Coder API routes that external
@ -438,6 +438,9 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
cancel() cancel()
return nil, xerrors.Errorf("parse url: %w", err) return nil, xerrors.Errorf("parse url: %w", err)
} }
q := coordinateURL.Query()
q.Add("version", agpl.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
coordinateHeaders := make(http.Header) coordinateHeaders := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader tokenHeader := codersdk.SessionTokenHeader
if c.SDKClient.SessionTokenHeader != "" { if c.SDKClient.SessionTokenHeader != "" {
@ -457,10 +460,24 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
go httpapi.HeartbeatClose(ctx, logger, cancel, conn) go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
nc := websocket.NetConn(ctx, conn, websocket.MessageText) nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
client, err := agpl.NewDRPCClient(nc)
if err != nil {
logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "")
return nil, xerrors.Errorf("failed to create DRPCClient: %w", err)
}
protocol, err := client.Coordinate(ctx)
if err != nil {
logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "")
return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err)
}
rma := remoteMultiAgentHandler{ rma := remoteMultiAgentHandler{
sdk: c, sdk: c,
nc: nc, logger: logger,
protocol: protocol,
cancel: cancel, cancel: cancel,
legacyAgentCache: map[uuid.UUID]bool{}, legacyAgentCache: map[uuid.UUID]bool{},
} }
@ -471,103 +488,75 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
OnSubscribe: rma.OnSubscribe, OnSubscribe: rma.OnSubscribe,
OnUnsubscribe: rma.OnUnsubscribe, OnUnsubscribe: rma.OnUnsubscribe,
OnNodeUpdate: rma.OnNodeUpdate, OnNodeUpdate: rma.OnNodeUpdate,
OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") }, OnRemove: rma.OnRemove,
}).Init() }).Init()
go func() { go func() {
<-ctx.Done() <-ctx.Done()
ma.Close() ma.Close()
_ = conn.Close(websocket.StatusGoingAway, "closed")
}() }()
go func() { rma.ma = ma
defer cancel() go rma.respLoop()
dec := json.NewDecoder(nc)
for {
var msg CoordinateNodes
err := dec.Decode(&msg)
if err != nil {
if xerrors.Is(err, io.EOF) {
logger.Info(ctx, "websocket connection severed", slog.Error(err))
return
}
logger.Error(ctx, "decode coordinator nodes", slog.Error(err))
return
}
err = ma.Enqueue(msg.Nodes)
if err != nil {
logger.Error(ctx, "enqueue nodes from coordinator", slog.Error(err))
continue
}
}
}()
return ma, nil return ma, nil
} }
type remoteMultiAgentHandler struct { type remoteMultiAgentHandler struct {
sdk *Client sdk *Client
nc net.Conn logger slog.Logger
cancel func() protocol proto.DRPCTailnet_CoordinateClient
ma *agpl.MultiAgent
cancel func()
legacyMu sync.RWMutex legacyMu sync.RWMutex
legacyAgentCache map[uuid.UUID]bool legacyAgentCache map[uuid.UUID]bool
legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse] legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse]
} }
func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { func (a *remoteMultiAgentHandler) respLoop() {
data, err := json.Marshal(v) {
if err != nil { defer a.cancel()
return xerrors.Errorf("json marshal message: %w", err) for {
} resp, err := a.protocol.Recv()
if err != nil {
if xerrors.Is(err, io.EOF) {
a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err))
return
}
// Set a deadline so that hung connections don't put back pressure on the system. a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err))
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. return
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) }
if err != nil {
a.cancel()
return xerrors.Errorf("set write deadline: %w", err)
}
_, err = a.nc.Write(data)
if err != nil {
a.cancel()
return xerrors.Errorf("write message: %w", err)
}
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are err = a.ma.Enqueue(resp)
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() if err != nil {
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err))
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after continue
// our successful write, it is important that we reset the deadline before it fires. }
err = a.nc.SetWriteDeadline(time.Time{}) }
if err != nil {
a.cancel()
return xerrors.Errorf("clear write deadline: %w", err)
} }
return nil
} }
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error { func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error {
return a.writeJSON(CoordinateMessage{ return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}})
Type: CoordinateMessageTypeNodeUpdate,
Node: node,
})
} }
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return nil, a.writeJSON(CoordinateMessage{ return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
Type: CoordinateMessageTypeSubscribe,
AgentID: agentID,
})
} }
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error { func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.writeJSON(CoordinateMessage{ return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
Type: CoordinateMessageTypeUnsubscribe, }
AgentID: agentID,
}) func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) {
err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err))
}
_ = a.protocol.CloseSend()
} }
func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool { func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool {

View File

@ -18,8 +18,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"golang.org/x/xerrors" "google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"cdr.dev/slog" "cdr.dev/slog"
@ -30,6 +31,7 @@ import (
"github.com/coder/coder/v2/enterprise/tailnet" "github.com/coder/coder/v2/enterprise/tailnet"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -156,25 +158,48 @@ func TestDialCoordinator(t *testing.T) {
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
t.Parallel() t.Parallel()
var ( var (
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID = uuid.New() agentID = uuid.UUID{33}
serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t)) proxyID = uuid.UUID{44}
r = chi.NewRouter() mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t))
srv = httptest.NewServer(r) coord agpl.Coordinator = mCoord
r = chi.NewRouter()
srv = httptest.NewServer(r)
) )
defer cancel() defer cancel()
defer srv.Close()
coordPtr := atomic.Pointer[agpl.Coordinator]{}
coordPtr.Store(&coord)
cSrv, err := tailnet.NewClientService(
logger, &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
// buffer the channels here, so we don't need to read and write in goroutines to
// avoid blocking
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}).
Times(1).
Return(reqs, resps)
serveMACErr := make(chan error, 1)
r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) { r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, nil) conn, err := websocket.Accept(w, r, nil)
require.NoError(t, err) if !assert.NoError(t, err) {
nc := websocket.NetConn(r.Context(), conn, websocket.MessageText) return
defer serverMultiAgent.Close()
err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent)
if !xerrors.Is(err, io.EOF) {
assert.NoError(t, err)
} }
version := r.URL.Query().Get("version")
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
return
}
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID)
serveMACErr <- err
}) })
r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) { r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{ httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{
@ -188,51 +213,50 @@ func TestDialCoordinator(t *testing.T) {
client := wsproxysdk.New(u) client := wsproxysdk.New(u)
client.SDKClient.SetLogger(logger) client.SDKClient.SetLogger(logger)
expected := []*agpl.Node{{ peerID := uuid.UUID{55}
ID: 55, peerNodeKey, err := key.NewNode().Public().MarshalBinary()
AsOf: time.Unix(1689653252, 0), require.NoError(t, err)
Key: key.NewNode().Public(), peerDiscoKey, err := key.NewDisco().Public().MarshalText()
DiscoKey: key.NewDisco().Public(), require.NoError(t, err)
PreferredDERP: 0, expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
DERPLatency: map[string]float64{ Id: peerID[:],
"0": 1.0, Node: &proto.Node{
Id: 55,
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
Key: peerNodeKey[:],
Disco: string(peerDiscoKey),
PreferredDerp: 0,
DerpLatency: map[string]float64{
"0": 1.0,
},
DerpForcedWebsocket: map[int32]string{},
Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
Endpoints: []string{"192.168.1.1:18842"},
}, },
DERPForcedWebsocket: map[int]string{}, }}}
Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
Endpoints: []string{"192.168.1.1:18842"},
}}
sendNode := make(chan struct{})
serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) {
select {
case <-sendNode:
return expected, true
case <-ctx.Done():
return nil, false
}
})
rma, err := client.DialCoordinator(ctx) rma, err := client.DialCoordinator(ctx)
require.NoError(t, err) require.NoError(t, err)
// Subscribe // Subscribe
{ {
ch := make(chan struct{})
serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) {
close(ch)
})
require.NoError(t, rma.SubscribeAgent(agentID)) require.NoError(t, rma.SubscribeAgent(agentID))
waitOrCancel(ctx, t, ch)
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
} }
// Read updated agent node // Read updated agent node
{ {
sendNode <- struct{}{} resps <- expected
got, ok := rma.NextUpdate(ctx)
resp, ok := rma.NextUpdate(ctx)
assert.True(t, ok) assert.True(t, ok)
got[0].AsOf = got[0].AsOf.In(time.Local) updates := resp.GetPeerUpdates()
assert.Equal(t, *expected[0], *got[0]) assert.Len(t, updates, 1)
eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode())
assert.NoError(t, err)
assert.True(t, eq)
} }
// Check legacy // Check legacy
{ {
@ -241,45 +265,38 @@ func TestDialCoordinator(t *testing.T) {
} }
// UpdateSelf // UpdateSelf
{ {
ch := make(chan struct{}) require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode()))
serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) {
node.AsOf = node.AsOf.In(time.Local) req := testutil.RequireRecvCtx(ctx, t, reqs)
assert.Equal(t, expected[0], node) eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode())
close(ch) require.NoError(t, err)
}) require.True(t, eq)
require.NoError(t, rma.UpdateSelf(expected[0]))
waitOrCancel(ctx, t, ch)
} }
// Unsubscribe // Unsubscribe
{ {
ch := make(chan struct{})
serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) {
close(ch)
})
require.NoError(t, rma.UnsubscribeAgent(agentID)) require.NoError(t, rma.UnsubscribeAgent(agentID))
waitOrCancel(ctx, t, ch)
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId())
} }
// Close // Close
{ {
ch := make(chan struct{})
serverMultiAgent.EXPECT().Close().Do(func() {
close(ch)
})
require.NoError(t, rma.Close()) require.NoError(t, rma.Close())
waitOrCancel(ctx, t, ch)
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for req close")
case _, ok := <-reqs:
require.False(t, ok, "didn't close requests")
}
require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr))
} }
}) })
} }
func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
case <-ctx.Done():
t.Fatal("timed out waiting for channel")
}
}
type ResponseRecorder struct { type ResponseRecorder struct {
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
wasWritten atomic.Bool wasWritten atomic.Bool

View File

@ -490,6 +490,18 @@ func (c *configMaps) protoNodeToTailcfg(p *proto.Node) (*tailcfg.Node, error) {
}, nil }, nil
} }
// nodeAddresses returns the addresses for the peer with the given publicKey, if known.
func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
c.L.Lock()
defer c.L.Unlock()
for _, lc := range c.peers {
if lc.node.Key == publicKey {
return lc.node.Addresses, true
}
}
return nil, false
}
type peerLifecycle struct { type peerLifecycle struct {
peerID uuid.UUID peerID uuid.UUID
node *tailcfg.Node node *tailcfg.Node

View File

@ -3,48 +3,40 @@ package tailnet
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"os" "os"
"reflect"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/google/uuid" "github.com/google/uuid"
"go4.org/netipx"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/connstats" "tailscale.com/net/connstats"
"tailscale.com/net/dns"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/netns" "tailscale.com/net/netns"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/net/tstun" "tailscale.com/net/tstun"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tsd" "tailscale.com/tsd"
"tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
tslogger "tailscale.com/types/logger" tslogger "tailscale.com/types/logger"
"tailscale.com/types/netlogtype" "tailscale.com/types/netlogtype"
"tailscale.com/types/netmap"
"tailscale.com/wgengine" "tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/netstack" "tailscale.com/wgengine/netstack"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg/nmcfg"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/cryptorand" "github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/tailnet/proto"
) )
var ErrConnClosed = xerrors.New("connection closed") var ErrConnClosed = xerrors.New("connection closed")
@ -128,42 +120,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
} }
nodePrivateKey := key.NewNode() nodePrivateKey := key.NewNode()
nodePublicKey := nodePrivateKey.Public()
netMap := &netmap.NetworkMap{
DERPMap: options.DERPMap,
NodeKey: nodePublicKey,
PrivateKey: nodePrivateKey,
Addresses: options.Addresses,
PacketFilter: []filter.Match{{
// Allow any protocol!
IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP},
// Allow traffic sourced from anywhere.
Srcs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
},
// Allow traffic to route anywhere.
Dsts: []filter.NetPortRange{
{
Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
{
Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
},
Caps: []filter.CapMatch{},
}},
}
var nodeID tailcfg.NodeID var nodeID tailcfg.NodeID
// If we're provided with a UUID, use it to populate our node ID. // If we're provided with a UUID, use it to populate our node ID.
@ -177,14 +133,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
nodeID = tailcfg.NodeID(uid) nodeID = tailcfg.NodeID(uid)
} }
// This is used by functions below to identify the node via key
netMap.SelfNode = &tailcfg.Node{
ID: nodeID,
Key: nodePublicKey,
Addresses: options.Addresses,
AllowedIPs: options.Addresses,
}
wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor"))) wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil { if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err) return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
@ -243,7 +191,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
if err != nil { if err != nil {
return nil, xerrors.Errorf("set node private key: %w", err) return nil, xerrors.Errorf("set node private key: %w", err)
} }
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
netStack, err := netstack.Create( netStack, err := netstack.Create(
Logger(options.Logger.Named("net.netstack")), Logger(options.Logger.Named("net.netstack")),
@ -262,44 +209,46 @@ func NewConn(options *Options) (conn *Conn, err error) {
} }
netStack.ProcessLocalIPs = true netStack.ProcessLocalIPs = true
wireguardEngine = wgengine.NewWatchdog(wireguardEngine) wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
options.Logger.Debug(context.Background(), "updating network map")
wireguardEngine.SetNetworkMap(&netMapCopy)
localIPSet := netipx.IPSetBuilder{} cfgMaps := newConfigMaps(
for _, addr := range netMap.Addresses { options.Logger,
localIPSet.AddPrefix(addr) wireguardEngine,
} nodeID,
localIPs, _ := localIPSet.IPSet() nodePrivateKey,
logIPSet := netipx.IPSetBuilder{} magicConn.DiscoPublicKey(),
logIPs, _ := logIPSet.IPSet() )
wireguardEngine.SetFilter(filter.New( cfgMaps.setAddresses(options.Addresses)
netMap.PacketFilter, cfgMaps.setDERPMap(DERPMapToProto(options.DERPMap))
localIPs, cfgMaps.setBlockEndpoints(options.BlockEndpoints)
logIPs,
nodeUp := newNodeUpdater(
options.Logger,
nil, nil,
Logger(options.Logger.Named("net.packet-filter")), nodeID,
)) nodePrivateKey.Public(),
magicConn.DiscoPublicKey(),
)
nodeUp.setAddresses(options.Addresses)
nodeUp.setBlockEndpoints(options.BlockEndpoints)
wireguardEngine.SetStatusCallback(nodeUp.setStatus)
wireguardEngine.SetNetInfoCallback(nodeUp.setNetInfo)
magicConn.SetDERPForcedWebsocketCallback(nodeUp.setDERPForcedWebsocket)
server := &Conn{ server := &Conn{
blockEndpoints: options.BlockEndpoints, closed: make(chan struct{}),
derpForceWebSockets: options.DERPForceWebSockets, logger: options.Logger,
closed: make(chan struct{}), magicConn: magicConn,
logger: options.Logger, dialer: dialer,
magicConn: magicConn, listeners: map[listenKey]*listener{},
dialer: dialer, tunDevice: sys.Tun.Get(),
listeners: map[listenKey]*listener{}, netStack: netStack,
peerMap: map[tailcfg.NodeID]*tailcfg.Node{}, wireguardMonitor: wireguardMonitor,
lastDERPForcedWebSockets: map[int]string{},
tunDevice: sys.Tun.Get(),
netMap: netMap,
netStack: netStack,
wireguardMonitor: wireguardMonitor,
wireguardRouter: &router.Config{ wireguardRouter: &router.Config{
LocalAddrs: netMap.Addresses, LocalAddrs: options.Addresses,
}, },
wireguardEngine: wireguardEngine, wireguardEngine: wireguardEngine,
configMaps: cfgMaps,
nodeUpdater: nodeUp,
} }
defer func() { defer func() {
if err != nil { if err != nil {
@ -307,52 +256,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
} }
}() }()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
if err != nil {
return
}
server.lastMutex.Lock()
if s.AsOf.Before(server.lastStatus) {
// Don't process outdated status!
server.lastMutex.Unlock()
return
}
server.lastStatus = s.AsOf
if endpointsEqual(s.LocalAddrs, server.lastEndpoints) {
// No need to update the node if nothing changed!
server.lastMutex.Unlock()
return
}
server.lastEndpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...)
server.lastMutex.Unlock()
server.sendNode()
})
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni))
server.lastMutex.Lock()
if reflect.DeepEqual(server.lastNetInfo, ni) {
server.lastMutex.Unlock()
return
}
server.lastNetInfo = ni.Clone()
server.lastMutex.Unlock()
server.sendNode()
})
magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) {
server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason))
server.lastMutex.Lock()
if server.lastDERPForcedWebSockets[region] == reason {
server.lastMutex.Unlock()
return
}
server.lastDERPForcedWebSockets[region] = reason
server.lastMutex.Unlock()
server.sendNode()
})
netStack.GetTCPHandlerForFlow = server.forwardTCP netStack.GetTCPHandlerForFlow = server.forwardTCP
err = netStack.Start(nil) err = netStack.Start(nil)
@ -389,16 +292,14 @@ func IPFromUUID(uid uuid.UUID) netip.Addr {
// Conn is an actively listening Wireguard connection. // Conn is an actively listening Wireguard connection.
type Conn struct { type Conn struct {
mutex sync.Mutex mutex sync.Mutex
closed chan struct{} closed chan struct{}
logger slog.Logger logger slog.Logger
blockEndpoints bool
derpForceWebSockets bool
dialer *tsdial.Dialer dialer *tsdial.Dialer
tunDevice *tstun.Wrapper tunDevice *tstun.Wrapper
peerMap map[tailcfg.NodeID]*tailcfg.Node configMaps *configMaps
netMap *netmap.NetworkMap nodeUpdater *nodeUpdater
netStack *netstack.Impl netStack *netstack.Impl
magicConn *magicsock.Conn magicConn *magicsock.Conn
wireguardMonitor *netmon.Monitor wireguardMonitor *netmon.Monitor
@ -406,17 +307,6 @@ type Conn struct {
wireguardEngine wgengine.Engine wireguardEngine wgengine.Engine
listeners map[listenKey]*listener listeners map[listenKey]*listener
lastMutex sync.Mutex
nodeSending bool
nodeChanged bool
// It's only possible to store these values via status functions,
// so the values must be stored for retrieval later on.
lastStatus time.Time
lastEndpoints []tailcfg.Endpoint
lastDERPForcedWebSockets map[int]string
lastNetInfo *tailcfg.NetInfo
nodeCallback func(node *Node)
trafficStats *connstats.Statistics trafficStats *connstats.Statistics
} }
@ -425,57 +315,30 @@ func (c *Conn) MagicsockSetDebugLoggingEnabled(enabled bool) {
} }
func (c *Conn) SetAddresses(ips []netip.Prefix) error { func (c *Conn) SetAddresses(ips []netip.Prefix) error {
c.mutex.Lock() c.configMaps.setAddresses(ips)
defer c.mutex.Unlock() c.nodeUpdater.setAddresses(ips)
c.netMap.Addresses = ips
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil return nil
} }
func (c *Conn) Addresses() []netip.Prefix {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.Addresses
}
func (c *Conn) SetNodeCallback(callback func(node *Node)) { func (c *Conn) SetNodeCallback(callback func(node *Node)) {
c.lastMutex.Lock() c.nodeUpdater.setCallback(callback)
c.nodeCallback = callback
c.lastMutex.Unlock()
c.sendNode()
} }
// SetDERPMap updates the DERPMap of a connection. // SetDERPMap updates the DERPMap of a connection.
func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) { func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
c.mutex.Lock() c.configMaps.setDERPMap(DERPMapToProto(derpMap))
defer c.mutex.Unlock()
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
c.wireguardEngine.SetDERPMap(derpMap)
c.netMap.DERPMap = derpMap
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
} }
func (c *Conn) SetDERPForceWebSockets(v bool) { func (c *Conn) SetDERPForceWebSockets(v bool) {
c.logger.Info(context.Background(), "setting DERP Force Websockets", slog.F("force_derp_websockets", v))
c.magicConn.SetDERPForceWebsockets(v) c.magicConn.SetDERPForceWebsockets(v)
} }
// SetBlockEndpoints sets whether or not to block P2P endpoints. This setting // SetBlockEndpoints sets whether to block P2P endpoints. This setting
// will only apply to new peers. // will only apply to new peers.
func (c *Conn) SetBlockEndpoints(blockEndpoints bool) { func (c *Conn) SetBlockEndpoints(blockEndpoints bool) {
c.mutex.Lock() c.configMaps.setBlockEndpoints(blockEndpoints)
defer c.mutex.Unlock() c.nodeUpdater.setBlockEndpoints(blockEndpoints)
c.blockEndpoints = blockEndpoints
} }
// SetDERPRegionDialer updates the dialer to use for connecting to DERP regions. // SetDERPRegionDialer updates the dialer to use for connecting to DERP regions.
@ -483,186 +346,24 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail
c.magicConn.SetDERPRegionDialer(dialer) c.magicConn.SetDERPRegionDialer(dialer)
} }
// UpdateNodes connects with a set of peers. This can be constantly updated, // UpdatePeers connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary. If replacePeers is // and peers will continually be reconnected as necessary.
// true, all peers will be removed before adding the new ones. func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
//
//nolint:revive // Complains about replacePeers.
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.isClosed() { if c.isClosed() {
return ErrConnClosed return ErrConnClosed
} }
c.configMaps.updatePeers(updates)
status := c.Status()
if replacePeers {
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
}
for _, peer := range c.netMap.Peers {
peerStatus, ok := status.Peer[peer.Key]
if !ok {
continue
}
// If this peer was added in the last 5 minutes, assume it
// could still be active.
if time.Since(peer.Created) < 5*time.Minute {
continue
}
// We double-check that it's safe to remove by ensuring no
// handshake has been sent in the past 5 minutes as well. Connections that
// are actively exchanging IP traffic will handshake every 2 minutes.
if time.Since(peerStatus.LastHandshake) < 5*time.Minute {
continue
}
c.logger.Debug(context.Background(), "removing peer, last handshake >5m ago",
slog.F("peer", peer.Key), slog.F("last_handshake", peerStatus.LastHandshake),
)
delete(c.peerMap, peer.ID)
}
for _, node := range nodes {
// If no preferred DERP is provided, we can't reach the node.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
continue
}
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
peerStatus, ok := status.Peer[node.Key]
peerNode := &tailcfg.Node{
ID: node.ID,
Created: time.Now(),
Key: node.Key,
DiscoKey: node.DiscoKey,
Addresses: node.Addresses,
AllowedIPs: node.AllowedIPs,
Endpoints: node.Endpoints,
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
Hostinfo: (&tailcfg.Hostinfo{}).View(),
// Starting KeepAlive messages at the initialization of a connection
// causes a race condition. If we handshake before the peer has our
// node, we'll have wait for 5 seconds before trying again. Ideally,
// the first handshake starts when the user first initiates a
// connection to the peer. After a successful connection we enable
// keep alives to persist the connection and keep it from becoming
// idle. SSH connections don't send send packets while idle, so we
// use keep alives to avoid random hangs while we set up the
// connection again after inactivity.
KeepAlive: ok && peerStatus.Active,
}
if c.blockEndpoints {
peerNode.Endpoints = nil
}
c.peerMap[node.ID] = peerNode
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
// PeerSelector is used to select a peer from within a Tailnet.
type PeerSelector struct {
ID tailcfg.NodeID
IP netip.Prefix
}
func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.isClosed() {
return false, ErrConnClosed
}
deleted = false
for _, peer := range c.peerMap {
if peer.ID == selector.ID {
delete(c.peerMap, peer.ID)
deleted = true
break
}
for _, peerIP := range peer.Addresses {
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
delete(c.peerMap, peer.ID)
deleted = true
break
}
}
}
if !deleted {
return false, nil
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err = c.reconfig()
if err != nil {
return false, xerrors.Errorf("reconfig: %w", err)
}
return true, nil
}
func (c *Conn) reconfig() error {
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
return xerrors.Errorf("update wireguard config: %w", err)
}
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
if c.isClosed() {
return nil
}
if errors.Is(err, wgengine.ErrNoChanges) {
return nil
}
return xerrors.Errorf("reconfig: %w", err)
}
return nil return nil
} }
// NodeAddresses returns the addresses of a node from the NetworkMap. // NodeAddresses returns the addresses of a node from the NetworkMap.
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
c.mutex.Lock() return c.configMaps.nodeAddresses(publicKey)
defer c.mutex.Unlock()
for _, node := range c.netMap.Peers {
if node.Key == publicKey {
return node.Addresses, true
}
}
return nil, false
} }
// Status returns the current ipnstate of a connection. // Status returns the current ipnstate of a connection.
func (c *Conn) Status() *ipnstate.Status { func (c *Conn) Status() *ipnstate.Status {
sb := &ipnstate.StatusBuilder{WantPeers: true} return c.configMaps.status()
c.wireguardEngine.UpdateStatus(sb)
return sb.Status()
} }
// Ping sends a ping to the Wireguard engine. // Ping sends a ping to the Wireguard engine.
@ -689,16 +390,9 @@ func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *i
// DERPMap returns the currently set DERP mapping. // DERPMap returns the currently set DERP mapping.
func (c *Conn) DERPMap() *tailcfg.DERPMap { func (c *Conn) DERPMap() *tailcfg.DERPMap {
c.mutex.Lock() c.configMaps.L.Lock()
defer c.mutex.Unlock() defer c.configMaps.L.Unlock()
return c.netMap.DERPMap return c.configMaps.derpMapLocked()
}
// BlockEndpoints returns whether or not P2P is blocked.
func (c *Conn) BlockEndpoints() bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.blockEndpoints
} }
// AwaitReachable pings the provided IP continually until the // AwaitReachable pings the provided IP continually until the
@ -759,6 +453,9 @@ func (c *Conn) Closed() <-chan struct{} {
// Close shuts down the Wireguard connection. // Close shuts down the Wireguard connection.
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.logger.Info(context.Background(), "closing tailnet Conn")
c.configMaps.close()
c.nodeUpdater.close()
c.mutex.Lock() c.mutex.Lock()
select { select {
case <-c.closed: case <-c.closed:
@ -808,91 +505,11 @@ func (c *Conn) isClosed() bool {
} }
} }
func (c *Conn) sendNode() {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
if c.nodeSending {
c.nodeChanged = true
return
}
node := c.selfNode()
// Conn.UpdateNodes will skip any nodes that don't have the PreferredDERP
// set to non-zero, since we cannot reach nodes without DERP for discovery.
// Therefore, there is no point in sending the node without this, and we can
// save ourselves from churn in the tailscale/wireguard layer.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node))
return
}
nodeCallback := c.nodeCallback
if nodeCallback == nil {
return
}
c.nodeSending = true
go func() {
c.logger.Debug(context.Background(), "sending node", slog.F("node", node))
nodeCallback(node)
c.lastMutex.Lock()
c.nodeSending = false
if c.nodeChanged {
c.nodeChanged = false
c.lastMutex.Unlock()
c.sendNode()
return
}
c.lastMutex.Unlock()
}()
}
// Node returns the last node that was sent to the node callback. // Node returns the last node that was sent to the node callback.
func (c *Conn) Node() *Node { func (c *Conn) Node() *Node {
c.lastMutex.Lock() c.nodeUpdater.L.Lock()
defer c.lastMutex.Unlock() defer c.nodeUpdater.L.Unlock()
return c.selfNode() return c.nodeUpdater.nodeLocked()
}
func (c *Conn) selfNode() *Node {
endpoints := make([]string, 0, len(c.lastEndpoints))
for _, addr := range c.lastEndpoints {
endpoints = append(endpoints, addr.Addr.String())
}
var preferredDERP int
var derpLatency map[string]float64
derpForcedWebsocket := make(map[int]string, 0)
if c.lastNetInfo != nil {
preferredDERP = c.lastNetInfo.PreferredDERP
derpLatency = c.lastNetInfo.DERPLatency
if c.derpForceWebSockets {
// We only need to store this for a single region, since this is
// mostly used for debugging purposes and doesn't actually have a
// code purpose.
derpForcedWebsocket[preferredDERP] = "DERP is configured to always fallback to WebSockets"
} else {
for k, v := range c.lastDERPForcedWebSockets {
derpForcedWebsocket[k] = v
}
}
}
node := &Node{
ID: c.netMap.SelfNode.ID,
AsOf: dbtime.Now(),
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: endpoints,
PreferredDERP: preferredDERP,
DERPLatency: derpLatency,
DERPForcedWebsocket: derpForcedWebsocket,
}
c.mutex.Lock()
if c.blockEndpoints {
node.Endpoints = nil
}
c.mutex.Unlock()
return node
} }
// This and below is taken _mostly_ verbatim from Tailscale: // This and below is taken _mostly_ verbatim from Tailscale:
@ -1056,15 +673,3 @@ func Logger(logger slog.Logger) tslogger.Logf {
logger.Debug(context.Background(), fmt.Sprintf(format, args...)) logger.Debug(context.Background(), fmt.Sprintf(format, args...))
}) })
} }
func endpointsEqual(x, y []tailcfg.Endpoint) bool {
if len(x) != len(y) {
return false
}
for i := range x {
if x[i] != y[i] {
return false
}
}
return true
}

View File

@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
@ -12,6 +13,7 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -22,10 +24,10 @@ func TestMain(m *testing.M) {
func TestTailnet(t *testing.T) { func TestTailnet(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, _ := tailnettest.RunDERPAndSTUN(t) derpMap, _ := tailnettest.RunDERPAndSTUN(t)
t.Run("InstantClose", func(t *testing.T) { t.Run("InstantClose", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
conn, err := tailnet.NewConn(&tailnet.Options{ conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w1"), Logger: logger.Named("w1"),
@ -37,6 +39,8 @@ func TestTailnet(t *testing.T) {
}) })
t.Run("Connect", func(t *testing.T) { t.Run("Connect", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitLong)
w1IP := tailnet.IP() w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{ w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
@ -55,14 +59,8 @@ func TestTailnet(t *testing.T) {
_ = w1.Close() _ = w1.Close()
_ = w2.Close() _ = w2.Close()
}) })
w1.SetNodeCallback(func(node *tailnet.Node) { stitch(t, w2, w1)
err := w2.UpdateNodes([]*tailnet.Node{node}, false) stitch(t, w1, w2)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
require.True(t, w2.AwaitReachable(context.Background(), w1IP)) require.True(t, w2.AwaitReachable(context.Background(), w1IP))
conn := make(chan struct{}, 1) conn := make(chan struct{}, 1)
go func() { go func() {
@ -89,7 +87,7 @@ func TestTailnet(t *testing.T) {
default: default:
} }
}) })
node := <-nodes node := testutil.RequireRecvCtx(ctx, t, nodes)
// Ensure this connected over DERP! // Ensure this connected over DERP!
require.Len(t, node.DERPForcedWebsocket, 0) require.Len(t, node.DERPForcedWebsocket, 0)
@ -99,6 +97,7 @@ func TestTailnet(t *testing.T) {
t.Run("ForcesWebSockets", func(t *testing.T) { t.Run("ForcesWebSockets", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitMedium)
w1IP := tailnet.IP() w1IP := tailnet.IP()
@ -122,14 +121,8 @@ func TestTailnet(t *testing.T) {
_ = w1.Close() _ = w1.Close()
_ = w2.Close() _ = w2.Close()
}) })
w1.SetNodeCallback(func(node *tailnet.Node) { stitch(t, w2, w1)
err := w2.UpdateNodes([]*tailnet.Node{node}, false) stitch(t, w1, w2)
assert.NoError(t, err)
})
w2.SetNodeCallback(func(node *tailnet.Node) {
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
assert.NoError(t, err)
})
require.True(t, w2.AwaitReachable(ctx, w1IP)) require.True(t, w2.AwaitReachable(ctx, w1IP))
conn := make(chan struct{}, 1) conn := make(chan struct{}, 1)
go func() { go func() {
@ -243,11 +236,16 @@ func TestConn_UpdateDERP(t *testing.T) {
err := client1.Close() err := client1.Close()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
client1.SetNodeCallback(func(node *tailnet.Node) { stitch(t, conn, client1)
err := conn.UpdateNodes([]*tailnet.Node{node}, false) pn, err := tailnet.NodeToProto(conn.Node())
assert.NoError(t, err) require.NoError(t, err)
}) connID := uuid.New()
client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false) err = client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
require.NoError(t, err)
awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort) awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel1() defer awaitReachableCancel1()
@ -288,7 +286,13 @@ parentLoop:
// ... unless the client updates it's derp map and nodes. // ... unless the client updates it's derp map and nodes.
client1.SetDERPMap(derpMap2) client1.SetDERPMap(derpMap2)
client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false) pn, err = tailnet.NodeToProto(conn.Node())
require.NoError(t, err)
client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort) awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel3() defer awaitReachableCancel3()
require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip)) require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip))
@ -306,13 +310,34 @@ parentLoop:
err := client2.Close() err := client2.Close()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
client2.SetNodeCallback(func(node *tailnet.Node) { stitch(t, conn, client2)
err := conn.UpdateNodes([]*tailnet.Node{node}, false) pn, err = tailnet.NodeToProto(conn.Node())
assert.NoError(t, err) require.NoError(t, err)
}) client2.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
client2.UpdateNodes([]*tailnet.Node{conn.Node()}, false) Id: connID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort) awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort)
defer awaitReachableCancel4() defer awaitReachableCancel4()
require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip)) require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip))
} }
// stitch sends node updates from src Conn as peer updates to dst Conn. Sort of
// like the Coordinator would, but without actually needing a Coordinator.
func stitch(t *testing.T, dst, src *tailnet.Conn) {
srcID := uuid.New()
src.SetNodeCallback(func(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if !assert.NoError(t, err) {
return
}
err = dst.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
Id: srcID[:],
Node: pn,
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
}})
assert.NoError(t, err)
})
}

View File

@ -3,6 +3,7 @@ package tailnet
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"html/template" "html/template"
"io" "io"
"net" "net"
@ -92,6 +93,237 @@ type Node struct {
Endpoints []string `json:"endpoints"` Endpoints []string `json:"endpoints"`
} }
// Coordinatee is something that can be coordinated over the Coordinate protocol. Usually this is a
// Conn.
type Coordinatee interface {
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
SetNodeCallback(func(*Node))
}
type Coordination interface {
io.Closer
Error() <-chan error
}
type remoteCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
}
func (c *remoteCoordination) Close() error {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
}
return nil
}
func (c *remoteCoordination) Error() <-chan error {
return c.errChan
}
func (c *remoteCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *remoteCoordination) respLoop() {
for {
resp, err := c.protocol.Recv()
if err != nil {
c.sendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("update peers: %w", err))
return
}
}
}
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
// a client---agents should NOT set this!).
func NewRemoteCoordination(logger slog.Logger,
protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee,
tunnelTarget uuid.UUID,
) Coordination {
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
}
if tunnelTarget != uuid.Nil {
c.Lock()
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
c.Unlock()
if err != nil {
c.sendErr(err)
}
}
coordinatee.SetNodeCallback(func(node *Node) {
pn, err := NodeToProto(node)
if err != nil {
c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
c.sendErr(err)
return
}
c.Lock()
defer c.Unlock()
if c.closed {
c.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
c.sendErr(xerrors.Errorf("write: %w", err))
}
})
go c.respLoop()
return c
}
type inMemoryCoordination struct {
sync.Mutex
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}
func (c *inMemoryCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *inMemoryCoordination) Error() <-chan error {
return c.errChan
}
// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing
// or local clients. Set ClientID to uuid.Nil for an agent.
func NewInMemoryCoordination(
ctx context.Context, logger slog.Logger,
clientID, agentID uuid.UUID,
coordinator Coordinator, coordinatee Coordinatee,
) Coordination {
thisID := agentID
logger = logger.With(slog.F("agent_id", agentID))
var auth TunnelAuth = AgentTunnelAuth{}
if clientID != uuid.Nil {
// this is a client connection
auth = ClientTunnelAuth{AgentID: agentID}
logger = logger.With(slog.F("client_id", clientID))
thisID = clientID
}
c := &inMemoryCoordination{
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
}
// use the background context since we will depend exclusively on closing the req channel to
// tell the coordinator we are done.
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
thisID, fmt.Sprintf("inmemory%s", thisID),
auth,
)
go c.respLoop()
if agentID != uuid.Nil {
select {
case <-ctx.Done():
c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err()))
return c
case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}:
// OK!
}
}
coordinatee.SetNodeCallback(func(n *Node) {
pn, err := NodeToProto(n)
if err != nil {
c.logger.Critical(ctx, "failed to convert node", slog.Error(err))
c.sendErr(err)
return
}
c.Lock()
defer c.Unlock()
if c.closed {
return
}
select {
case <-ctx.Done():
c.logger.Info(ctx, "context expired before sending node update")
return
case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}:
c.logger.Debug(ctx, "sent node in-memory to coordinator")
}
})
return c
}
func (c *inMemoryCoordination) respLoop() {
for {
select {
case <-c.closedCh:
c.logger.Debug(context.Background(), "in-memory coordination closed")
return
case resp, ok := <-c.resps:
if !ok {
c.logger.Debug(context.Background(), "in-memory response channel closed")
return
}
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return
}
}
}
}
func (c *inMemoryCoordination) Close() error {
c.Lock()
defer c.Unlock()
c.logger.Debug(context.Background(), "closing in-memory coordination")
if c.closed {
return nil
}
defer close(c.reqs)
c.closed = true
close(c.closedCh)
select {
case <-c.ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
return nil
}
}
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages. // ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) { func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) {
errChan := make(chan error, 1) errChan := make(chan error, 1)
@ -237,21 +469,17 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge
} }
return false return false
}, },
OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) { OnSubscribe: func(enq Queue, agent uuid.UUID) error {
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}}) err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
return c.Node(agent), err return err
}, },
OnUnsubscribe: func(enq Queue, agent uuid.UUID) error { OnUnsubscribe: func(enq Queue, agent uuid.UUID) error {
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}}) err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
return err return err
}, },
OnNodeUpdate: func(id uuid.UUID, node *Node) error { OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error {
pn, err := NodeToProto(node)
if err != nil {
return err
}
return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: pn, Node: node,
}}) }})
}, },
OnRemove: func(_ Queue) { OnRemove: func(_ Queue) {
@ -285,7 +513,7 @@ const (
type Queue interface { type Queue interface {
UniqueID() uuid.UUID UniqueID() uuid.UUID
Kind() QueueKind Kind() QueueKind
Enqueue(n []*Node) error Enqueue(resp *proto.CoordinateResponse) error
Name() string Name() string
Stats() (start, lastWrite int64) Stats() (start, lastWrite int64)
Overwrites() int64 Overwrites() int64
@ -793,18 +1021,7 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
return return
} }
logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp)) logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp))
nodes, err := OnlyNodeUpdates(resp) err = q.Enqueue(resp)
if err != nil {
logger.Critical(ctx, "v1RespLoop failed to decode resp", slog.F("resp", resp), slog.Error(err))
_ = q.CoordinatorClose()
return
}
// don't send empty updates
if len(nodes) == 0 {
logger.Debug(ctx, "v1RespLoop skipping enqueueing 0-length v1 update")
continue
}
err = q.Enqueue(nodes)
if err != nil && !xerrors.Is(err, context.Canceled) { if err != nil && !xerrors.Is(err, context.Canceled) {
logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err)) logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err))
} }

View File

@ -8,13 +8,15 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/v2/tailnet/proto"
) )
type MultiAgentConn interface { type MultiAgentConn interface {
UpdateSelf(node *Node) error UpdateSelf(node *proto.Node) error
SubscribeAgent(agentID uuid.UUID) error SubscribeAgent(agentID uuid.UUID) error
UnsubscribeAgent(agentID uuid.UUID) error UnsubscribeAgent(agentID uuid.UUID) error
NextUpdate(ctx context.Context) ([]*Node, bool) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool)
AgentIsLegacy(agentID uuid.UUID) bool AgentIsLegacy(agentID uuid.UUID) bool
Close() error Close() error
IsClosed() bool IsClosed() bool
@ -26,16 +28,16 @@ type MultiAgent struct {
ID uuid.UUID ID uuid.UUID
AgentIsLegacyFunc func(agentID uuid.UUID) bool AgentIsLegacyFunc func(agentID uuid.UUID) bool
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) OnSubscribe func(enq Queue, agent uuid.UUID) error
OnUnsubscribe func(enq Queue, agent uuid.UUID) error OnUnsubscribe func(enq Queue, agent uuid.UUID) error
OnNodeUpdate func(id uuid.UUID, node *Node) error OnNodeUpdate func(id uuid.UUID, node *proto.Node) error
OnRemove func(enq Queue) OnRemove func(enq Queue)
ctx context.Context ctx context.Context
ctxCancel func() ctxCancel func()
closed bool closed bool
updates chan []*Node updates chan *proto.CoordinateResponse
closeOnce sync.Once closeOnce sync.Once
start int64 start int64
lastWrite int64 lastWrite int64
@ -45,7 +47,7 @@ type MultiAgent struct {
} }
func (m *MultiAgent) Init() *MultiAgent { func (m *MultiAgent) Init() *MultiAgent {
m.updates = make(chan []*Node, 128) m.updates = make(chan *proto.CoordinateResponse, 128)
m.start = time.Now().Unix() m.start = time.Now().Unix()
m.ctx, m.ctxCancel = context.WithCancel(context.Background()) m.ctx, m.ctxCancel = context.WithCancel(context.Background())
return m return m
@ -65,7 +67,7 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool {
var ErrMultiAgentClosed = xerrors.New("multiagent is closed") var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
func (m *MultiAgent) UpdateSelf(node *Node) error { func (m *MultiAgent) UpdateSelf(node *proto.Node) error {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if m.closed { if m.closed {
@ -82,15 +84,11 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
return ErrMultiAgentClosed return ErrMultiAgentClosed
} }
node, err := m.OnSubscribe(m, agentID) err := m.OnSubscribe(m, agentID)
if err != nil { if err != nil {
return err return err
} }
if node != nil {
return m.enqueueLocked([]*Node{node})
}
return nil return nil
} }
@ -104,17 +102,17 @@ func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
return m.OnUnsubscribe(m, agentID) return m.OnUnsubscribe(m, agentID)
} }
func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) { func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, false return nil, false
case nodes, ok := <-m.updates: case resp, ok := <-m.updates:
return nodes, ok return resp, ok
} }
} }
func (m *MultiAgent) Enqueue(nodes []*Node) error { func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
@ -122,14 +120,14 @@ func (m *MultiAgent) Enqueue(nodes []*Node) error {
return nil return nil
} }
return m.enqueueLocked(nodes) return m.enqueueLocked(resp)
} }
func (m *MultiAgent) enqueueLocked(nodes []*Node) error { func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error {
atomic.StoreInt64(&m.lastWrite, time.Now().Unix()) atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
select { select {
case m.updates <- nodes: case m.updates <- resp:
return nil return nil
default: default:
return ErrWouldBlock return ErrWouldBlock

View File

@ -75,7 +75,9 @@ func NewClientService(
} }
server := drpcserver.NewWithOptions(mux, drpcserver.Options{ server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) { Log: func(err error) {
if xerrors.Is(err, io.EOF) { if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) {
return return
} }
logger.Debug(context.Background(), "drpc server error", slog.Error(err)) logger.Debug(context.Background(), "drpc server error", slog.Error(err))

View File

@ -0,0 +1,142 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: Coordinator)
//
// Generated by this command:
//
// mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
net "net"
http "net/http"
reflect "reflect"
tailnet "github.com/coder/coder/v2/tailnet"
proto "github.com/coder/coder/v2/tailnet/proto"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockCoordinator is a mock of Coordinator interface.
type MockCoordinator struct {
ctrl *gomock.Controller
recorder *MockCoordinatorMockRecorder
}
// MockCoordinatorMockRecorder is the mock recorder for MockCoordinator.
type MockCoordinatorMockRecorder struct {
mock *MockCoordinator
}
// NewMockCoordinator creates a new mock instance.
func NewMockCoordinator(ctrl *gomock.Controller) *MockCoordinator {
mock := &MockCoordinator{ctrl: ctrl}
mock.recorder = &MockCoordinatorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCoordinator) EXPECT() *MockCoordinatorMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockCoordinator) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCoordinator)(nil).Close))
}
// Coordinate mocks base method.
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(chan<- *proto.CoordinateRequest)
ret1, _ := ret[1].(<-chan *proto.CoordinateResponse)
return ret0, ret1
}
// Coordinate indicates an expected call of Coordinate.
func (mr *MockCoordinatorMockRecorder) Coordinate(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Coordinate", reflect.TypeOf((*MockCoordinator)(nil).Coordinate), arg0, arg1, arg2, arg3)
}
// Node mocks base method.
func (m *MockCoordinator) Node(arg0 uuid.UUID) *tailnet.Node {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Node", arg0)
ret0, _ := ret[0].(*tailnet.Node)
return ret0
}
// Node indicates an expected call of Node.
func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0)
}
// ServeAgent mocks base method.
func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ServeAgent indicates an expected call of ServeAgent.
func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2)
}
// ServeClient mocks base method.
func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ServeClient indicates an expected call of ServeClient.
func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2)
}
// ServeHTTPDebug mocks base method.
func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ServeHTTPDebug", arg0, arg1)
}
// ServeHTTPDebug indicates an expected call of ServeHTTPDebug.
func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1)
}
// ServeMultiAgent mocks base method.
func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeMultiAgent", arg0)
ret0, _ := ret[0].(tailnet.MultiAgentConn)
return ret0
}
// ServeMultiAgent indicates an expected call of ServeMultiAgent.
func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0)
}

View File

@ -1,141 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn)
//
// Generated by this command:
//
// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
//
// Package tailnettest is a generated GoMock package.
package tailnettest
import (
context "context"
reflect "reflect"
tailnet "github.com/coder/coder/v2/tailnet"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockMultiAgentConn is a mock of MultiAgentConn interface.
type MockMultiAgentConn struct {
ctrl *gomock.Controller
recorder *MockMultiAgentConnMockRecorder
}
// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn.
type MockMultiAgentConnMockRecorder struct {
mock *MockMultiAgentConn
}
// NewMockMultiAgentConn creates a new mock instance.
func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn {
mock := &MockMultiAgentConn{ctrl: ctrl}
mock.recorder = &MockMultiAgentConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder {
return m.recorder
}
// AgentIsLegacy mocks base method.
func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AgentIsLegacy", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// AgentIsLegacy indicates an expected call of AgentIsLegacy.
func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0)
}
// Close mocks base method.
func (m *MockMultiAgentConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close))
}
// IsClosed mocks base method.
func (m *MockMultiAgentConn) IsClosed() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosed")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosed indicates an expected call of IsClosed.
func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed))
}
// NextUpdate mocks base method.
func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NextUpdate", arg0)
ret0, _ := ret[0].([]*tailnet.Node)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// NextUpdate indicates an expected call of NextUpdate.
func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0)
}
// SubscribeAgent mocks base method.
func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SubscribeAgent indicates an expected call of SubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0)
}
// UnsubscribeAgent mocks base method.
func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UnsubscribeAgent indicates an expected call of UnsubscribeAgent.
func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0)
}
// UpdateSelf mocks base method.
func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateSelf", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateSelf indicates an expected call of UpdateSelf.
func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0)
}

View File

@ -21,7 +21,7 @@ import (
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
) )
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn //go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
// RunDERPAndSTUN creates a DERP mapping for tests. // RunDERPAndSTUN creates a DERP mapping for tests.
func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) { func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
) )
const ( const (
@ -29,7 +30,7 @@ type TrackedConn struct {
cancel func() cancel func()
kind QueueKind kind QueueKind
conn net.Conn conn net.Conn
updates chan []*Node updates chan *proto.CoordinateResponse
logger slog.Logger logger slog.Logger
lastData []byte lastData []byte
@ -55,7 +56,7 @@ func NewTrackedConn(ctx context.Context, cancel func(),
// coordinator mutex while queuing. Node updates don't // coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but // come quickly, so 512 should be plenty for all but
// the most pathological cases. // the most pathological cases.
updates := make(chan []*Node, ResponseBufferSize) updates := make(chan *proto.CoordinateResponse, ResponseBufferSize)
now := time.Now().Unix() now := time.Now().Unix()
return &TrackedConn{ return &TrackedConn{
ctx: ctx, ctx: ctx,
@ -72,10 +73,10 @@ func NewTrackedConn(ctx context.Context, cancel func(),
} }
} }
func (t *TrackedConn) Enqueue(n []*Node) (err error) { func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) {
atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
select { select {
case t.updates <- n: case t.updates <- resp:
return nil return nil
default: default:
return ErrWouldBlock return ErrWouldBlock
@ -124,7 +125,16 @@ func (t *TrackedConn) SendUpdates() {
case <-t.ctx.Done(): case <-t.ctx.Done():
t.logger.Debug(t.ctx, "done sending updates") t.logger.Debug(t.ctx, "done sending updates")
return return
case nodes := <-t.updates: case resp := <-t.updates:
nodes, err := OnlyNodeUpdates(resp)
if err != nil {
t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err))
return
}
if len(nodes) == 0 {
t.logger.Debug(t.ctx, "skipping response with no nodes")
continue
}
data, err := json.Marshal(nodes) data, err := json.Marshal(nodes)
if err != nil { if err != nil {
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))