mirror of https://github.com/coder/coder.git
feat: use tailnet v2 API for coordination (#11638)
This one is huge, and I'm sorry. The problem is that once I change `tailnet.Conn` to start doing v2 behavior, I kind of have to change it everywhere, including in CoderSDK (CLI), the agent, wsproxy, and ServerTailnet. There is still a bit more cleanup to do, and I need to add code so that when we lose connection to the Coordinator, we mark all peers as LOST, but that will be in a separate PR since this is big enough!
This commit is contained in:
parent
5a2cf7cd14
commit
f01cab9894
7
Makefile
7
Makefile
|
@ -475,7 +475,8 @@ gen: \
|
|||
site/.eslintignore \
|
||||
site/e2e/provisionerGenerated.ts \
|
||||
site/src/theme/icons.json \
|
||||
examples/examples.gen.json
|
||||
examples/examples.gen.json \
|
||||
tailnet/tailnettest/coordinatormock.go
|
||||
.PHONY: gen
|
||||
|
||||
# Mark all generated files as fresh so make thinks they're up-to-date. This is
|
||||
|
@ -502,6 +503,7 @@ gen/mark-fresh:
|
|||
site/e2e/provisionerGenerated.ts \
|
||||
site/src/theme/icons.json \
|
||||
examples/examples.gen.json \
|
||||
tailnet/tailnettest/coordinatormock.go \
|
||||
"
|
||||
for file in $$files; do
|
||||
echo "$$file"
|
||||
|
@ -529,6 +531,9 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $
|
|||
coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go
|
||||
go generate ./coderd/database/dbmock/
|
||||
|
||||
tailnet/tailnettest/coordinatormock.go: tailnet/coordinator.go
|
||||
go generate ./tailnet/tailnettest/
|
||||
|
||||
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"storj.io/drpc"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/netlogtype"
|
||||
|
@ -47,6 +48,7 @@ import (
|
|||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -86,7 +88,7 @@ type Options struct {
|
|||
|
||||
type Client interface {
|
||||
Manifest(ctx context.Context) (agentsdk.Manifest, error)
|
||||
Listen(ctx context.Context) (net.Conn, error)
|
||||
Listen(ctx context.Context) (drpc.Conn, error)
|
||||
DERPMapUpdates(ctx context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error)
|
||||
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
|
||||
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
|
||||
|
@ -1058,20 +1060,34 @@ func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error
|
|||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
coordinator, err := a.client.Listen(ctx)
|
||||
conn, err := a.client.Listen(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer coordinator.Close()
|
||||
defer func() {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
tClient := tailnetproto.NewDRPCTailnetClient(conn)
|
||||
coordinate, err := tClient.Coordinate(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to connect to the coordinate endpoint: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
cErr := coordinate.Close()
|
||||
if cErr != nil {
|
||||
a.logger.Debug(ctx, "error closing Coordinate client", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
a.logger.Info(ctx, "connected to coordination endpoint")
|
||||
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error {
|
||||
return network.UpdateNodes(nodes, false)
|
||||
})
|
||||
network.SetNodeCallback(sendNodes)
|
||||
coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errChan:
|
||||
case err := <-coordination.Error():
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1664,9 +1664,11 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
|||
require.NotNil(t, originalDerpMap)
|
||||
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer func() {
|
||||
// use t.Cleanup so the coordinator closing doesn't deadlock with in-memory
|
||||
// coordination
|
||||
t.Cleanup(func() {
|
||||
_ = coordinator.Close()
|
||||
}()
|
||||
})
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *agentsdk.Stats, 50)
|
||||
fs := afero.NewMemMapFs()
|
||||
|
@ -1681,41 +1683,42 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
|||
statsCh,
|
||||
coordinator,
|
||||
)
|
||||
closer := agent.New(agent.Options{
|
||||
uut := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Filesystem: fs,
|
||||
Logger: logger.Named("agent"),
|
||||
ReconnectingPTYTimeout: time.Minute,
|
||||
})
|
||||
defer func() {
|
||||
_ = closer.Close()
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
_ = uut.Close()
|
||||
})
|
||||
|
||||
// Setup a client connection.
|
||||
newClientConn := func(derpMap *tailcfg.DERPMap) *codersdk.WorkspaceAgentConn {
|
||||
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *codersdk.WorkspaceAgentConn {
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: derpMap,
|
||||
Logger: logger.Named("client"),
|
||||
Logger: logger.Named(name),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
serveClientDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
t.Logf("closing conn %s", name)
|
||||
_ = conn.Close()
|
||||
<-serveClientDone
|
||||
})
|
||||
go func() {
|
||||
defer close(serveClientDone)
|
||||
err := coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(nodes, false)
|
||||
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(testCtxCancel)
|
||||
clientID := uuid.New()
|
||||
coordination := tailnet.NewInMemoryCoordination(
|
||||
testCtx, logger,
|
||||
clientID, agentID,
|
||||
coordinator, conn)
|
||||
t.Cleanup(func() {
|
||||
t.Logf("closing coordination %s", name)
|
||||
err := coordination.Close()
|
||||
if err != nil {
|
||||
t.Logf("error closing in-memory coordination: %s", err.Error())
|
||||
}
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
// Force DERP.
|
||||
conn.SetBlockEndpoints(true)
|
||||
|
||||
|
@ -1724,6 +1727,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
|||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
t.Logf("closing sdkConn %s", name)
|
||||
_ = sdkConn.Close()
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
|
@ -1734,7 +1738,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
|||
|
||||
return sdkConn
|
||||
}
|
||||
conn1 := newClientConn(originalDerpMap)
|
||||
conn1 := newClientConn(originalDerpMap, "client1")
|
||||
|
||||
// Change the DERP map.
|
||||
newDerpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
|
@ -1753,27 +1757,34 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
|||
DERPMap: newDerpMap,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Logf("client Pushed DERPMap update")
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
conn := closer.TailnetConn()
|
||||
conn := uut.TailnetConn()
|
||||
if conn == nil {
|
||||
return false
|
||||
}
|
||||
regionIDs := conn.DERPMap().RegionIDs()
|
||||
return len(regionIDs) == 1 && regionIDs[0] == 2 && conn.Node().PreferredDERP == 2
|
||||
preferredDERP := conn.Node().PreferredDERP
|
||||
t.Logf("agent Conn DERPMap with regionIDs %v, PreferredDERP %d", regionIDs, preferredDERP)
|
||||
return len(regionIDs) == 1 && regionIDs[0] == 2 && preferredDERP == 2
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
t.Logf("agent got the new DERPMap")
|
||||
|
||||
// Connect from a second client and make sure it uses the new DERP map.
|
||||
conn2 := newClientConn(newDerpMap)
|
||||
conn2 := newClientConn(newDerpMap, "client2")
|
||||
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
|
||||
t.Log("conn2 got the new DERPMap")
|
||||
|
||||
// If the first client gets a DERP map update, it should be able to
|
||||
// reconnect just fine.
|
||||
conn1.SetDERPMap(newDerpMap)
|
||||
require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs())
|
||||
t.Log("set the new DERPMap on conn1")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
require.True(t, conn1.AwaitReachable(ctx))
|
||||
t.Log("conn1 reached agent with new DERP")
|
||||
}
|
||||
|
||||
func TestAgent_Speedtest(t *testing.T) {
|
||||
|
@ -2050,22 +2061,22 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
|
|||
Logger: logger.Named("client"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
serveClientDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
<-serveClientDone
|
||||
})
|
||||
go func() {
|
||||
defer close(serveClientDone)
|
||||
coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID)
|
||||
}()
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(nodes, false)
|
||||
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(testCtxCancel)
|
||||
clientID := uuid.New()
|
||||
coordination := tailnet.NewInMemoryCoordination(
|
||||
testCtx, logger,
|
||||
clientID, metadata.AgentID,
|
||||
coordinator, conn)
|
||||
t.Cleanup(func() {
|
||||
err := coordination.Close()
|
||||
if err != nil {
|
||||
t.Logf("error closing in-mem coordination: %s", err.Error())
|
||||
}
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: metadata.AgentID,
|
||||
})
|
||||
|
|
|
@ -3,19 +3,26 @@ package agenttest
|
|||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/xerrors"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
@ -24,11 +31,31 @@ func NewClient(t testing.TB,
|
|||
agentID uuid.UUID,
|
||||
manifest agentsdk.Manifest,
|
||||
statsChan chan *agentsdk.Stats,
|
||||
coordinator tailnet.CoordinatorV1,
|
||||
coordinator tailnet.Coordinator,
|
||||
) *Client {
|
||||
if manifest.AgentID == uuid.Nil {
|
||||
manifest.AgentID = agentID
|
||||
}
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coordinator)
|
||||
mux := drpcmux.New()
|
||||
drpcService := &tailnet.DRPCService{
|
||||
CoordPtr: &coordPtr,
|
||||
Logger: logger,
|
||||
// TODO: handle DERPMap too!
|
||||
DerpMapUpdateFrequency: time.Hour,
|
||||
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
}
|
||||
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
||||
require.NoError(t, err)
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
return &Client{
|
||||
t: t,
|
||||
logger: logger.Named("client"),
|
||||
|
@ -36,6 +63,7 @@ func NewClient(t testing.TB,
|
|||
manifest: manifest,
|
||||
statsChan: statsChan,
|
||||
coordinator: coordinator,
|
||||
server: server,
|
||||
derpMapUpdates: make(chan agentsdk.DERPMapUpdate),
|
||||
}
|
||||
}
|
||||
|
@ -47,7 +75,8 @@ type Client struct {
|
|||
manifest agentsdk.Manifest
|
||||
metadata map[string]agentsdk.Metadata
|
||||
statsChan chan *agentsdk.Stats
|
||||
coordinator tailnet.CoordinatorV1
|
||||
coordinator tailnet.Coordinator
|
||||
server *drpcserver.Server
|
||||
LastWorkspaceAgent func()
|
||||
PatchWorkspaceLogs func() error
|
||||
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
|
||||
|
@ -63,20 +92,29 @@ func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
|
|||
return c.manifest, nil
|
||||
}
|
||||
|
||||
func (c *Client) Listen(_ context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
func (c *Client) Listen(_ context.Context) (drpc.Conn, error) {
|
||||
conn, lis := drpcsdk.MemTransportPipe()
|
||||
closed := make(chan struct{})
|
||||
c.LastWorkspaceAgent = func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
_ = conn.Close()
|
||||
_ = lis.Close()
|
||||
<-closed
|
||||
}
|
||||
c.t.Cleanup(c.LastWorkspaceAgent)
|
||||
serveCtx, cancel := context.WithCancel(context.Background())
|
||||
c.t.Cleanup(cancel)
|
||||
auth := tailnet.AgentTunnelAuth{}
|
||||
streamID := tailnet.StreamID{
|
||||
Name: "agenttest",
|
||||
ID: c.agentID,
|
||||
Auth: auth,
|
||||
}
|
||||
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
|
||||
go func() {
|
||||
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
|
||||
_ = c.server.Serve(serveCtx, lis)
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
@ -98,14 +99,32 @@ func TestDERP(t *testing.T) {
|
|||
|
||||
w2Ready := make(chan struct{})
|
||||
w2ReadyOnce := sync.Once{}
|
||||
w1ID := uuid.New()
|
||||
w1.SetNodeCallback(func(node *tailnet.Node) {
|
||||
w2.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
pn, err := tailnet.NodeToProto(node)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
w2.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
|
||||
Id: w1ID[:],
|
||||
Node: pn,
|
||||
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}})
|
||||
w2ReadyOnce.Do(func() {
|
||||
close(w2Ready)
|
||||
})
|
||||
})
|
||||
w2ID := uuid.New()
|
||||
w2.SetNodeCallback(func(node *tailnet.Node) {
|
||||
w1.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
pn, err := tailnet.NodeToProto(node)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
w1.UpdatePeers([]*tailnetproto.CoordinateResponse_PeerUpdate{{
|
||||
Id: w2ID[:],
|
||||
Node: pn,
|
||||
Kind: tailnetproto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}})
|
||||
})
|
||||
|
||||
conn := make(chan struct{})
|
||||
|
@ -199,7 +218,11 @@ func TestDERPForceWebSockets(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID,
|
||||
&codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("client"),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
|
|
|
@ -121,12 +121,23 @@ func NewServerTailnet(
|
|||
}
|
||||
tn.agentConn.Store(&agentConn)
|
||||
|
||||
err = tn.getAgentConn().UpdateSelf(conn.Node())
|
||||
pn, err := tailnet.NodeToProto(conn.Node())
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
|
||||
tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
} else {
|
||||
err = tn.getAgentConn().UpdateSelf(pn)
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
conn.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := tn.getAgentConn().UpdateSelf(node)
|
||||
pn, err := tailnet.NodeToProto(node)
|
||||
if err != nil {
|
||||
tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
return
|
||||
}
|
||||
err = tn.getAgentConn().UpdateSelf(pn)
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
|
||||
}
|
||||
|
@ -191,21 +202,9 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
|
|||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
|
||||
ID: tailnet.NodeID(agentID),
|
||||
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if !deleted {
|
||||
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
|
||||
}
|
||||
|
||||
deletedCount++
|
||||
delete(s.agentConnectionTimes, agentID)
|
||||
err = agentConn.UnsubscribeAgent(agentID)
|
||||
err := agentConn.UnsubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
|
@ -221,7 +220,7 @@ func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
|
|||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
for {
|
||||
conn := s.getAgentConn()
|
||||
nodes, ok := conn.NextUpdate(s.ctx)
|
||||
resp, ok := conn.NextUpdate(s.ctx)
|
||||
if !ok {
|
||||
if conn.IsClosed() && s.ctx.Err() == nil {
|
||||
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
|
||||
|
@ -231,7 +230,7 @@ func (s *ServerTailnet) watchAgentUpdates() {
|
|||
return
|
||||
}
|
||||
|
||||
err := s.conn.UpdateNodes(nodes, false)
|
||||
err := s.conn.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
if xerrors.Is(err, tailnet.ErrConnClosed) {
|
||||
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
|
||||
|
|
|
@ -3,7 +3,6 @@ package coderd_test
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
|
@ -204,22 +203,20 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
|
|||
Logger: logger.Named("client"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
serveClientDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
<-serveClientDone
|
||||
})
|
||||
go func() {
|
||||
defer close(serveClientDone)
|
||||
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID)
|
||||
}()
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
clientID := uuid.New()
|
||||
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(testCtxCancel)
|
||||
coordination := tailnet.NewInMemoryCoordination(
|
||||
testCtx, logger,
|
||||
clientID, manifest.AgentID,
|
||||
coord, conn,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
_ = coordination.Close()
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: manifest.AgentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
|
|
|
@ -30,6 +30,10 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion {
|
|||
return v
|
||||
}
|
||||
|
||||
func (v *APIVersion) String() string {
|
||||
return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor)
|
||||
}
|
||||
|
||||
// Validate validates the given version against the given constraints:
|
||||
// A given major.minor version is valid iff:
|
||||
// 1. The requested major version is contained within v.supportedMajors
|
||||
|
@ -42,10 +46,6 @@ func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion {
|
|||
// - 1.x is supported,
|
||||
// - 2.0, 2.1, and 2.2 are supported,
|
||||
// - 2.3+ is not supported.
|
||||
func (v *APIVersion) String() string {
|
||||
return fmt.Sprintf("%d.%d", v.supportedMajor, v.supportedMinor)
|
||||
}
|
||||
|
||||
func (v *APIVersion) Validate(version string) error {
|
||||
major, minor, err := Parse(version)
|
||||
if err != nil {
|
||||
|
|
|
@ -857,8 +857,6 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
|||
// Deprecated: use api.tailnet.AgentConn instead.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
derpMap := api.DERPMap()
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
|
@ -868,8 +866,6 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
|
|||
BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
|
||||
})
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(api.ctx)
|
||||
|
@ -887,10 +883,10 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
|
|||
return left
|
||||
})
|
||||
|
||||
sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(nodes, true)
|
||||
})
|
||||
conn.SetNodeCallback(sendNodes)
|
||||
clientID := uuid.New()
|
||||
coordination := tailnet.NewInMemoryCoordination(ctx, api.Logger,
|
||||
clientID, agentID,
|
||||
*(api.TailnetCoordinator.Load()), conn)
|
||||
|
||||
// Check for updated DERP map every 5 seconds.
|
||||
go func() {
|
||||
|
@ -920,27 +916,13 @@ func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspa
|
|||
AgentID: agentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
CloseFunc: func() error {
|
||||
_ = coordination.Close()
|
||||
cancel()
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
go func() {
|
||||
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
|
||||
if err != nil {
|
||||
// Sometimes, we get benign closed pipe errors when the server is
|
||||
// shutting down.
|
||||
if api.ctx.Err() == nil {
|
||||
api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
|
||||
}
|
||||
_ = agentConn.Close()
|
||||
}
|
||||
}()
|
||||
if !agentConn.AwaitReachable(ctx) {
|
||||
_ = agentConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("agent not reachable")
|
||||
}
|
||||
|
|
|
@ -535,7 +535,6 @@ func TestWorkspaceAgentTailnetDirectDisabled(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.True(t, conn.BlockEndpoints())
|
||||
|
||||
require.True(t, conn.AwaitReachable(ctx))
|
||||
_, p2p, _, err := conn.Ping(ctx)
|
||||
|
|
|
@ -12,14 +12,19 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/xerrors"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
@ -27,7 +32,9 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/wsconncache"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -41,7 +48,7 @@ func TestCache(t *testing.T) {
|
|||
t.Run("Same", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0), nil
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
}, 0)
|
||||
defer func() {
|
||||
_ = cache.Close()
|
||||
|
@ -54,10 +61,10 @@ func TestCache(t *testing.T) {
|
|||
})
|
||||
t.Run("Expire", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
called := atomic.NewInt32(0)
|
||||
called := int32(0)
|
||||
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
called.Add(1)
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0), nil
|
||||
atomic.AddInt32(&called, 1)
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
}, time.Microsecond)
|
||||
defer func() {
|
||||
_ = cache.Close()
|
||||
|
@ -70,12 +77,12 @@ func TestCache(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
release()
|
||||
<-conn.Closed()
|
||||
require.Equal(t, int32(2), called.Load())
|
||||
require.Equal(t, int32(2), called)
|
||||
})
|
||||
t.Run("NoExpireWhenLocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0), nil
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
}, time.Microsecond)
|
||||
defer func() {
|
||||
_ = cache.Close()
|
||||
|
@ -108,7 +115,7 @@ func TestCache(t *testing.T) {
|
|||
go server.Serve(random)
|
||||
|
||||
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0), nil
|
||||
return setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
}, time.Microsecond)
|
||||
defer func() {
|
||||
_ = cache.Close()
|
||||
|
@ -154,7 +161,7 @@ func TestCache(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
|
||||
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
|
||||
|
@ -184,18 +191,25 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
|
|||
DERPForceWebSockets: manifest.DERPForceWebSockets,
|
||||
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
// setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the
|
||||
// test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks.
|
||||
if !assert.NoError(t, err) {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(nodes, false)
|
||||
clientID := uuid.New()
|
||||
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(testCtxCancel)
|
||||
coordination := tailnet.NewInMemoryCoordination(
|
||||
testCtx, logger,
|
||||
clientID, manifest.AgentID,
|
||||
coordinator, conn,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
_ = coordination.Close()
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: manifest.AgentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
|
@ -206,16 +220,20 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
if !agentConn.AwaitReachable(ctx) {
|
||||
t.Fatal("agent not reachable")
|
||||
// setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end
|
||||
// the test, which in turn closes the wsconncache, which in turn waits for the Dialer and
|
||||
// deadlocks.
|
||||
t.Error("agent not reachable")
|
||||
return nil, xerrors.New("agent not reachable")
|
||||
}
|
||||
return agentConn
|
||||
return agentConn, nil
|
||||
}
|
||||
|
||||
type client struct {
|
||||
t *testing.T
|
||||
agentID uuid.UUID
|
||||
manifest agentsdk.Manifest
|
||||
coordinator tailnet.CoordinatorV1
|
||||
coordinator tailnet.Coordinator
|
||||
}
|
||||
|
||||
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
|
||||
|
@ -240,19 +258,53 @@ func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate,
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) Listen(_ context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
|
||||
logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
|
||||
conn, lis := drpcsdk.MemTransportPipe()
|
||||
closed := make(chan struct{})
|
||||
c.t.Cleanup(func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
_ = conn.Close()
|
||||
_ = lis.Close()
|
||||
<-closed
|
||||
})
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&c.coordinator)
|
||||
mux := drpcmux.New()
|
||||
drpcService := &tailnet.DRPCService{
|
||||
CoordPtr: &coordPtr,
|
||||
Logger: logger,
|
||||
// TODO: handle DERPMap too!
|
||||
DerpMapUpdateFrequency: time.Hour,
|
||||
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
}
|
||||
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("register DRPC service: %w", err)
|
||||
}
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) ||
|
||||
xerrors.Is(err, context.Canceled) ||
|
||||
xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
serveCtx, cancel := context.WithCancel(context.Background())
|
||||
c.t.Cleanup(cancel)
|
||||
auth := tailnet.AgentTunnelAuth{}
|
||||
streamID := tailnet.StreamID{
|
||||
Name: "wsconncache_test-agent",
|
||||
ID: c.agentID,
|
||||
Auth: auth,
|
||||
}
|
||||
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
|
||||
go func() {
|
||||
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
|
||||
server.Serve(serveCtx, lis)
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {
|
||||
|
|
|
@ -14,12 +14,15 @@ import (
|
|||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
"storj.io/drpc"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
|
@ -280,8 +283,8 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
|
|||
|
||||
// Listen connects to the workspace agent coordinate WebSocket
|
||||
// that handles connection negotiation.
|
||||
func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
|
||||
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/coordinate")
|
||||
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
|
||||
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
@ -312,14 +315,21 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
|
|||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
pingClosed := pingWebSocket(ctx, c.SDK.Logger(), conn, "coordinate")
|
||||
|
||||
return &closeNetConn{
|
||||
netConn := &closeNetConn{
|
||||
Conn: wsNetConn,
|
||||
closeFunc: func() {
|
||||
cancelFunc()
|
||||
_ = conn.Close(websocket.StatusGoingAway, "Listen closed")
|
||||
<-pingClosed
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Client(netConn, config)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("multiplex client: %w", err)
|
||||
}
|
||||
return drpcsdk.MultiplexedConn(session), nil
|
||||
}
|
||||
|
||||
type PostAppHealthsRequest struct {
|
||||
|
|
|
@ -313,6 +313,9 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
q := coordinateURL.Query()
|
||||
q.Add("version", tailnet.CurrentVersion.String())
|
||||
coordinateURL.RawQuery = q.Encode()
|
||||
closedCoordinator := make(chan struct{})
|
||||
// Must only ever be used once, send error OR close to avoid
|
||||
// reassignment race. Buffered so we don't hang in goroutine.
|
||||
|
@ -344,12 +347,22 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
|
|||
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(nodes, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
client, err := tailnet.NewDRPCClient(websocket.NetConn(ctx, ws, websocket.MessageBinary))
|
||||
if err != nil {
|
||||
options.Logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
continue
|
||||
}
|
||||
coordinate, err := client.Coordinate(ctx)
|
||||
if err != nil {
|
||||
options.Logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
continue
|
||||
}
|
||||
|
||||
coordination := tailnet.NewRemoteCoordination(options.Logger, coordinate, conn, agentID)
|
||||
options.Logger.Debug(ctx, "serving coordinator")
|
||||
err = <-errChan
|
||||
err = <-coordination.Error()
|
||||
if errors.Is(err, context.Canceled) {
|
||||
_ = ws.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
|
@ -53,6 +54,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
ctx := r.Context()
|
||||
|
||||
version := "1.0"
|
||||
msgType := websocket.MessageText
|
||||
qv := r.URL.Query().Get("version")
|
||||
if qv != "" {
|
||||
version = qv
|
||||
|
@ -66,6 +68,11 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
})
|
||||
return
|
||||
}
|
||||
maj, _, _ := apiversion.Parse(version)
|
||||
if maj >= 2 {
|
||||
// Versions 2+ use dRPC over a binary connection
|
||||
msgType = websocket.MessageBinary
|
||||
}
|
||||
|
||||
api.AGPL.WebsocketWaitMutex.Lock()
|
||||
api.AGPL.WebsocketWaitGroup.Add(1)
|
||||
|
@ -81,7 +88,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
ctx, nc := websocketNetConn(ctx, conn, msgType)
|
||||
defer nc.Close()
|
||||
|
||||
id := uuid.New()
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
@ -20,6 +21,7 @@ import (
|
|||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
@ -27,6 +29,12 @@ import (
|
|||
|
||||
func Test_agentIsLegacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
nodeKey := key.NewNode().Public()
|
||||
discoKey := key.NewDisco().Public()
|
||||
nkBin, err := nodeKey.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
dkBin, err := discoKey.MarshalText()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Legacy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
@ -54,18 +62,18 @@ func Test_agentIsLegacy(t *testing.T) {
|
|||
nodeID := uuid.New()
|
||||
ma := coordinator.ServeMultiAgent(nodeID)
|
||||
defer ma.Close()
|
||||
require.NoError(t, ma.UpdateSelf(&agpl.Node{
|
||||
ID: 55,
|
||||
AsOf: time.Unix(1689653252, 0),
|
||||
Key: key.NewNode().Public(),
|
||||
DiscoKey: key.NewDisco().Public(),
|
||||
PreferredDERP: 0,
|
||||
DERPLatency: map[string]float64{
|
||||
require.NoError(t, ma.UpdateSelf(&proto.Node{
|
||||
Id: 55,
|
||||
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
|
||||
Key: nkBin,
|
||||
Disco: string(dkBin),
|
||||
PreferredDerp: 0,
|
||||
DerpLatency: map[string]float64{
|
||||
"0": 1.0,
|
||||
},
|
||||
DERPForcedWebsocket: map[int]string{},
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
||||
AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
||||
DerpForcedWebsocket: map[int32]string{},
|
||||
Addresses: []string{codersdk.WorkspaceAgentIP.String() + "/128"},
|
||||
AllowedIps: []string{codersdk.WorkspaceAgentIP.String() + "/128"},
|
||||
Endpoints: []string{"192.168.1.1:18842"},
|
||||
}))
|
||||
require.Eventually(t, func() bool {
|
||||
|
@ -114,18 +122,18 @@ func Test_agentIsLegacy(t *testing.T) {
|
|||
nodeID := uuid.New()
|
||||
ma := coordinator.ServeMultiAgent(nodeID)
|
||||
defer ma.Close()
|
||||
require.NoError(t, ma.UpdateSelf(&agpl.Node{
|
||||
ID: 55,
|
||||
AsOf: time.Unix(1689653252, 0),
|
||||
Key: key.NewNode().Public(),
|
||||
DiscoKey: key.NewDisco().Public(),
|
||||
PreferredDERP: 0,
|
||||
DERPLatency: map[string]float64{
|
||||
require.NoError(t, ma.UpdateSelf(&proto.Node{
|
||||
Id: 55,
|
||||
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
|
||||
Key: nkBin,
|
||||
Disco: string(dkBin),
|
||||
PreferredDerp: 0,
|
||||
DerpLatency: map[string]float64{
|
||||
"0": 1.0,
|
||||
},
|
||||
DERPForcedWebsocket: map[int]string{},
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)},
|
||||
AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)},
|
||||
DerpForcedWebsocket: map[int32]string{},
|
||||
Addresses: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()},
|
||||
AllowedIps: []string{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128).String()},
|
||||
Endpoints: []string{"192.168.1.1:18842"},
|
||||
}))
|
||||
require.Eventually(t, func() bool {
|
||||
|
|
|
@ -6,12 +6,15 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
@ -39,22 +42,19 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
|||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord1.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
ma1 := newTestMultiAgent(t, coord1)
|
||||
defer ma1.close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
ma1.subscribeAgent(agent1.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 1)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
ma1.sendNodeWithDERP(3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
ma1.close()
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
@ -86,23 +86,20 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
|
|||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord1.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
ma1 := newTestMultiAgent(t, coord1)
|
||||
defer ma1.close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
ma1.subscribeAgent(agent1.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 1)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
ma1.sendNodeWithDERP(3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
|
||||
require.NoError(t, ma1.Close())
|
||||
ma1.unsubscribeAgent(agent1.id)
|
||||
ma1.close()
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
@ -134,37 +131,35 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
|
|||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord1.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
ma1 := newTestMultiAgent(t, coord1)
|
||||
defer ma1.close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
ma1.subscribeAgent(agent1.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 1)
|
||||
|
||||
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}))
|
||||
ma1.sendNodeWithDERP(3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
|
||||
ma1.unsubscribeAgent(agent1.id)
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||
defer cancel()
|
||||
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9}))
|
||||
ma1.sendNodeWithDERP(9)
|
||||
assertNeverHasDERPs(ctx, t, agent1, 9)
|
||||
}()
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||
defer cancel()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
|
||||
assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8)
|
||||
ma1.assertNeverHasDERPs(ctx, 8)
|
||||
}()
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
ma1.close()
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
@ -201,22 +196,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
|
|||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord2.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
ma1 := newTestMultiAgent(t, coord2)
|
||||
defer ma1.close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
ma1.subscribeAgent(agent1.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 1)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
ma1.sendNodeWithDERP(3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
ma1.close()
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
@ -254,22 +246,19 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
|
|||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord2.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
ma1 := newTestMultiAgent(t, coord2)
|
||||
defer ma1.close()
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
ma1.sendNodeWithDERP(3)
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
ma1.subscribeAgent(agent1.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 5)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 1)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
ma1.close()
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
@ -316,33 +305,129 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
|
|||
defer agent1.close()
|
||||
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord3.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
ma1 := newTestMultiAgent(t, coord3)
|
||||
defer ma1.close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
ma1.subscribeAgent(agent1.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 1)
|
||||
|
||||
err = ma1.SubscribeAgent(agent2.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6)
|
||||
ma1.subscribeAgent(agent2.id)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 6)
|
||||
|
||||
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2)
|
||||
ma1.assertEventuallyHasDERPs(ctx, 2)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
ma1.sendNodeWithDERP(3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent2, 3)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
ma1.close()
|
||||
require.NoError(t, agent1.close())
|
||||
require.NoError(t, agent2.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
type testMultiAgent struct {
|
||||
t testing.TB
|
||||
id uuid.UUID
|
||||
a agpl.MultiAgentConn
|
||||
nodeKey []byte
|
||||
discoKey string
|
||||
}
|
||||
|
||||
func newTestMultiAgent(t testing.TB, coord agpl.Coordinator) *testMultiAgent {
|
||||
nk, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
dk, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
m := &testMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
|
||||
m.a = coord.ServeMultiAgent(m.id)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *testMultiAgent) sendNodeWithDERP(derp int32) {
|
||||
m.t.Helper()
|
||||
err := m.a.UpdateSelf(&proto.Node{
|
||||
Key: m.nodeKey,
|
||||
Disco: m.discoKey,
|
||||
PreferredDerp: derp,
|
||||
})
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *testMultiAgent) close() {
|
||||
m.t.Helper()
|
||||
err := m.a.Close()
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *testMultiAgent) subscribeAgent(id uuid.UUID) {
|
||||
m.t.Helper()
|
||||
err := m.a.SubscribeAgent(id)
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *testMultiAgent) unsubscribeAgent(id uuid.UUID) {
|
||||
m.t.Helper()
|
||||
err := m.a.UnsubscribeAgent(id)
|
||||
require.NoError(m.t, err)
|
||||
}
|
||||
|
||||
func (m *testMultiAgent) assertEventuallyHasDERPs(ctx context.Context, expected ...int) {
|
||||
m.t.Helper()
|
||||
for {
|
||||
resp, ok := m.a.NextUpdate(ctx)
|
||||
require.True(m.t, ok)
|
||||
nodes, err := agpl.OnlyNodeUpdates(resp)
|
||||
require.NoError(m.t, err)
|
||||
if len(nodes) != len(expected) {
|
||||
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
m.t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *testMultiAgent) assertNeverHasDERPs(ctx context.Context, expected ...int) {
|
||||
m.t.Helper()
|
||||
for {
|
||||
resp, ok := m.a.NextUpdate(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
nodes, err := agpl.OnlyNodeUpdates(resp)
|
||||
require.NoError(m.t, err)
|
||||
if len(nodes) != len(expected) {
|
||||
m.t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
m.t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -818,56 +818,6 @@ func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expecte
|
|||
}
|
||||
}
|
||||
|
||||
func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
|
||||
t.Helper()
|
||||
for {
|
||||
nodes, ok := ma.NextUpdate(ctx)
|
||||
require.True(t, ok)
|
||||
if len(nodes) != len(expected) {
|
||||
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
|
||||
t.Helper()
|
||||
for {
|
||||
nodes, ok := ma.NextUpdate(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if len(nodes) != len(expected) {
|
||||
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
|
|
|
@ -96,7 +96,11 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
|
|||
return xerrors.Errorf("unsubscribe agent: %w", err)
|
||||
}
|
||||
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
|
||||
err := ma.UpdateSelf(msg.Node)
|
||||
pn, err := agpl.NodeToProto(msg.Node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ma.UpdateSelf(pn)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update self: %w", err)
|
||||
}
|
||||
|
@ -110,11 +114,14 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
|
|||
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
|
||||
var lastData []byte
|
||||
for {
|
||||
nodes, ok := ma.NextUpdate(ctx)
|
||||
resp, ok := ma.NextUpdate(ctx)
|
||||
if !ok {
|
||||
return xerrors.New("multiagent is closed")
|
||||
}
|
||||
|
||||
nodes, err := agpl.OnlyNodeUpdates(resp)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to convert response: %w", err)
|
||||
}
|
||||
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -158,7 +158,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
|||
// TODO: Probably do some version checking here
|
||||
info, err := client.SDKClient.BuildInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("buildinfo: %w", errors.Join(
|
||||
return nil, xerrors.Errorf("buildinfo: %w", errors.Join(
|
||||
xerrors.Errorf("unable to fetch build info from primary coderd. Are you sure %q is a coderd instance?", opts.DashboardURL),
|
||||
err,
|
||||
))
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
@ -23,6 +22,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
// Client is a HTTP client for a subset of Coder API routes that external
|
||||
|
@ -438,6 +438,9 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
|
|||
cancel()
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
q := coordinateURL.Query()
|
||||
q.Add("version", agpl.CurrentVersion.String())
|
||||
coordinateURL.RawQuery = q.Encode()
|
||||
coordinateHeaders := make(http.Header)
|
||||
tokenHeader := codersdk.SessionTokenHeader
|
||||
if c.SDKClient.SessionTokenHeader != "" {
|
||||
|
@ -457,10 +460,24 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
|
|||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
nc := websocket.NetConn(ctx, conn, websocket.MessageText)
|
||||
nc := websocket.NetConn(ctx, conn, websocket.MessageBinary)
|
||||
client, err := agpl.NewDRPCClient(nc)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "")
|
||||
return nil, xerrors.Errorf("failed to create DRPCClient: %w", err)
|
||||
}
|
||||
protocol, err := client.Coordinate(ctx)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to reach the Coordinate endpoint", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "")
|
||||
return nil, xerrors.Errorf("failed to reach the Coordinate endpoint: %w", err)
|
||||
}
|
||||
|
||||
rma := remoteMultiAgentHandler{
|
||||
sdk: c,
|
||||
nc: nc,
|
||||
logger: logger,
|
||||
protocol: protocol,
|
||||
cancel: cancel,
|
||||
legacyAgentCache: map[uuid.UUID]bool{},
|
||||
}
|
||||
|
@ -471,103 +488,75 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
|
|||
OnSubscribe: rma.OnSubscribe,
|
||||
OnUnsubscribe: rma.OnUnsubscribe,
|
||||
OnNodeUpdate: rma.OnNodeUpdate,
|
||||
OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") },
|
||||
OnRemove: rma.OnRemove,
|
||||
}).Init()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
ma.Close()
|
||||
_ = conn.Close(websocket.StatusGoingAway, "closed")
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
dec := json.NewDecoder(nc)
|
||||
for {
|
||||
var msg CoordinateNodes
|
||||
err := dec.Decode(&msg)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
logger.Info(ctx, "websocket connection severed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error(ctx, "decode coordinator nodes", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = ma.Enqueue(msg.Nodes)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "enqueue nodes from coordinator", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
rma.ma = ma
|
||||
go rma.respLoop()
|
||||
|
||||
return ma, nil
|
||||
}
|
||||
|
||||
type remoteMultiAgentHandler struct {
|
||||
sdk *Client
|
||||
nc net.Conn
|
||||
cancel func()
|
||||
sdk *Client
|
||||
logger slog.Logger
|
||||
protocol proto.DRPCTailnet_CoordinateClient
|
||||
ma *agpl.MultiAgent
|
||||
cancel func()
|
||||
|
||||
legacyMu sync.RWMutex
|
||||
legacyAgentCache map[uuid.UUID]bool
|
||||
legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse]
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("json marshal message: %w", err)
|
||||
}
|
||||
func (a *remoteMultiAgentHandler) respLoop() {
|
||||
{
|
||||
defer a.cancel()
|
||||
for {
|
||||
resp, err := a.protocol.Recv()
|
||||
if err != nil {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
a.logger.Info(context.Background(), "remote multiagent connection severed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set a deadline so that hung connections don't put back pressure on the system.
|
||||
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
||||
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
|
||||
if err != nil {
|
||||
a.cancel()
|
||||
return xerrors.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
_, err = a.nc.Write(data)
|
||||
if err != nil {
|
||||
a.cancel()
|
||||
return xerrors.Errorf("write message: %w", err)
|
||||
}
|
||||
a.logger.Error(context.Background(), "error receiving multiagent responses", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
|
||||
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
|
||||
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
|
||||
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
|
||||
// our successful write, it is important that we reset the deadline before it fires.
|
||||
err = a.nc.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
a.cancel()
|
||||
return xerrors.Errorf("clear write deadline: %w", err)
|
||||
err = a.ma.Enqueue(resp)
|
||||
if err != nil {
|
||||
a.logger.Error(context.Background(), "enqueue response from coordinator", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error {
|
||||
return a.writeJSON(CoordinateMessage{
|
||||
Type: CoordinateMessageTypeNodeUpdate,
|
||||
Node: node,
|
||||
})
|
||||
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *proto.Node) error {
|
||||
return a.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
|
||||
return nil, a.writeJSON(CoordinateMessage{
|
||||
Type: CoordinateMessageTypeSubscribe,
|
||||
AgentID: agentID,
|
||||
})
|
||||
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) error {
|
||||
return a.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
|
||||
return a.writeJSON(CoordinateMessage{
|
||||
Type: CoordinateMessageTypeUnsubscribe,
|
||||
AgentID: agentID,
|
||||
})
|
||||
return a.protocol.Send(&proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnRemove(_ agpl.Queue) {
|
||||
err := a.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
if err != nil {
|
||||
a.logger.Warn(context.Background(), "failed to gracefully disconnect", slog.Error(err))
|
||||
}
|
||||
_ = a.protocol.CloseSend()
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool {
|
||||
|
|
|
@ -18,8 +18,9 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
@ -30,6 +31,7 @@ import (
|
|||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -156,25 +158,48 @@ func TestDialCoordinator(t *testing.T) {
|
|||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
agentID = uuid.New()
|
||||
serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t))
|
||||
r = chi.NewRouter()
|
||||
srv = httptest.NewServer(r)
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
agentID = uuid.UUID{33}
|
||||
proxyID = uuid.UUID{44}
|
||||
mCoord = tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||
coord agpl.Coordinator = mCoord
|
||||
r = chi.NewRouter()
|
||||
srv = httptest.NewServer(r)
|
||||
)
|
||||
defer cancel()
|
||||
defer srv.Close()
|
||||
|
||||
coordPtr := atomic.Pointer[agpl.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
cSrv, err := tailnet.NewClientService(
|
||||
logger, &coordPtr,
|
||||
time.Hour,
|
||||
func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// buffer the channels here, so we don't need to read and write in goroutines to
|
||||
// avoid blocking
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}).
|
||||
Times(1).
|
||||
Return(reqs, resps)
|
||||
|
||||
serveMACErr := make(chan error, 1)
|
||||
r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := websocket.Accept(w, r, nil)
|
||||
require.NoError(t, err)
|
||||
nc := websocket.NetConn(r.Context(), conn, websocket.MessageText)
|
||||
defer serverMultiAgent.Close()
|
||||
|
||||
err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent)
|
||||
if !xerrors.Is(err, io.EOF) {
|
||||
assert.NoError(t, err)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
version := r.URL.Query().Get("version")
|
||||
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
|
||||
return
|
||||
}
|
||||
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
err = cSrv.ServeMultiAgentClient(ctx, version, nc, proxyID)
|
||||
serveMACErr <- err
|
||||
})
|
||||
r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{
|
||||
|
@ -188,51 +213,50 @@ func TestDialCoordinator(t *testing.T) {
|
|||
client := wsproxysdk.New(u)
|
||||
client.SDKClient.SetLogger(logger)
|
||||
|
||||
expected := []*agpl.Node{{
|
||||
ID: 55,
|
||||
AsOf: time.Unix(1689653252, 0),
|
||||
Key: key.NewNode().Public(),
|
||||
DiscoKey: key.NewDisco().Public(),
|
||||
PreferredDERP: 0,
|
||||
DERPLatency: map[string]float64{
|
||||
"0": 1.0,
|
||||
peerID := uuid.UUID{55}
|
||||
peerNodeKey, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
peerDiscoKey, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
expected := &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: peerID[:],
|
||||
Node: &proto.Node{
|
||||
Id: 55,
|
||||
AsOf: timestamppb.New(time.Unix(1689653252, 0)),
|
||||
Key: peerNodeKey[:],
|
||||
Disco: string(peerDiscoKey),
|
||||
PreferredDerp: 0,
|
||||
DerpLatency: map[string]float64{
|
||||
"0": 1.0,
|
||||
},
|
||||
DerpForcedWebsocket: map[int32]string{},
|
||||
Addresses: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
|
||||
AllowedIps: []string{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128).String()},
|
||||
Endpoints: []string{"192.168.1.1:18842"},
|
||||
},
|
||||
DERPForcedWebsocket: map[int]string{},
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
|
||||
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)},
|
||||
Endpoints: []string{"192.168.1.1:18842"},
|
||||
}}
|
||||
sendNode := make(chan struct{})
|
||||
|
||||
serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes().
|
||||
DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) {
|
||||
select {
|
||||
case <-sendNode:
|
||||
return expected, true
|
||||
case <-ctx.Done():
|
||||
return nil, false
|
||||
}
|
||||
})
|
||||
}}}
|
||||
|
||||
rma, err := client.DialCoordinator(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subscribe
|
||||
{
|
||||
ch := make(chan struct{})
|
||||
serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) {
|
||||
close(ch)
|
||||
})
|
||||
require.NoError(t, rma.SubscribeAgent(agentID))
|
||||
waitOrCancel(ctx, t, ch)
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
|
||||
}
|
||||
// Read updated agent node
|
||||
{
|
||||
sendNode <- struct{}{}
|
||||
got, ok := rma.NextUpdate(ctx)
|
||||
resps <- expected
|
||||
|
||||
resp, ok := rma.NextUpdate(ctx)
|
||||
assert.True(t, ok)
|
||||
got[0].AsOf = got[0].AsOf.In(time.Local)
|
||||
assert.Equal(t, *expected[0], *got[0])
|
||||
updates := resp.GetPeerUpdates()
|
||||
assert.Len(t, updates, 1)
|
||||
eq, err := updates[0].GetNode().Equal(expected.GetPeerUpdates()[0].GetNode())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, eq)
|
||||
}
|
||||
// Check legacy
|
||||
{
|
||||
|
@ -241,45 +265,38 @@ func TestDialCoordinator(t *testing.T) {
|
|||
}
|
||||
// UpdateSelf
|
||||
{
|
||||
ch := make(chan struct{})
|
||||
serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) {
|
||||
node.AsOf = node.AsOf.In(time.Local)
|
||||
assert.Equal(t, expected[0], node)
|
||||
close(ch)
|
||||
})
|
||||
require.NoError(t, rma.UpdateSelf(expected[0]))
|
||||
waitOrCancel(ctx, t, ch)
|
||||
require.NoError(t, rma.UpdateSelf(expected.PeerUpdates[0].GetNode()))
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
eq, err := req.GetUpdateSelf().GetNode().Equal(expected.PeerUpdates[0].GetNode())
|
||||
require.NoError(t, err)
|
||||
require.True(t, eq)
|
||||
}
|
||||
// Unsubscribe
|
||||
{
|
||||
ch := make(chan struct{})
|
||||
serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) {
|
||||
close(ch)
|
||||
})
|
||||
require.NoError(t, rma.UnsubscribeAgent(agentID))
|
||||
waitOrCancel(ctx, t, ch)
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, agentID[:], req.GetRemoveTunnel().GetId())
|
||||
}
|
||||
// Close
|
||||
{
|
||||
ch := make(chan struct{})
|
||||
serverMultiAgent.EXPECT().Close().Do(func() {
|
||||
close(ch)
|
||||
})
|
||||
require.NoError(t, rma.Close())
|
||||
waitOrCancel(ctx, t, ch)
|
||||
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.NotNil(t, req.Disconnect)
|
||||
close(resps)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for req close")
|
||||
case _, ok := <-reqs:
|
||||
require.False(t, ok, "didn't close requests")
|
||||
}
|
||||
require.Error(t, testutil.RequireRecvCtx(ctx, t, serveMACErr))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-ch:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for channel")
|
||||
}
|
||||
}
|
||||
|
||||
type ResponseRecorder struct {
|
||||
rw *httptest.ResponseRecorder
|
||||
wasWritten atomic.Bool
|
||||
|
|
|
@ -490,6 +490,18 @@ func (c *configMaps) protoNodeToTailcfg(p *proto.Node) (*tailcfg.Node, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// nodeAddresses returns the addresses for the peer with the given publicKey, if known.
|
||||
func (c *configMaps) nodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
|
||||
c.L.Lock()
|
||||
defer c.L.Unlock()
|
||||
for _, lc := range c.peers {
|
||||
if lc.node.Key == publicKey {
|
||||
return lc.node.Addresses, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
type peerLifecycle struct {
|
||||
peerID uuid.UUID
|
||||
node *tailcfg.Node
|
||||
|
|
519
tailnet/conn.go
519
tailnet/conn.go
|
@ -3,48 +3,40 @@ package tailnet
|
|||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"go4.org/netipx"
|
||||
"golang.org/x/xerrors"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/connstats"
|
||||
"tailscale.com/net/dns"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/net/tstun"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tsd"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
tslogger "tailscale.com/types/logger"
|
||||
"tailscale.com/types/netlogtype"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/magicsock"
|
||||
"tailscale.com/wgengine/netstack"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/wgcfg/nmcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
var ErrConnClosed = xerrors.New("connection closed")
|
||||
|
@ -128,42 +120,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
}
|
||||
|
||||
nodePrivateKey := key.NewNode()
|
||||
nodePublicKey := nodePrivateKey.Public()
|
||||
|
||||
netMap := &netmap.NetworkMap{
|
||||
DERPMap: options.DERPMap,
|
||||
NodeKey: nodePublicKey,
|
||||
PrivateKey: nodePrivateKey,
|
||||
Addresses: options.Addresses,
|
||||
PacketFilter: []filter.Match{{
|
||||
// Allow any protocol!
|
||||
IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP},
|
||||
// Allow traffic sourced from anywhere.
|
||||
Srcs: []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
|
||||
},
|
||||
// Allow traffic to route anywhere.
|
||||
Dsts: []filter.NetPortRange{
|
||||
{
|
||||
Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
|
||||
Ports: filter.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
{
|
||||
Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
|
||||
Ports: filter.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
Caps: []filter.CapMatch{},
|
||||
}},
|
||||
}
|
||||
|
||||
var nodeID tailcfg.NodeID
|
||||
|
||||
// If we're provided with a UUID, use it to populate our node ID.
|
||||
|
@ -177,14 +133,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
nodeID = tailcfg.NodeID(uid)
|
||||
}
|
||||
|
||||
// This is used by functions below to identify the node via key
|
||||
netMap.SelfNode = &tailcfg.Node{
|
||||
ID: nodeID,
|
||||
Key: nodePublicKey,
|
||||
Addresses: options.Addresses,
|
||||
AllowedIPs: options.Addresses,
|
||||
}
|
||||
|
||||
wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
|
||||
|
@ -243,7 +191,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("set node private key: %w", err)
|
||||
}
|
||||
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
|
||||
|
||||
netStack, err := netstack.Create(
|
||||
Logger(options.Logger.Named("net.netstack")),
|
||||
|
@ -262,44 +209,46 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
}
|
||||
netStack.ProcessLocalIPs = true
|
||||
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
|
||||
wireguardEngine.SetDERPMap(options.DERPMap)
|
||||
netMapCopy := *netMap
|
||||
options.Logger.Debug(context.Background(), "updating network map")
|
||||
wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
|
||||
localIPSet := netipx.IPSetBuilder{}
|
||||
for _, addr := range netMap.Addresses {
|
||||
localIPSet.AddPrefix(addr)
|
||||
}
|
||||
localIPs, _ := localIPSet.IPSet()
|
||||
logIPSet := netipx.IPSetBuilder{}
|
||||
logIPs, _ := logIPSet.IPSet()
|
||||
wireguardEngine.SetFilter(filter.New(
|
||||
netMap.PacketFilter,
|
||||
localIPs,
|
||||
logIPs,
|
||||
cfgMaps := newConfigMaps(
|
||||
options.Logger,
|
||||
wireguardEngine,
|
||||
nodeID,
|
||||
nodePrivateKey,
|
||||
magicConn.DiscoPublicKey(),
|
||||
)
|
||||
cfgMaps.setAddresses(options.Addresses)
|
||||
cfgMaps.setDERPMap(DERPMapToProto(options.DERPMap))
|
||||
cfgMaps.setBlockEndpoints(options.BlockEndpoints)
|
||||
|
||||
nodeUp := newNodeUpdater(
|
||||
options.Logger,
|
||||
nil,
|
||||
Logger(options.Logger.Named("net.packet-filter")),
|
||||
))
|
||||
nodeID,
|
||||
nodePrivateKey.Public(),
|
||||
magicConn.DiscoPublicKey(),
|
||||
)
|
||||
nodeUp.setAddresses(options.Addresses)
|
||||
nodeUp.setBlockEndpoints(options.BlockEndpoints)
|
||||
wireguardEngine.SetStatusCallback(nodeUp.setStatus)
|
||||
wireguardEngine.SetNetInfoCallback(nodeUp.setNetInfo)
|
||||
magicConn.SetDERPForcedWebsocketCallback(nodeUp.setDERPForcedWebsocket)
|
||||
|
||||
server := &Conn{
|
||||
blockEndpoints: options.BlockEndpoints,
|
||||
derpForceWebSockets: options.DERPForceWebSockets,
|
||||
closed: make(chan struct{}),
|
||||
logger: options.Logger,
|
||||
magicConn: magicConn,
|
||||
dialer: dialer,
|
||||
listeners: map[listenKey]*listener{},
|
||||
peerMap: map[tailcfg.NodeID]*tailcfg.Node{},
|
||||
lastDERPForcedWebSockets: map[int]string{},
|
||||
tunDevice: sys.Tun.Get(),
|
||||
netMap: netMap,
|
||||
netStack: netStack,
|
||||
wireguardMonitor: wireguardMonitor,
|
||||
closed: make(chan struct{}),
|
||||
logger: options.Logger,
|
||||
magicConn: magicConn,
|
||||
dialer: dialer,
|
||||
listeners: map[listenKey]*listener{},
|
||||
tunDevice: sys.Tun.Get(),
|
||||
netStack: netStack,
|
||||
wireguardMonitor: wireguardMonitor,
|
||||
wireguardRouter: &router.Config{
|
||||
LocalAddrs: netMap.Addresses,
|
||||
LocalAddrs: options.Addresses,
|
||||
},
|
||||
wireguardEngine: wireguardEngine,
|
||||
configMaps: cfgMaps,
|
||||
nodeUpdater: nodeUp,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
|
@ -307,52 +256,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
|
||||
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
server.lastMutex.Lock()
|
||||
if s.AsOf.Before(server.lastStatus) {
|
||||
// Don't process outdated status!
|
||||
server.lastMutex.Unlock()
|
||||
return
|
||||
}
|
||||
server.lastStatus = s.AsOf
|
||||
if endpointsEqual(s.LocalAddrs, server.lastEndpoints) {
|
||||
// No need to update the node if nothing changed!
|
||||
server.lastMutex.Unlock()
|
||||
return
|
||||
}
|
||||
server.lastEndpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...)
|
||||
server.lastMutex.Unlock()
|
||||
server.sendNode()
|
||||
})
|
||||
|
||||
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
|
||||
server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni))
|
||||
server.lastMutex.Lock()
|
||||
if reflect.DeepEqual(server.lastNetInfo, ni) {
|
||||
server.lastMutex.Unlock()
|
||||
return
|
||||
}
|
||||
server.lastNetInfo = ni.Clone()
|
||||
server.lastMutex.Unlock()
|
||||
server.sendNode()
|
||||
})
|
||||
|
||||
magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) {
|
||||
server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason))
|
||||
server.lastMutex.Lock()
|
||||
if server.lastDERPForcedWebSockets[region] == reason {
|
||||
server.lastMutex.Unlock()
|
||||
return
|
||||
}
|
||||
server.lastDERPForcedWebSockets[region] = reason
|
||||
server.lastMutex.Unlock()
|
||||
server.sendNode()
|
||||
})
|
||||
|
||||
netStack.GetTCPHandlerForFlow = server.forwardTCP
|
||||
|
||||
err = netStack.Start(nil)
|
||||
|
@ -389,16 +292,14 @@ func IPFromUUID(uid uuid.UUID) netip.Addr {
|
|||
|
||||
// Conn is an actively listening Wireguard connection.
|
||||
type Conn struct {
|
||||
mutex sync.Mutex
|
||||
closed chan struct{}
|
||||
logger slog.Logger
|
||||
blockEndpoints bool
|
||||
derpForceWebSockets bool
|
||||
mutex sync.Mutex
|
||||
closed chan struct{}
|
||||
logger slog.Logger
|
||||
|
||||
dialer *tsdial.Dialer
|
||||
tunDevice *tstun.Wrapper
|
||||
peerMap map[tailcfg.NodeID]*tailcfg.Node
|
||||
netMap *netmap.NetworkMap
|
||||
configMaps *configMaps
|
||||
nodeUpdater *nodeUpdater
|
||||
netStack *netstack.Impl
|
||||
magicConn *magicsock.Conn
|
||||
wireguardMonitor *netmon.Monitor
|
||||
|
@ -406,17 +307,6 @@ type Conn struct {
|
|||
wireguardEngine wgengine.Engine
|
||||
listeners map[listenKey]*listener
|
||||
|
||||
lastMutex sync.Mutex
|
||||
nodeSending bool
|
||||
nodeChanged bool
|
||||
// It's only possible to store these values via status functions,
|
||||
// so the values must be stored for retrieval later on.
|
||||
lastStatus time.Time
|
||||
lastEndpoints []tailcfg.Endpoint
|
||||
lastDERPForcedWebSockets map[int]string
|
||||
lastNetInfo *tailcfg.NetInfo
|
||||
nodeCallback func(node *Node)
|
||||
|
||||
trafficStats *connstats.Statistics
|
||||
}
|
||||
|
||||
|
@ -425,57 +315,30 @@ func (c *Conn) MagicsockSetDebugLoggingEnabled(enabled bool) {
|
|||
}
|
||||
|
||||
func (c *Conn) SetAddresses(ips []netip.Prefix) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.netMap.Addresses = ips
|
||||
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
err := c.reconfig()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
c.configMaps.setAddresses(ips)
|
||||
c.nodeUpdater.setAddresses(ips)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Addresses() []netip.Prefix {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return c.netMap.Addresses
|
||||
}
|
||||
|
||||
func (c *Conn) SetNodeCallback(callback func(node *Node)) {
|
||||
c.lastMutex.Lock()
|
||||
c.nodeCallback = callback
|
||||
c.lastMutex.Unlock()
|
||||
c.sendNode()
|
||||
c.nodeUpdater.setCallback(callback)
|
||||
}
|
||||
|
||||
// SetDERPMap updates the DERPMap of a connection.
|
||||
func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
|
||||
c.wireguardEngine.SetDERPMap(derpMap)
|
||||
c.netMap.DERPMap = derpMap
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
c.configMaps.setDERPMap(DERPMapToProto(derpMap))
|
||||
}
|
||||
|
||||
func (c *Conn) SetDERPForceWebSockets(v bool) {
|
||||
c.logger.Info(context.Background(), "setting DERP Force Websockets", slog.F("force_derp_websockets", v))
|
||||
c.magicConn.SetDERPForceWebsockets(v)
|
||||
}
|
||||
|
||||
// SetBlockEndpoints sets whether or not to block P2P endpoints. This setting
|
||||
// SetBlockEndpoints sets whether to block P2P endpoints. This setting
|
||||
// will only apply to new peers.
|
||||
func (c *Conn) SetBlockEndpoints(blockEndpoints bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.blockEndpoints = blockEndpoints
|
||||
c.configMaps.setBlockEndpoints(blockEndpoints)
|
||||
c.nodeUpdater.setBlockEndpoints(blockEndpoints)
|
||||
}
|
||||
|
||||
// SetDERPRegionDialer updates the dialer to use for connecting to DERP regions.
|
||||
|
@ -483,186 +346,24 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail
|
|||
c.magicConn.SetDERPRegionDialer(dialer)
|
||||
}
|
||||
|
||||
// UpdateNodes connects with a set of peers. This can be constantly updated,
|
||||
// and peers will continually be reconnected as necessary. If replacePeers is
|
||||
// true, all peers will be removed before adding the new ones.
|
||||
//
|
||||
//nolint:revive // Complains about replacePeers.
|
||||
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// UpdatePeers connects with a set of peers. This can be constantly updated,
|
||||
// and peers will continually be reconnected as necessary.
|
||||
func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
|
||||
if c.isClosed() {
|
||||
return ErrConnClosed
|
||||
}
|
||||
|
||||
status := c.Status()
|
||||
if replacePeers {
|
||||
c.netMap.Peers = []*tailcfg.Node{}
|
||||
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
|
||||
}
|
||||
for _, peer := range c.netMap.Peers {
|
||||
peerStatus, ok := status.Peer[peer.Key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// If this peer was added in the last 5 minutes, assume it
|
||||
// could still be active.
|
||||
if time.Since(peer.Created) < 5*time.Minute {
|
||||
continue
|
||||
}
|
||||
// We double-check that it's safe to remove by ensuring no
|
||||
// handshake has been sent in the past 5 minutes as well. Connections that
|
||||
// are actively exchanging IP traffic will handshake every 2 minutes.
|
||||
if time.Since(peerStatus.LastHandshake) < 5*time.Minute {
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Debug(context.Background(), "removing peer, last handshake >5m ago",
|
||||
slog.F("peer", peer.Key), slog.F("last_handshake", peerStatus.LastHandshake),
|
||||
)
|
||||
delete(c.peerMap, peer.ID)
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
// If no preferred DERP is provided, we can't reach the node.
|
||||
if node.PreferredDERP == 0 {
|
||||
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
|
||||
continue
|
||||
}
|
||||
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
|
||||
|
||||
peerStatus, ok := status.Peer[node.Key]
|
||||
peerNode := &tailcfg.Node{
|
||||
ID: node.ID,
|
||||
Created: time.Now(),
|
||||
Key: node.Key,
|
||||
DiscoKey: node.DiscoKey,
|
||||
Addresses: node.Addresses,
|
||||
AllowedIPs: node.AllowedIPs,
|
||||
Endpoints: node.Endpoints,
|
||||
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
|
||||
Hostinfo: (&tailcfg.Hostinfo{}).View(),
|
||||
// Starting KeepAlive messages at the initialization of a connection
|
||||
// causes a race condition. If we handshake before the peer has our
|
||||
// node, we'll have wait for 5 seconds before trying again. Ideally,
|
||||
// the first handshake starts when the user first initiates a
|
||||
// connection to the peer. After a successful connection we enable
|
||||
// keep alives to persist the connection and keep it from becoming
|
||||
// idle. SSH connections don't send send packets while idle, so we
|
||||
// use keep alives to avoid random hangs while we set up the
|
||||
// connection again after inactivity.
|
||||
KeepAlive: ok && peerStatus.Active,
|
||||
}
|
||||
if c.blockEndpoints {
|
||||
peerNode.Endpoints = nil
|
||||
}
|
||||
c.peerMap[node.ID] = peerNode
|
||||
}
|
||||
|
||||
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
|
||||
for _, peer := range c.peerMap {
|
||||
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
|
||||
}
|
||||
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
err := c.reconfig()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeerSelector is used to select a peer from within a Tailnet.
|
||||
type PeerSelector struct {
|
||||
ID tailcfg.NodeID
|
||||
IP netip.Prefix
|
||||
}
|
||||
|
||||
func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.isClosed() {
|
||||
return false, ErrConnClosed
|
||||
}
|
||||
|
||||
deleted = false
|
||||
for _, peer := range c.peerMap {
|
||||
if peer.ID == selector.ID {
|
||||
delete(c.peerMap, peer.ID)
|
||||
deleted = true
|
||||
break
|
||||
}
|
||||
|
||||
for _, peerIP := range peer.Addresses {
|
||||
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
|
||||
delete(c.peerMap, peer.ID)
|
||||
deleted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !deleted {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
|
||||
for _, peer := range c.peerMap {
|
||||
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
|
||||
}
|
||||
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
err = c.reconfig()
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *Conn) reconfig() error {
|
||||
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update wireguard config: %w", err)
|
||||
}
|
||||
|
||||
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
|
||||
if err != nil {
|
||||
if c.isClosed() {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, wgengine.ErrNoChanges) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
c.configMaps.updatePeers(updates)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NodeAddresses returns the addresses of a node from the NetworkMap.
|
||||
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
for _, node := range c.netMap.Peers {
|
||||
if node.Key == publicKey {
|
||||
return node.Addresses, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
return c.configMaps.nodeAddresses(publicKey)
|
||||
}
|
||||
|
||||
// Status returns the current ipnstate of a connection.
|
||||
func (c *Conn) Status() *ipnstate.Status {
|
||||
sb := &ipnstate.StatusBuilder{WantPeers: true}
|
||||
c.wireguardEngine.UpdateStatus(sb)
|
||||
return sb.Status()
|
||||
return c.configMaps.status()
|
||||
}
|
||||
|
||||
// Ping sends a ping to the Wireguard engine.
|
||||
|
@ -689,16 +390,9 @@ func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *i
|
|||
|
||||
// DERPMap returns the currently set DERP mapping.
|
||||
func (c *Conn) DERPMap() *tailcfg.DERPMap {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return c.netMap.DERPMap
|
||||
}
|
||||
|
||||
// BlockEndpoints returns whether or not P2P is blocked.
|
||||
func (c *Conn) BlockEndpoints() bool {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return c.blockEndpoints
|
||||
c.configMaps.L.Lock()
|
||||
defer c.configMaps.L.Unlock()
|
||||
return c.configMaps.derpMapLocked()
|
||||
}
|
||||
|
||||
// AwaitReachable pings the provided IP continually until the
|
||||
|
@ -759,6 +453,9 @@ func (c *Conn) Closed() <-chan struct{} {
|
|||
|
||||
// Close shuts down the Wireguard connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.logger.Info(context.Background(), "closing tailnet Conn")
|
||||
c.configMaps.close()
|
||||
c.nodeUpdater.close()
|
||||
c.mutex.Lock()
|
||||
select {
|
||||
case <-c.closed:
|
||||
|
@ -808,91 +505,11 @@ func (c *Conn) isClosed() bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendNode() {
|
||||
c.lastMutex.Lock()
|
||||
defer c.lastMutex.Unlock()
|
||||
if c.nodeSending {
|
||||
c.nodeChanged = true
|
||||
return
|
||||
}
|
||||
node := c.selfNode()
|
||||
// Conn.UpdateNodes will skip any nodes that don't have the PreferredDERP
|
||||
// set to non-zero, since we cannot reach nodes without DERP for discovery.
|
||||
// Therefore, there is no point in sending the node without this, and we can
|
||||
// save ourselves from churn in the tailscale/wireguard layer.
|
||||
if node.PreferredDERP == 0 {
|
||||
c.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node))
|
||||
return
|
||||
}
|
||||
nodeCallback := c.nodeCallback
|
||||
if nodeCallback == nil {
|
||||
return
|
||||
}
|
||||
c.nodeSending = true
|
||||
go func() {
|
||||
c.logger.Debug(context.Background(), "sending node", slog.F("node", node))
|
||||
nodeCallback(node)
|
||||
c.lastMutex.Lock()
|
||||
c.nodeSending = false
|
||||
if c.nodeChanged {
|
||||
c.nodeChanged = false
|
||||
c.lastMutex.Unlock()
|
||||
c.sendNode()
|
||||
return
|
||||
}
|
||||
c.lastMutex.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// Node returns the last node that was sent to the node callback.
|
||||
func (c *Conn) Node() *Node {
|
||||
c.lastMutex.Lock()
|
||||
defer c.lastMutex.Unlock()
|
||||
return c.selfNode()
|
||||
}
|
||||
|
||||
func (c *Conn) selfNode() *Node {
|
||||
endpoints := make([]string, 0, len(c.lastEndpoints))
|
||||
for _, addr := range c.lastEndpoints {
|
||||
endpoints = append(endpoints, addr.Addr.String())
|
||||
}
|
||||
var preferredDERP int
|
||||
var derpLatency map[string]float64
|
||||
derpForcedWebsocket := make(map[int]string, 0)
|
||||
if c.lastNetInfo != nil {
|
||||
preferredDERP = c.lastNetInfo.PreferredDERP
|
||||
derpLatency = c.lastNetInfo.DERPLatency
|
||||
|
||||
if c.derpForceWebSockets {
|
||||
// We only need to store this for a single region, since this is
|
||||
// mostly used for debugging purposes and doesn't actually have a
|
||||
// code purpose.
|
||||
derpForcedWebsocket[preferredDERP] = "DERP is configured to always fallback to WebSockets"
|
||||
} else {
|
||||
for k, v := range c.lastDERPForcedWebSockets {
|
||||
derpForcedWebsocket[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node := &Node{
|
||||
ID: c.netMap.SelfNode.ID,
|
||||
AsOf: dbtime.Now(),
|
||||
Key: c.netMap.SelfNode.Key,
|
||||
Addresses: c.netMap.SelfNode.Addresses,
|
||||
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
|
||||
DiscoKey: c.magicConn.DiscoPublicKey(),
|
||||
Endpoints: endpoints,
|
||||
PreferredDERP: preferredDERP,
|
||||
DERPLatency: derpLatency,
|
||||
DERPForcedWebsocket: derpForcedWebsocket,
|
||||
}
|
||||
c.mutex.Lock()
|
||||
if c.blockEndpoints {
|
||||
node.Endpoints = nil
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return node
|
||||
c.nodeUpdater.L.Lock()
|
||||
defer c.nodeUpdater.L.Unlock()
|
||||
return c.nodeUpdater.nodeLocked()
|
||||
}
|
||||
|
||||
// This and below is taken _mostly_ verbatim from Tailscale:
|
||||
|
@ -1056,15 +673,3 @@ func Logger(logger slog.Logger) tslogger.Logf {
|
|||
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
|
||||
})
|
||||
}
|
||||
|
||||
func endpointsEqual(x, y []tailcfg.Endpoint) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i := range x {
|
||||
if x[i] != y[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -22,10 +24,10 @@ func TestMain(m *testing.M) {
|
|||
|
||||
func TestTailnet(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
t.Run("InstantClose", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
Logger: logger.Named("w1"),
|
||||
|
@ -37,6 +39,8 @@ func TestTailnet(t *testing.T) {
|
|||
})
|
||||
t.Run("Connect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
w1IP := tailnet.IP()
|
||||
w1, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
|
||||
|
@ -55,14 +59,8 @@ func TestTailnet(t *testing.T) {
|
|||
_ = w1.Close()
|
||||
_ = w2.Close()
|
||||
})
|
||||
w1.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
w2.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
stitch(t, w2, w1)
|
||||
stitch(t, w1, w2)
|
||||
require.True(t, w2.AwaitReachable(context.Background(), w1IP))
|
||||
conn := make(chan struct{}, 1)
|
||||
go func() {
|
||||
|
@ -89,7 +87,7 @@ func TestTailnet(t *testing.T) {
|
|||
default:
|
||||
}
|
||||
})
|
||||
node := <-nodes
|
||||
node := testutil.RequireRecvCtx(ctx, t, nodes)
|
||||
// Ensure this connected over DERP!
|
||||
require.Len(t, node.DERPForcedWebsocket, 0)
|
||||
|
||||
|
@ -99,6 +97,7 @@ func TestTailnet(t *testing.T) {
|
|||
|
||||
t.Run("ForcesWebSockets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
w1IP := tailnet.IP()
|
||||
|
@ -122,14 +121,8 @@ func TestTailnet(t *testing.T) {
|
|||
_ = w1.Close()
|
||||
_ = w2.Close()
|
||||
})
|
||||
w1.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := w2.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
w2.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := w1.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
stitch(t, w2, w1)
|
||||
stitch(t, w1, w2)
|
||||
require.True(t, w2.AwaitReachable(ctx, w1IP))
|
||||
conn := make(chan struct{}, 1)
|
||||
go func() {
|
||||
|
@ -243,11 +236,16 @@ func TestConn_UpdateDERP(t *testing.T) {
|
|||
err := client1.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
client1.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := conn.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false)
|
||||
stitch(t, conn, client1)
|
||||
pn, err := tailnet.NodeToProto(conn.Node())
|
||||
require.NoError(t, err)
|
||||
connID := uuid.New()
|
||||
err = client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: connID[:],
|
||||
Node: pn,
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
awaitReachableCtx1, awaitReachableCancel1 := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer awaitReachableCancel1()
|
||||
|
@ -288,7 +286,13 @@ parentLoop:
|
|||
|
||||
// ... unless the client updates it's derp map and nodes.
|
||||
client1.SetDERPMap(derpMap2)
|
||||
client1.UpdateNodes([]*tailnet.Node{conn.Node()}, false)
|
||||
pn, err = tailnet.NodeToProto(conn.Node())
|
||||
require.NoError(t, err)
|
||||
client1.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: connID[:],
|
||||
Node: pn,
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}})
|
||||
awaitReachableCtx3, awaitReachableCancel3 := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer awaitReachableCancel3()
|
||||
require.True(t, client1.AwaitReachable(awaitReachableCtx3, ip))
|
||||
|
@ -306,13 +310,34 @@ parentLoop:
|
|||
err := client2.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
client2.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := conn.UpdateNodes([]*tailnet.Node{node}, false)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
client2.UpdateNodes([]*tailnet.Node{conn.Node()}, false)
|
||||
stitch(t, conn, client2)
|
||||
pn, err = tailnet.NodeToProto(conn.Node())
|
||||
require.NoError(t, err)
|
||||
client2.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: connID[:],
|
||||
Node: pn,
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}})
|
||||
|
||||
awaitReachableCtx4, awaitReachableCancel4 := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer awaitReachableCancel4()
|
||||
require.True(t, client2.AwaitReachable(awaitReachableCtx4, ip))
|
||||
}
|
||||
|
||||
// stitch sends node updates from src Conn as peer updates to dst Conn. Sort of
|
||||
// like the Coordinator would, but without actually needing a Coordinator.
|
||||
func stitch(t *testing.T, dst, src *tailnet.Conn) {
|
||||
srcID := uuid.New()
|
||||
src.SetNodeCallback(func(node *tailnet.Node) {
|
||||
pn, err := tailnet.NodeToProto(node)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
err = dst.UpdatePeers([]*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: srcID[:],
|
||||
Node: pn,
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
}})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tailnet
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -92,6 +93,237 @@ type Node struct {
|
|||
Endpoints []string `json:"endpoints"`
|
||||
}
|
||||
|
||||
// Coordinatee is something that can be coordinated over the Coordinate protocol. Usually this is a
|
||||
// Conn.
|
||||
type Coordinatee interface {
|
||||
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
|
||||
SetNodeCallback(func(*Node))
|
||||
}
|
||||
|
||||
type Coordination interface {
|
||||
io.Closer
|
||||
Error() <-chan error
|
||||
}
|
||||
|
||||
type remoteCoordination struct {
|
||||
sync.Mutex
|
||||
closed bool
|
||||
errChan chan error
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
protocol proto.DRPCTailnet_CoordinateClient
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) Close() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("send disconnect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) Error() <-chan error {
|
||||
return c.errChan
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) sendErr(err error) {
|
||||
select {
|
||||
case c.errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) respLoop() {
|
||||
for {
|
||||
resp, err := c.protocol.Recv()
|
||||
if err != nil {
|
||||
c.sendErr(xerrors.Errorf("read: %w", err))
|
||||
return
|
||||
}
|
||||
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
c.sendErr(xerrors.Errorf("update peers: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
|
||||
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
|
||||
// a client---agents should NOT set this!).
|
||||
func NewRemoteCoordination(logger slog.Logger,
|
||||
protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee,
|
||||
tunnelTarget uuid.UUID,
|
||||
) Coordination {
|
||||
c := &remoteCoordination{
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
protocol: protocol,
|
||||
}
|
||||
if tunnelTarget != uuid.Nil {
|
||||
c.Lock()
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
|
||||
c.Unlock()
|
||||
if err != nil {
|
||||
c.sendErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
coordinatee.SetNodeCallback(func(node *Node) {
|
||||
pn, err := NodeToProto(node)
|
||||
if err != nil {
|
||||
c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
c.sendErr(err)
|
||||
return
|
||||
}
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
c.logger.Debug(context.Background(), "ignored node update because coordination is closed")
|
||||
return
|
||||
}
|
||||
err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
|
||||
if err != nil {
|
||||
c.sendErr(xerrors.Errorf("write: %w", err))
|
||||
}
|
||||
})
|
||||
go c.respLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
type inMemoryCoordination struct {
|
||||
sync.Mutex
|
||||
ctx context.Context
|
||||
errChan chan error
|
||||
closed bool
|
||||
closedCh chan struct{}
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
resps <-chan *proto.CoordinateResponse
|
||||
reqs chan<- *proto.CoordinateRequest
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) sendErr(err error) {
|
||||
select {
|
||||
case c.errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) Error() <-chan error {
|
||||
return c.errChan
|
||||
}
|
||||
|
||||
// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing
|
||||
// or local clients. Set ClientID to uuid.Nil for an agent.
|
||||
func NewInMemoryCoordination(
|
||||
ctx context.Context, logger slog.Logger,
|
||||
clientID, agentID uuid.UUID,
|
||||
coordinator Coordinator, coordinatee Coordinatee,
|
||||
) Coordination {
|
||||
thisID := agentID
|
||||
logger = logger.With(slog.F("agent_id", agentID))
|
||||
var auth TunnelAuth = AgentTunnelAuth{}
|
||||
if clientID != uuid.Nil {
|
||||
// this is a client connection
|
||||
auth = ClientTunnelAuth{AgentID: agentID}
|
||||
logger = logger.With(slog.F("client_id", clientID))
|
||||
thisID = clientID
|
||||
}
|
||||
c := &inMemoryCoordination{
|
||||
ctx: ctx,
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
closedCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// use the background context since we will depend exclusively on closing the req channel to
|
||||
// tell the coordinator we are done.
|
||||
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
|
||||
thisID, fmt.Sprintf("inmemory%s", thisID),
|
||||
auth,
|
||||
)
|
||||
go c.respLoop()
|
||||
if agentID != uuid.Nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err()))
|
||||
return c
|
||||
case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
coordinatee.SetNodeCallback(func(n *Node) {
|
||||
pn, err := NodeToProto(n)
|
||||
if err != nil {
|
||||
c.logger.Critical(ctx, "failed to convert node", slog.Error(err))
|
||||
c.sendErr(err)
|
||||
return
|
||||
}
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Info(ctx, "context expired before sending node update")
|
||||
return
|
||||
case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}:
|
||||
c.logger.Debug(ctx, "sent node in-memory to coordinator")
|
||||
}
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) respLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-c.closedCh:
|
||||
c.logger.Debug(context.Background(), "in-memory coordination closed")
|
||||
return
|
||||
case resp, ok := <-c.resps:
|
||||
if !ok {
|
||||
c.logger.Debug(context.Background(), "in-memory response channel closed")
|
||||
return
|
||||
}
|
||||
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
|
||||
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) Close() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.logger.Debug(context.Background(), "closing in-memory coordination")
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
defer close(c.reqs)
|
||||
c.closed = true
|
||||
close(c.closedCh)
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
||||
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
|
||||
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
|
||||
func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) {
|
||||
errChan := make(chan error, 1)
|
||||
|
@ -237,21 +469,17 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge
|
|||
}
|
||||
return false
|
||||
},
|
||||
OnSubscribe: func(enq Queue, agent uuid.UUID) (*Node, error) {
|
||||
OnSubscribe: func(enq Queue, agent uuid.UUID) error {
|
||||
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
|
||||
return c.Node(agent), err
|
||||
return err
|
||||
},
|
||||
OnUnsubscribe: func(enq Queue, agent uuid.UUID) error {
|
||||
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}})
|
||||
return err
|
||||
},
|
||||
OnNodeUpdate: func(id uuid.UUID, node *Node) error {
|
||||
pn, err := NodeToProto(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
OnNodeUpdate: func(id uuid.UUID, node *proto.Node) error {
|
||||
return SendCtx(ctx, reqs, &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
||||
Node: pn,
|
||||
Node: node,
|
||||
}})
|
||||
},
|
||||
OnRemove: func(_ Queue) {
|
||||
|
@ -285,7 +513,7 @@ const (
|
|||
type Queue interface {
|
||||
UniqueID() uuid.UUID
|
||||
Kind() QueueKind
|
||||
Enqueue(n []*Node) error
|
||||
Enqueue(resp *proto.CoordinateResponse) error
|
||||
Name() string
|
||||
Stats() (start, lastWrite int64)
|
||||
Overwrites() int64
|
||||
|
@ -793,18 +1021,7 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
|
|||
return
|
||||
}
|
||||
logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp))
|
||||
nodes, err := OnlyNodeUpdates(resp)
|
||||
if err != nil {
|
||||
logger.Critical(ctx, "v1RespLoop failed to decode resp", slog.F("resp", resp), slog.Error(err))
|
||||
_ = q.CoordinatorClose()
|
||||
return
|
||||
}
|
||||
// don't send empty updates
|
||||
if len(nodes) == 0 {
|
||||
logger.Debug(ctx, "v1RespLoop skipping enqueueing 0-length v1 update")
|
||||
continue
|
||||
}
|
||||
err = q.Enqueue(nodes)
|
||||
err = q.Enqueue(resp)
|
||||
if err != nil && !xerrors.Is(err, context.Canceled) {
|
||||
logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err))
|
||||
}
|
||||
|
|
|
@ -8,13 +8,15 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
type MultiAgentConn interface {
|
||||
UpdateSelf(node *Node) error
|
||||
UpdateSelf(node *proto.Node) error
|
||||
SubscribeAgent(agentID uuid.UUID) error
|
||||
UnsubscribeAgent(agentID uuid.UUID) error
|
||||
NextUpdate(ctx context.Context) ([]*Node, bool)
|
||||
NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool)
|
||||
AgentIsLegacy(agentID uuid.UUID) bool
|
||||
Close() error
|
||||
IsClosed() bool
|
||||
|
@ -26,16 +28,16 @@ type MultiAgent struct {
|
|||
ID uuid.UUID
|
||||
|
||||
AgentIsLegacyFunc func(agentID uuid.UUID) bool
|
||||
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
|
||||
OnSubscribe func(enq Queue, agent uuid.UUID) error
|
||||
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
|
||||
OnNodeUpdate func(id uuid.UUID, node *Node) error
|
||||
OnNodeUpdate func(id uuid.UUID, node *proto.Node) error
|
||||
OnRemove func(enq Queue)
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
closed bool
|
||||
|
||||
updates chan []*Node
|
||||
updates chan *proto.CoordinateResponse
|
||||
closeOnce sync.Once
|
||||
start int64
|
||||
lastWrite int64
|
||||
|
@ -45,7 +47,7 @@ type MultiAgent struct {
|
|||
}
|
||||
|
||||
func (m *MultiAgent) Init() *MultiAgent {
|
||||
m.updates = make(chan []*Node, 128)
|
||||
m.updates = make(chan *proto.CoordinateResponse, 128)
|
||||
m.start = time.Now().Unix()
|
||||
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
|
||||
return m
|
||||
|
@ -65,7 +67,7 @@ func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool {
|
|||
|
||||
var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
|
||||
|
||||
func (m *MultiAgent) UpdateSelf(node *Node) error {
|
||||
func (m *MultiAgent) UpdateSelf(node *proto.Node) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
|
@ -82,15 +84,11 @@ func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
|
|||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
node, err := m.OnSubscribe(m, agentID)
|
||||
err := m.OnSubscribe(m, agentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if node != nil {
|
||||
return m.enqueueLocked([]*Node{node})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -104,17 +102,17 @@ func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
|
|||
return m.OnUnsubscribe(m, agentID)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) {
|
||||
func (m *MultiAgent) NextUpdate(ctx context.Context) (*proto.CoordinateResponse, bool) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false
|
||||
|
||||
case nodes, ok := <-m.updates:
|
||||
return nodes, ok
|
||||
case resp, ok := <-m.updates:
|
||||
return resp, ok
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Enqueue(nodes []*Node) error {
|
||||
func (m *MultiAgent) Enqueue(resp *proto.CoordinateResponse) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
|
@ -122,14 +120,14 @@ func (m *MultiAgent) Enqueue(nodes []*Node) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return m.enqueueLocked(nodes)
|
||||
return m.enqueueLocked(resp)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) enqueueLocked(nodes []*Node) error {
|
||||
func (m *MultiAgent) enqueueLocked(resp *proto.CoordinateResponse) error {
|
||||
atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
|
||||
|
||||
select {
|
||||
case m.updates <- nodes:
|
||||
case m.updates <- resp:
|
||||
return nil
|
||||
default:
|
||||
return ErrWouldBlock
|
||||
|
|
|
@ -75,7 +75,9 @@ func NewClientService(
|
|||
}
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
if xerrors.Is(err, io.EOF) ||
|
||||
xerrors.Is(err, context.Canceled) ||
|
||||
xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/tailnet (interfaces: Coordinator)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
|
||||
//
|
||||
|
||||
// Package tailnettest is a generated GoMock package.
|
||||
package tailnettest
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
http "net/http"
|
||||
reflect "reflect"
|
||||
|
||||
tailnet "github.com/coder/coder/v2/tailnet"
|
||||
proto "github.com/coder/coder/v2/tailnet/proto"
|
||||
uuid "github.com/google/uuid"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockCoordinator is a mock of Coordinator interface.
|
||||
type MockCoordinator struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCoordinatorMockRecorder
|
||||
}
|
||||
|
||||
// MockCoordinatorMockRecorder is the mock recorder for MockCoordinator.
|
||||
type MockCoordinatorMockRecorder struct {
|
||||
mock *MockCoordinator
|
||||
}
|
||||
|
||||
// NewMockCoordinator creates a new mock instance.
|
||||
func NewMockCoordinator(ctrl *gomock.Controller) *MockCoordinator {
|
||||
mock := &MockCoordinator{ctrl: ctrl}
|
||||
mock.recorder = &MockCoordinatorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockCoordinator) EXPECT() *MockCoordinatorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockCoordinator) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCoordinator)(nil).Close))
|
||||
}
|
||||
|
||||
// Coordinate mocks base method.
|
||||
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(chan<- *proto.CoordinateRequest)
|
||||
ret1, _ := ret[1].(<-chan *proto.CoordinateResponse)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Coordinate indicates an expected call of Coordinate.
|
||||
func (mr *MockCoordinatorMockRecorder) Coordinate(arg0, arg1, arg2, arg3 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Coordinate", reflect.TypeOf((*MockCoordinator)(nil).Coordinate), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// Node mocks base method.
|
||||
func (m *MockCoordinator) Node(arg0 uuid.UUID) *tailnet.Node {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Node", arg0)
|
||||
ret0, _ := ret[0].(*tailnet.Node)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Node indicates an expected call of Node.
|
||||
func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0)
|
||||
}
|
||||
|
||||
// ServeAgent mocks base method.
|
||||
func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ServeAgent indicates an expected call of ServeAgent.
|
||||
func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ServeClient mocks base method.
|
||||
func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ServeClient indicates an expected call of ServeClient.
|
||||
func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ServeHTTPDebug mocks base method.
|
||||
func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ServeHTTPDebug", arg0, arg1)
|
||||
}
|
||||
|
||||
// ServeHTTPDebug indicates an expected call of ServeHTTPDebug.
|
||||
func (mr *MockCoordinatorMockRecorder) ServeHTTPDebug(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTPDebug", reflect.TypeOf((*MockCoordinator)(nil).ServeHTTPDebug), arg0, arg1)
|
||||
}
|
||||
|
||||
// ServeMultiAgent mocks base method.
|
||||
func (m *MockCoordinator) ServeMultiAgent(arg0 uuid.UUID) tailnet.MultiAgentConn {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ServeMultiAgent", arg0)
|
||||
ret0, _ := ret[0].(tailnet.MultiAgentConn)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ServeMultiAgent indicates an expected call of ServeMultiAgent.
|
||||
func (mr *MockCoordinatorMockRecorder) ServeMultiAgent(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeMultiAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeMultiAgent), arg0)
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
|
||||
//
|
||||
|
||||
// Package tailnettest is a generated GoMock package.
|
||||
package tailnettest
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
tailnet "github.com/coder/coder/v2/tailnet"
|
||||
uuid "github.com/google/uuid"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockMultiAgentConn is a mock of MultiAgentConn interface.
|
||||
type MockMultiAgentConn struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMultiAgentConnMockRecorder
|
||||
}
|
||||
|
||||
// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn.
|
||||
type MockMultiAgentConnMockRecorder struct {
|
||||
mock *MockMultiAgentConn
|
||||
}
|
||||
|
||||
// NewMockMultiAgentConn creates a new mock instance.
|
||||
func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn {
|
||||
mock := &MockMultiAgentConn{ctrl: ctrl}
|
||||
mock.recorder = &MockMultiAgentConnMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AgentIsLegacy mocks base method.
|
||||
func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AgentIsLegacy", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AgentIsLegacy indicates an expected call of AgentIsLegacy.
|
||||
func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockMultiAgentConn) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close))
|
||||
}
|
||||
|
||||
// IsClosed mocks base method.
|
||||
func (m *MockMultiAgentConn) IsClosed() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClosed")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsClosed indicates an expected call of IsClosed.
|
||||
func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed))
|
||||
}
|
||||
|
||||
// NextUpdate mocks base method.
|
||||
func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "NextUpdate", arg0)
|
||||
ret0, _ := ret[0].([]*tailnet.Node)
|
||||
ret1, _ := ret[1].(bool)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// NextUpdate indicates an expected call of NextUpdate.
|
||||
func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0)
|
||||
}
|
||||
|
||||
// SubscribeAgent mocks base method.
|
||||
func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SubscribeAgent", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SubscribeAgent indicates an expected call of SubscribeAgent.
|
||||
func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0)
|
||||
}
|
||||
|
||||
// UnsubscribeAgent mocks base method.
|
||||
func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnsubscribeAgent indicates an expected call of UnsubscribeAgent.
|
||||
func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0)
|
||||
}
|
||||
|
||||
// UpdateSelf mocks base method.
|
||||
func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateSelf", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateSelf indicates an expected call of UpdateSelf.
|
||||
func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0)
|
||||
}
|
|
@ -21,7 +21,7 @@ import (
|
|||
"github.com/coder/coder/v2/tailnet"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn
|
||||
//go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator
|
||||
|
||||
// RunDERPAndSTUN creates a DERP mapping for tests.
|
||||
func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -29,7 +30,7 @@ type TrackedConn struct {
|
|||
cancel func()
|
||||
kind QueueKind
|
||||
conn net.Conn
|
||||
updates chan []*Node
|
||||
updates chan *proto.CoordinateResponse
|
||||
logger slog.Logger
|
||||
lastData []byte
|
||||
|
||||
|
@ -55,7 +56,7 @@ func NewTrackedConn(ctx context.Context, cancel func(),
|
|||
// coordinator mutex while queuing. Node updates don't
|
||||
// come quickly, so 512 should be plenty for all but
|
||||
// the most pathological cases.
|
||||
updates := make(chan []*Node, ResponseBufferSize)
|
||||
updates := make(chan *proto.CoordinateResponse, ResponseBufferSize)
|
||||
now := time.Now().Unix()
|
||||
return &TrackedConn{
|
||||
ctx: ctx,
|
||||
|
@ -72,10 +73,10 @@ func NewTrackedConn(ctx context.Context, cancel func(),
|
|||
}
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Enqueue(n []*Node) (err error) {
|
||||
func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) {
|
||||
atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
|
||||
select {
|
||||
case t.updates <- n:
|
||||
case t.updates <- resp:
|
||||
return nil
|
||||
default:
|
||||
return ErrWouldBlock
|
||||
|
@ -124,7 +125,16 @@ func (t *TrackedConn) SendUpdates() {
|
|||
case <-t.ctx.Done():
|
||||
t.logger.Debug(t.ctx, "done sending updates")
|
||||
return
|
||||
case nodes := <-t.updates:
|
||||
case resp := <-t.updates:
|
||||
nodes, err := OnlyNodeUpdates(resp)
|
||||
if err != nil {
|
||||
t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
t.logger.Debug(t.ctx, "skipping response with no nodes")
|
||||
continue
|
||||
}
|
||||
data, err := json.Marshal(nodes)
|
||||
if err != nil {
|
||||
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
|
||||
|
|
Loading…
Reference in New Issue