mirror of https://github.com/coder/coder.git
feat: set peers lost when disconnected from coordinator (#11681)
Adds support to Coordination to call SetAllPeersLost() when it is closed. This ensure that when we disconnect from a Coordinator, we set all peers lost. This covers CoderSDK (CLI client) and Agent. Next PR will cover MultiAgent (notably, `wsproxy`).
This commit is contained in:
parent
9f6b38ce9c
commit
3d85cdfa11
|
@ -356,6 +356,11 @@ func (c *Conn) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllPeersLost marks all peers lost; typically used when we disconnect from a coordinator.
|
||||||
|
func (c *Conn) SetAllPeersLost() {
|
||||||
|
c.configMaps.setAllPeersLost()
|
||||||
|
}
|
||||||
|
|
||||||
// NodeAddresses returns the addresses of a node from the NetworkMap.
|
// NodeAddresses returns the addresses of a node from the NetworkMap.
|
||||||
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
|
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
|
||||||
return c.configMaps.nodeAddresses(publicKey)
|
return c.configMaps.nodeAddresses(publicKey)
|
||||||
|
|
|
@ -97,6 +97,7 @@ type Node struct {
|
||||||
// Conn.
|
// Conn.
|
||||||
type Coordinatee interface {
|
type Coordinatee interface {
|
||||||
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
|
UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error
|
||||||
|
SetAllPeersLost()
|
||||||
SetNodeCallback(func(*Node))
|
SetNodeCallback(func(*Node))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,20 +108,28 @@ type Coordination interface {
|
||||||
|
|
||||||
type remoteCoordination struct {
|
type remoteCoordination struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
errChan chan error
|
errChan chan error
|
||||||
coordinatee Coordinatee
|
coordinatee Coordinatee
|
||||||
logger slog.Logger
|
logger slog.Logger
|
||||||
protocol proto.DRPCTailnet_CoordinateClient
|
protocol proto.DRPCTailnet_CoordinateClient
|
||||||
|
respLoopDone chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *remoteCoordination) Close() error {
|
func (c *remoteCoordination) Close() (retErr error) {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c.closed = true
|
c.closed = true
|
||||||
|
defer func() {
|
||||||
|
protoErr := c.protocol.Close()
|
||||||
|
<-c.respLoopDone
|
||||||
|
if retErr == nil {
|
||||||
|
retErr = protoErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("send disconnect: %w", err)
|
return xerrors.Errorf("send disconnect: %w", err)
|
||||||
|
@ -140,6 +149,10 @@ func (c *remoteCoordination) sendErr(err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *remoteCoordination) respLoop() {
|
func (c *remoteCoordination) respLoop() {
|
||||||
|
defer func() {
|
||||||
|
c.coordinatee.SetAllPeersLost()
|
||||||
|
close(c.respLoopDone)
|
||||||
|
}()
|
||||||
for {
|
for {
|
||||||
resp, err := c.protocol.Recv()
|
resp, err := c.protocol.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -162,10 +175,11 @@ func NewRemoteCoordination(logger slog.Logger,
|
||||||
tunnelTarget uuid.UUID,
|
tunnelTarget uuid.UUID,
|
||||||
) Coordination {
|
) Coordination {
|
||||||
c := &remoteCoordination{
|
c := &remoteCoordination{
|
||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
coordinatee: coordinatee,
|
coordinatee: coordinatee,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
protocol: protocol,
|
protocol: protocol,
|
||||||
|
respLoopDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
if tunnelTarget != uuid.Nil {
|
if tunnelTarget != uuid.Nil {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
|
@ -200,14 +214,15 @@ func NewRemoteCoordination(logger slog.Logger,
|
||||||
|
|
||||||
type inMemoryCoordination struct {
|
type inMemoryCoordination struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
errChan chan error
|
errChan chan error
|
||||||
closed bool
|
closed bool
|
||||||
closedCh chan struct{}
|
closedCh chan struct{}
|
||||||
coordinatee Coordinatee
|
respLoopDone chan struct{}
|
||||||
logger slog.Logger
|
coordinatee Coordinatee
|
||||||
resps <-chan *proto.CoordinateResponse
|
logger slog.Logger
|
||||||
reqs chan<- *proto.CoordinateRequest
|
resps <-chan *proto.CoordinateResponse
|
||||||
|
reqs chan<- *proto.CoordinateRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *inMemoryCoordination) sendErr(err error) {
|
func (c *inMemoryCoordination) sendErr(err error) {
|
||||||
|
@ -238,11 +253,12 @@ func NewInMemoryCoordination(
|
||||||
thisID = clientID
|
thisID = clientID
|
||||||
}
|
}
|
||||||
c := &inMemoryCoordination{
|
c := &inMemoryCoordination{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
coordinatee: coordinatee,
|
coordinatee: coordinatee,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
closedCh: make(chan struct{}),
|
closedCh: make(chan struct{}),
|
||||||
|
respLoopDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// use the background context since we will depend exclusively on closing the req channel to
|
// use the background context since we will depend exclusively on closing the req channel to
|
||||||
|
@ -285,6 +301,10 @@ func NewInMemoryCoordination(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *inMemoryCoordination) respLoop() {
|
func (c *inMemoryCoordination) respLoop() {
|
||||||
|
defer func() {
|
||||||
|
c.coordinatee.SetAllPeersLost()
|
||||||
|
close(c.respLoopDone)
|
||||||
|
}()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.closedCh:
|
case <-c.closedCh:
|
||||||
|
@ -315,6 +335,7 @@ func (c *inMemoryCoordination) Close() error {
|
||||||
defer close(c.reqs)
|
defer close(c.reqs)
|
||||||
c.closed = true
|
c.closed = true
|
||||||
close(c.closedCh)
|
close(c.closedCh)
|
||||||
|
<-c.respLoopDone
|
||||||
select {
|
select {
|
||||||
case <-c.ctx.Done():
|
case <-c.ctx.Done():
|
||||||
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
||||||
|
|
|
@ -6,19 +6,24 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"nhooyr.io/websocket"
|
|
||||||
|
|
||||||
"cdr.dev/slog"
|
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/mock/gomock"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||||
"github.com/coder/coder/v2/tailnet/test"
|
"github.com/coder/coder/v2/tailnet/test"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
@ -400,3 +405,160 @@ func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server n
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
return client, server
|
return client, server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInMemoryCoordination(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
clientID := uuid.UUID{1}
|
||||||
|
agentID := uuid.UUID{2}
|
||||||
|
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||||
|
fConn := &fakeCoordinatee{}
|
||||||
|
|
||||||
|
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||||
|
resps := make(chan *proto.CoordinateResponse, 100)
|
||||||
|
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||||
|
Times(1).Return(reqs, resps)
|
||||||
|
|
||||||
|
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
||||||
|
defer uut.Close()
|
||||||
|
|
||||||
|
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-uut.Error():
|
||||||
|
require.NoError(t, err)
|
||||||
|
default:
|
||||||
|
// OK!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoteCoordination(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
clientID := uuid.UUID{1}
|
||||||
|
agentID := uuid.UUID{2}
|
||||||
|
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
||||||
|
fConn := &fakeCoordinatee{}
|
||||||
|
|
||||||
|
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||||
|
resps := make(chan *proto.CoordinateResponse, 100)
|
||||||
|
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||||
|
Times(1).Return(reqs, resps)
|
||||||
|
|
||||||
|
var coord tailnet.Coordinator = mCoord
|
||||||
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||||
|
coordPtr.Store(&coord)
|
||||||
|
svc, err := tailnet.NewClientService(
|
||||||
|
logger.Named("svc"), &coordPtr,
|
||||||
|
time.Hour,
|
||||||
|
func() *tailcfg.DERPMap { panic("not implemented") },
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
sC, cC := net.Pipe()
|
||||||
|
|
||||||
|
serveErr := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
|
||||||
|
serveErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := tailnet.NewDRPCClient(cC)
|
||||||
|
require.NoError(t, err)
|
||||||
|
protocol, err := client.Coordinate(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
|
||||||
|
defer uut.Close()
|
||||||
|
|
||||||
|
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-uut.Error():
|
||||||
|
require.ErrorContains(t, err, "stream terminated by sending close")
|
||||||
|
default:
|
||||||
|
// OK!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// coordinationTest tests that a coordination behaves correctly
|
||||||
|
func coordinationTest(
|
||||||
|
ctx context.Context, t *testing.T,
|
||||||
|
uut tailnet.Coordination, fConn *fakeCoordinatee,
|
||||||
|
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
|
||||||
|
agentID uuid.UUID,
|
||||||
|
) {
|
||||||
|
// It should add the tunnel, since we configured as a client
|
||||||
|
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||||
|
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
|
||||||
|
|
||||||
|
// when we call the callback, it should send a node update
|
||||||
|
require.NotNil(t, fConn.callback)
|
||||||
|
fConn.callback(&tailnet.Node{PreferredDERP: 1})
|
||||||
|
|
||||||
|
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||||
|
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
|
||||||
|
|
||||||
|
// When we send a peer update, it should update the coordinatee
|
||||||
|
nk, err := key.NewNode().Public().MarshalBinary()
|
||||||
|
require.NoError(t, err)
|
||||||
|
dk, err := key.NewDisco().Public().MarshalText()
|
||||||
|
require.NoError(t, err)
|
||||||
|
updates := []*proto.CoordinateResponse_PeerUpdate{
|
||||||
|
{
|
||||||
|
Id: agentID[:],
|
||||||
|
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||||
|
Node: &proto.Node{
|
||||||
|
Id: 2,
|
||||||
|
Key: nk,
|
||||||
|
Disco: string(dk),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
fConn.Lock()
|
||||||
|
defer fConn.Unlock()
|
||||||
|
return len(fConn.updates) > 0
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
require.Len(t, fConn.updates[0], 1)
|
||||||
|
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
||||||
|
|
||||||
|
err = uut.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
uut.Error()
|
||||||
|
|
||||||
|
// When we close, it should gracefully disconnect
|
||||||
|
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||||
|
require.NotNil(t, req.Disconnect)
|
||||||
|
|
||||||
|
// It should set all peers lost on the coordinatee
|
||||||
|
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCoordinatee struct {
|
||||||
|
sync.Mutex
|
||||||
|
callback func(*tailnet.Node)
|
||||||
|
updates [][]*proto.CoordinateResponse_PeerUpdate
|
||||||
|
setAllPeersLostCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
f.updates = append(f.updates, updates)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCoordinatee) SetAllPeersLost() {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
f.setAllPeersLostCalls++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
f.callback = callback
|
||||||
|
}
|
||||||
|
|
|
@ -22,3 +22,13 @@ func RequireRecvCtx[A any](ctx context.Context, t testing.TB, c <-chan A) (a A)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RequireSendCtx[A any](ctx context.Context, t testing.TB, c chan<- A, a A) {
|
||||||
|
t.Helper()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
case c <- a:
|
||||||
|
// OK!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue