mirror of https://github.com/coder/coder.git
feat: set peers lost when disconnected from coordinator (#11681)
Adds support to Coordination to call SetAllPeersLost() when it is closed. This ensure that when we disconnect from a Coordinator, we set all peers lost. This covers CoderSDK (CLI client) and Agent. Next PR will cover MultiAgent (notably, `wsproxy`).
This commit is contained in:
parent
9f6b38ce9c
commit
3d85cdfa11
|
@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
|
||||
func (c *Conn) SetAllPeersLost() {
|
||||
c.configMaps.setAllPeersLost()
|
||||
}
|
||||
|
||||
// NodeAddresses returns the addresses of a node from the NetworkMap.
|
||||
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
|
||||
return c.configMaps.nodeAddresses(publicKey)
|
||||
|
|
|
@ -97,6 +97,7 @@ type Node struct {
|
|||
// Conn.
|
||||
type Coordinatee interface {
|
||||
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
|
||||
SetAllPeersLost()
|
||||
SetNodeCallback(func(*Node))
|
||||
}
|
||||
|
||||
|
@ -107,20 +108,28 @@ type Coordination interface {
|
|||
|
||||
type remoteCoordination struct {
|
||||
sync.Mutex
|
||||
closed bool
|
||||
errChan chan error
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
protocol proto.DRPCTailnet_CoordinateClient
|
||||
closed bool
|
||||
errChan chan error
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
protocol proto.DRPCTailnet_CoordinateClient
|
||||
respLoopDone chan struct{}
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) Close() error {
|
||||
func (c *remoteCoordination) Close() (retErr error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
defer func() {
|
||||
protoErr := c.protocol.Close()
|
||||
<-c.respLoopDone
|
||||
if retErr == nil {
|
||||
retErr = protoErr
|
||||
}
|
||||
}()
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("send disconnect: %w", err)
|
||||
|
@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
|
|||
}
|
||||
|
||||
func (c *remoteCoordination) respLoop() {
|
||||
defer func() {
|
||||
c.coordinatee.SetAllPeersLost()
|
||||
close(c.respLoopDone)
|
||||
}()
|
||||
for {
|
||||
resp, err := c.protocol.Recv()
|
||||
if err != nil {
|
||||
|
@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
|
|||
tunnelTarget uuid.UUID,
|
||||
) Coordination {
|
||||
c := &remoteCoordination{
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
protocol: protocol,
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
protocol: protocol,
|
||||
respLoopDone: make(chan struct{}),
|
||||
}
|
||||
if tunnelTarget != uuid.Nil {
|
||||
c.Lock()
|
||||
|
@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,
|
|||
|
||||
type inMemoryCoordination struct {
|
||||
sync.Mutex
|
||||
ctx context.Context
|
||||
errChan chan error
|
||||
closed bool
|
||||
closedCh chan struct{}
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
resps <-chan *proto.CoordinateResponse
|
||||
reqs chan<- *proto.CoordinateRequest
|
||||
ctx context.Context
|
||||
errChan chan error
|
||||
closed bool
|
||||
closedCh chan struct{}
|
||||
respLoopDone chan struct{}
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
resps <-chan *proto.CoordinateResponse
|
||||
reqs chan<- *proto.CoordinateRequest
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) sendErr(err error) {
|
||||
|
@ -238,11 +253,12 @@ func NewInMemoryCoordination(
|
|||
thisID = clientID
|
||||
}
|
||||
c := &inMemoryCoordination{
|
||||
ctx: ctx,
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
closedCh: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
closedCh: make(chan struct{}),
|
||||
respLoopDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// use the background context since we will depend exclusively on closing the req channel to
|
||||
|
@ -285,6 +301,10 @@ func NewInMemoryCoordination(
|
|||
}
|
||||
|
||||
func (c *inMemoryCoordination) respLoop() {
|
||||
defer func() {
|
||||
c.coordinatee.SetAllPeersLost()
|
||||
close(c.respLoopDone)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.closedCh:
|
||||
|
@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
|
|||
defer close(c.reqs)
|
||||
c.closed = true
|
||||
close(c.closedCh)
|
||||
<-c.respLoopDone
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
||||
|
|
|
@ -6,19 +6,24 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/tailnet/test"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
|
|||
require.True(t, ok)
|
||||
return client, server
|
||||
}
|
||||
|
||||
func TestInMemoryCoordination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clientID := uuid.UUID{1}
|
||||
agentID := uuid.UUID{2}
|
||||
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||
fConn := &fakeCoordinatee{}
|
||||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
||||
defer uut.Close()
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
|
||||
select {
|
||||
case err := <-uut.Error():
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteCoordination(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
clientID := uuid.UUID{1}
|
||||
agentID := uuid.UUID{2}
|
||||
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||
fConn := &fakeCoordinatee{}
|
||||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
var coord tailnet.Coordinator = mCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
svc, err := tailnet.NewClientService(
|
||||
logger.Named("svc"), &coordPtr,
|
||||
time.Hour,
|
||||
func() *tailcfg.DERPMap { panic("not implemented") },
|
||||
)
|
||||
require.NoError(t, err)
|
||||
sC, cC := net.Pipe()
|
||||
|
||||
serveErr := make(chan error, 1)
|
||||
go func() {
|
||||
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
|
||||
serveErr <- err
|
||||
}()
|
||||
|
||||
client, err := tailnet.NewDRPCClient(cC)
|
||||
require.NoError(t, err)
|
||||
protocol, err := client.Coordinate(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
|
||||
defer uut.Close()
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
|
||||
select {
|
||||
case err := <-uut.Error():
|
||||
require.ErrorContains(t, err, "stream terminated by sending close")
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
||||
// coordinationTest tests that a coordination behaves correctly
|
||||
func coordinationTest(
|
||||
ctx context.Context, t *testing.T,
|
||||
uut tailnet.Coordination, fConn *fakeCoordinatee,
|
||||
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
|
||||
agentID uuid.UUID,
|
||||
) {
|
||||
// It should add the tunnel, since we configured as a client
|
||||
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
|
||||
|
||||
// when we call the callback, it should send a node update
|
||||
require.NotNil(t, fConn.callback)
|
||||
fConn.callback(&tailnet.Node{PreferredDERP: 1})
|
||||
|
||||
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
|
||||
|
||||
// When we send a peer update, it should update the coordinatee
|
||||
nk, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
dk, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
updates := []*proto.CoordinateResponse_PeerUpdate{
|
||||
{
|
||||
Id: agentID[:],
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Node: &proto.Node{
|
||||
Id: 2,
|
||||
Key: nk,
|
||||
Disco: string(dk),
|
||||
},
|
||||
},
|
||||
}
|
||||
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
|
||||
require.Eventually(t, func() bool {
|
||||
fConn.Lock()
|
||||
defer fConn.Unlock()
|
||||
return len(fConn.updates) > 0
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.Len(t, fConn.updates[0], 1)
|
||||
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
||||
|
||||
err = uut.Close()
|
||||
require.NoError(t, err)
|
||||
uut.Error()
|
||||
|
||||
// When we close, it should gracefully disconnect
|
||||
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.NotNil(t, req.Disconnect)
|
||||
|
||||
// It should set all peers lost on the coordinatee
|
||||
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
||||
}
|
||||
|
||||
type fakeCoordinatee struct {
|
||||
sync.Mutex
|
||||
callback func(*tailnet.Node)
|
||||
updates [][]*proto.CoordinateResponse_PeerUpdate
|
||||
setAllPeersLostCalls int
|
||||
}
|
||||
|
||||
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.updates = append(f.updates, updates)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeCoordinatee) SetAllPeersLost() {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.setAllPeersLostCalls++
|
||||
}
|
||||
|
||||
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.callback = callback
|
||||
}
|
||||
|
|
|
@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
|
|||
return a
|
||||
}
|
||||
}
|
||||
|
||||
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
case c <- a:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue