feat: set peers lost when disconnected from coordinator (#11681)

Adds support to Coordination to call SetAllPeersLost() when it is closed. This ensure that when we disconnect from a Coordinator, we set all peers lost.

This covers CoderSDK (CLI client) and Agent.  Next PR will cover MultiAgent (notably, `wsproxy`).
This commit is contained in:
Spike Curtis 2024-01-22 15:26:20 +04:00 committed by GitHub
parent 9f6b38ce9c
commit 3d85cdfa11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 226 additions and 28 deletions

View File

@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
return nil
}
// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
func (c *Conn) SetAllPeersLost() {
c.configMaps.setAllPeersLost()
}
// NodeAddresses returns the addresses of a node from the NetworkMap.
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
return c.configMaps.nodeAddresses(publicKey)

View File

@ -97,6 +97,7 @@ type Node struct {
// Conn.
type Coordinatee interface {
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
SetAllPeersLost()
SetNodeCallback(func(*Node))
}
@ -107,20 +108,28 @@ type Coordination interface {
type remoteCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
respLoopDone chan struct{}
}
func (c *remoteCoordination) Close() error {
func (c *remoteCoordination) Close() (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
protoErr := c.protocol.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
}
func (c *remoteCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
resp, err := c.protocol.Recv()
if err != nil {
@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
tunnelTarget uuid.UUID,
) Coordination {
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
protocol: protocol,
respLoopDone: make(chan struct{}),
}
if tunnelTarget != uuid.Nil {
c.Lock()
@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,
type inMemoryCoordination struct {
sync.Mutex
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
ctx context.Context
errChan chan error
closed bool
closedCh chan struct{}
respLoopDone chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}
func (c *inMemoryCoordination) sendErr(err error) {
@ -238,11 +253,12 @@ func NewInMemoryCoordination(
thisID = clientID
}
c := &inMemoryCoordination{
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
closedCh: make(chan struct{}),
respLoopDone: make(chan struct{}),
}
// use the background context since we will depend exclusively on closing the req channel to
@ -285,6 +301,10 @@ func NewInMemoryCoordination(
}
func (c *inMemoryCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
select {
case <-c.closedCh:
@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
defer close(c.reqs)
c.closed = true
close(c.closedCh)
<-c.respLoopDone
select {
case <-c.ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())

View File

@ -6,19 +6,24 @@ import (
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"nhooyr.io/websocket"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
require.True(t, ok)
return client, server
}
func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.NoError(t, err)
default:
// OK!
}
}
func TestRemoteCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(
logger.Named("svc"), &coordPtr,
time.Hour,
func() *tailcfg.DERPMap { panic("not implemented") },
)
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close()
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.ErrorContains(t, err, "stream terminated by sending close")
default:
// OK!
}
}
// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
err = uut.Close()
require.NoError(t, err)
uut.Error()
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}
type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
}
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}
func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}

View File

@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
return a
}
}
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout")
case c <- a:
// OK!
}
}