From 3d85cdfa11e3f3b7fff119d0246c480ed4f53989 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 22 Jan 2024 15:26:20 +0400 Subject: [PATCH] feat: set peers lost when disconnected from coordinator (#11681) Adds support to Coordination to call SetAllPeersLost() when it is closed. This ensure that when we disconnect from a Coordinator, we set all peers lost. This covers CoderSDK (CLI client) and Agent. Next PR will cover MultiAgent (notably, `wsproxy`). --- tailnet/conn.go | 5 ++ tailnet/coordinator.go | 67 +++++++++----- tailnet/coordinator_test.go | 172 ++++++++++++++++++++++++++++++++++-- testutil/ctx.go | 10 +++ 4 files changed, 226 insertions(+), 28 deletions(-) diff --git a/tailnet/conn.go b/tailnet/conn.go index 0b4b942f8f..4048567946 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error return nil } +// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator. +func (c *Conn) SetAllPeersLost() { + c.configMaps.setAllPeersLost() +} + // NodeAddresses returns the addresses of a node from the NetworkMap. func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) { return c.configMaps.nodeAddresses(publicKey) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 0fa62fc922..3c4b1aeb24 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -97,6 +97,7 @@ type Node struct { // Conn. type Coordinatee interface { UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error + SetAllPeersLost() SetNodeCallback(func(*Node)) } @@ -107,20 +108,28 @@ type Coordination interface { type remoteCoordination struct { sync.Mutex - closed bool - errChan chan error - coordinatee Coordinatee - logger slog.Logger - protocol proto.DRPCTailnet_CoordinateClient + closed bool + errChan chan error + coordinatee Coordinatee + logger slog.Logger + protocol proto.DRPCTailnet_CoordinateClient + respLoopDone chan struct{} } -func (c *remoteCoordination) Close() error { +func (c *remoteCoordination) Close() (retErr error) { c.Lock() defer c.Unlock() if c.closed { return nil } c.closed = true + defer func() { + protoErr := c.protocol.Close() + <-c.respLoopDone + if retErr == nil { + retErr = protoErr + } + }() err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) if err != nil { return xerrors.Errorf("send disconnect: %w", err) @@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) { } func (c *remoteCoordination) respLoop() { + defer func() { + c.coordinatee.SetAllPeersLost() + close(c.respLoopDone) + }() for { resp, err := c.protocol.Recv() if err != nil { @@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger, tunnelTarget uuid.UUID, ) Coordination { c := &remoteCoordination{ - errChan: make(chan error, 1), - coordinatee: coordinatee, - logger: logger, - protocol: protocol, + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + protocol: protocol, + respLoopDone: make(chan struct{}), } if tunnelTarget != uuid.Nil { c.Lock() @@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger, type inMemoryCoordination struct { sync.Mutex - ctx context.Context - errChan chan error - closed bool - closedCh chan struct{} - coordinatee Coordinatee - logger slog.Logger - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest + ctx context.Context + errChan chan error + closed bool + closedCh chan struct{} + respLoopDone chan struct{} + coordinatee Coordinatee + logger slog.Logger + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest } func (c *inMemoryCoordination) sendErr(err error) { @@ -238,11 +253,12 @@ func NewInMemoryCoordination( thisID = clientID } c := &inMemoryCoordination{ - ctx: ctx, - errChan: make(chan error, 1), - coordinatee: coordinatee, - logger: logger, - closedCh: make(chan struct{}), + ctx: ctx, + errChan: make(chan error, 1), + coordinatee: coordinatee, + logger: logger, + closedCh: make(chan struct{}), + respLoopDone: make(chan struct{}), } // use the background context since we will depend exclusively on closing the req channel to @@ -285,6 +301,10 @@ func NewInMemoryCoordination( } func (c *inMemoryCoordination) respLoop() { + defer func() { + c.coordinatee.SetAllPeersLost() + close(c.respLoopDone) + }() for { select { case <-c.closedCh: @@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error { defer close(c.reqs) c.closed = true close(c.closedCh) + <-c.respLoopDone select { case <-c.ctx.Done(): return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 1aad59e5a2..7207f93d78 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -6,19 +6,24 @@ import ( "net" "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" "time" - "nhooyr.io/websocket" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "nhooyr.io/websocket" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" ) @@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n require.True(t, ok) return client, server } + +func TestInMemoryCoordination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). + Times(1).Return(reqs, resps) + + uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) + defer uut.Close() + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + select { + case err := <-uut.Error(): + require.NoError(t, err) + default: + // OK! + } +} + +func TestRemoteCoordination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). + Times(1).Return(reqs, resps) + + var coord tailnet.Coordinator = mCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + svc, err := tailnet.NewClientService( + logger.Named("svc"), &coordPtr, + time.Hour, + func() *tailcfg.DERPMap { panic("not implemented") }, + ) + require.NoError(t, err) + sC, cC := net.Pipe() + + serveErr := make(chan error, 1) + go func() { + err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID) + serveErr <- err + }() + + client, err := tailnet.NewDRPCClient(cC) + require.NoError(t, err) + protocol, err := client.Coordinate(ctx) + require.NoError(t, err) + + uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) + defer uut.Close() + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + select { + case err := <-uut.Error(): + require.ErrorContains(t, err, "stream terminated by sending close") + default: + // OK! + } +} + +// coordinationTest tests that a coordination behaves correctly +func coordinationTest( + ctx context.Context, t *testing.T, + uut tailnet.Coordination, fConn *fakeCoordinatee, + reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse, + agentID uuid.UUID, +) { + // It should add the tunnel, since we configured as a client + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) + + // when we call the callback, it should send a node update + require.NotNil(t, fConn.callback) + fConn.callback(&tailnet.Node{PreferredDERP: 1}) + + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp()) + + // When we send a peer update, it should update the coordinatee + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + updates := []*proto.CoordinateResponse_PeerUpdate{ + { + Id: agentID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{ + Id: 2, + Key: nk, + Disco: string(dk), + }, + }, + } + testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates}) + require.Eventually(t, func() bool { + fConn.Lock() + defer fConn.Unlock() + return len(fConn.updates) > 0 + }, testutil.WaitShort, testutil.IntervalFast) + require.Len(t, fConn.updates[0], 1) + require.Equal(t, agentID[:], fConn.updates[0][0].Id) + + err = uut.Close() + require.NoError(t, err) + uut.Error() + + // When we close, it should gracefully disconnect + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, req.Disconnect) + + // It should set all peers lost on the coordinatee + require.Equal(t, 1, fConn.setAllPeersLostCalls) +} + +type fakeCoordinatee struct { + sync.Mutex + callback func(*tailnet.Node) + updates [][]*proto.CoordinateResponse_PeerUpdate + setAllPeersLostCalls int +} + +func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { + f.Lock() + defer f.Unlock() + f.updates = append(f.updates, updates) + return nil +} + +func (f *fakeCoordinatee) SetAllPeersLost() { + f.Lock() + defer f.Unlock() + f.setAllPeersLostCalls++ +} + +func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) { + f.Lock() + defer f.Unlock() + f.callback = callback +} diff --git a/testutil/ctx.go b/testutil/ctx.go index 2cc44c5bad..c8f8c1769f 100644 --- a/testutil/ctx.go +++ b/testutil/ctx.go @@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A) return a } } + +func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout") + case c <- a: + // OK! + } +}