From eb12fd7d92940d2cc1011c475ae14e9c177187c1 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 23 Jan 2024 13:17:56 +0400 Subject: [PATCH] feat: make ServerTailnet set peers lost when it reconnects to the coordinator (#11682) Adds support to `ServerTailnet` to set all peers lost before attempting to reconnect to the coordinator. In practice, this only really affects `wsproxy` since coderd has a local connection to the coordinator that only goes down if we're shutting down or change licenses. --- Makefile | 8 +- coderd/tailnet.go | 15 ++- coderd/tailnet_internal_test.go | 75 +++++++++++++ tailnet/service_test.go | 2 +- tailnet/tailnettest/coordinateemock.go | 79 ++++++++++++++ tailnet/tailnettest/multiagentmock.go | 141 +++++++++++++++++++++++++ tailnet/tailnettest/tailnettest.go | 2 + 7 files changed, 315 insertions(+), 7 deletions(-) create mode 100644 coderd/tailnet_internal_test.go create mode 100644 tailnet/tailnettest/coordinateemock.go create mode 100644 tailnet/tailnettest/multiagentmock.go diff --git a/Makefile b/Makefile index d771fb0233..55f73f672f 100644 --- a/Makefile +++ b/Makefile @@ -476,7 +476,9 @@ gen: \ site/e2e/provisionerGenerated.ts \ site/src/theme/icons.json \ examples/examples.gen.json \ - tailnet/tailnettest/coordinatormock.go + tailnet/tailnettest/coordinatormock.go \ + tailnet/tailnettest/coordinateemock.go \ + tailnet/tailnettest/multiagentmock.go .PHONY: gen # Mark all generated files as fresh so make thinks they're up-to-date. This is @@ -504,6 +506,8 @@ gen/mark-fresh: site/src/theme/icons.json \ examples/examples.gen.json \ tailnet/tailnettest/coordinatormock.go \ + tailnet/tailnettest/coordinateemock.go \ + tailnet/tailnettest/multiagentmock.go \ " for file in $$files; do echo "$$file" @@ -531,7 +535,7 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $ coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go go generate ./coderd/database/dbmock/ -tailnet/tailnettest/coordinatormock.go: tailnet/coordinator.go +tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/multiagentmock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go tailnet/multiagent.go go generate ./tailnet/tailnettest/ tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 5f3300711a..086cd76866 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -95,6 +95,7 @@ func NewServerTailnet( logger: logger, tracer: traceProvider.Tracer(tracing.TracerName), conn: conn, + coordinatee: conn, getMultiAgent: getMultiAgent, cache: cache, agentConnectionTimes: map[uuid.UUID]time.Time{}, @@ -224,13 +225,14 @@ func (s *ServerTailnet) watchAgentUpdates() { if !ok { if conn.IsClosed() && s.ctx.Err() == nil { s.logger.Warn(s.ctx, "multiagent closed, reinitializing") + s.coordinatee.SetAllPeersLost() s.reinitCoordinator() continue } return } - err := s.conn.UpdatePeers(resp.GetPeerUpdates()) + err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates()) if err != nil { if xerrors.Is(err, tailnet.ErrConnClosed) { s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err)) @@ -280,9 +282,14 @@ type ServerTailnet struct { cancel func() derpMapUpdaterClosed chan struct{} - logger slog.Logger - tracer trace.Tracer - conn *tailnet.Conn + logger slog.Logger + tracer trace.Tracer + + // in prod, these are the same, but coordinatee is a subset of Conn's + // methods which makes some tests easier. + conn *tailnet.Conn + coordinatee tailnet.Coordinatee + getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error) agentConn atomic.Pointer[tailnet.MultiAgentConn] cache *wsconncache.Cache diff --git a/coderd/tailnet_internal_test.go b/coderd/tailnet_internal_test.go new file mode 100644 index 0000000000..f09ac1d28b --- /dev/null +++ b/coderd/tailnet_internal_test.go @@ -0,0 +1,75 @@ +package coderd + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "go.uber.org/mock/gomock" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" +) + +// TestServerTailnet_Reconnect tests that ServerTailnet calls SetAllPeersLost on the Coordinatee +// (tailnet.Conn in production) when it disconnects from the Coordinator (via MultiAgentConn) and +// reconnects. +func TestServerTailnet_Reconnect(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctrl := gomock.NewController(t) + ctx := testutil.Context(t, testutil.WaitShort) + + mMultiAgent0 := tailnettest.NewMockMultiAgentConn(ctrl) + mMultiAgent1 := tailnettest.NewMockMultiAgentConn(ctrl) + mac := make(chan tailnet.MultiAgentConn, 2) + mac <- mMultiAgent0 + mac <- mMultiAgent1 + mCoord := tailnettest.NewMockCoordinatee(ctrl) + + uut := &ServerTailnet{ + ctx: ctx, + logger: logger, + coordinatee: mCoord, + getMultiAgent: func(ctx context.Context) (tailnet.MultiAgentConn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case m := <-mac: + return m, nil + } + }, + agentConn: atomic.Pointer[tailnet.MultiAgentConn]{}, + agentConnectionTimes: make(map[uuid.UUID]time.Time), + } + // reinit the Coordinator once, to load mMultiAgent0 + uut.reinitCoordinator() + + mMultiAgent0.EXPECT().NextUpdate(gomock.Any()). + Times(1). + Return(nil, false) // this indicates there are no more updates + closed0 := mMultiAgent0.EXPECT().IsClosed(). + Times(1). + Return(true) // this triggers reconnect + setLost := mCoord.EXPECT().SetAllPeersLost().Times(1).After(closed0) + mMultiAgent1.EXPECT().NextUpdate(gomock.Any()). + Times(1). + After(setLost). + Return(nil, false) + mMultiAgent1.EXPECT().IsClosed(). + Times(1). + Return(false) // this causes us to exit and not reconnect + + done := make(chan struct{}) + go func() { + uut.watchAgentUpdates() + close(done) + }() + + testutil.RequireRecvCtx(ctx, t, done) +} diff --git a/tailnet/service_test.go b/tailnet/service_test.go index c6a8907644..bb5683afa0 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -102,7 +102,7 @@ func TestClientService_ServeClient_V2(t *testing.T) { err = c.Close() require.NoError(t, err) err = testutil.RequireRecvCtx(ctx, t, errCh) - require.ErrorIs(t, err, io.EOF) + require.True(t, xerrors.Is(err, io.EOF) || xerrors.Is(err, io.ErrClosedPipe)) } func TestClientService_ServeClient_V1(t *testing.T) { diff --git a/tailnet/tailnettest/coordinateemock.go b/tailnet/tailnettest/coordinateemock.go new file mode 100644 index 0000000000..51f2dd2bce --- /dev/null +++ b/tailnet/tailnettest/coordinateemock.go @@ -0,0 +1,79 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/tailnet (interfaces: Coordinatee) +// +// Generated by this command: +// +// mockgen -destination ./coordinateemock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinatee +// + +// Package tailnettest is a generated GoMock package. +package tailnettest + +import ( + reflect "reflect" + + tailnet "github.com/coder/coder/v2/tailnet" + proto "github.com/coder/coder/v2/tailnet/proto" + gomock "go.uber.org/mock/gomock" +) + +// MockCoordinatee is a mock of Coordinatee interface. +type MockCoordinatee struct { + ctrl *gomock.Controller + recorder *MockCoordinateeMockRecorder +} + +// MockCoordinateeMockRecorder is the mock recorder for MockCoordinatee. +type MockCoordinateeMockRecorder struct { + mock *MockCoordinatee +} + +// NewMockCoordinatee creates a new mock instance. +func NewMockCoordinatee(ctrl *gomock.Controller) *MockCoordinatee { + mock := &MockCoordinatee{ctrl: ctrl} + mock.recorder = &MockCoordinateeMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCoordinatee) EXPECT() *MockCoordinateeMockRecorder { + return m.recorder +} + +// SetAllPeersLost mocks base method. +func (m *MockCoordinatee) SetAllPeersLost() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAllPeersLost") +} + +// SetAllPeersLost indicates an expected call of SetAllPeersLost. +func (mr *MockCoordinateeMockRecorder) SetAllPeersLost() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAllPeersLost", reflect.TypeOf((*MockCoordinatee)(nil).SetAllPeersLost)) +} + +// SetNodeCallback mocks base method. +func (m *MockCoordinatee) SetNodeCallback(arg0 func(*tailnet.Node)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNodeCallback", arg0) +} + +// SetNodeCallback indicates an expected call of SetNodeCallback. +func (mr *MockCoordinateeMockRecorder) SetNodeCallback(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNodeCallback", reflect.TypeOf((*MockCoordinatee)(nil).SetNodeCallback), arg0) +} + +// UpdatePeers mocks base method. +func (m *MockCoordinatee) UpdatePeers(arg0 []*proto.CoordinateResponse_PeerUpdate) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePeers", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdatePeers indicates an expected call of UpdatePeers. +func (mr *MockCoordinateeMockRecorder) UpdatePeers(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeers", reflect.TypeOf((*MockCoordinatee)(nil).UpdatePeers), arg0) +} diff --git a/tailnet/tailnettest/multiagentmock.go b/tailnet/tailnettest/multiagentmock.go new file mode 100644 index 0000000000..e72233ed38 --- /dev/null +++ b/tailnet/tailnettest/multiagentmock.go @@ -0,0 +1,141 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/tailnet (interfaces: MultiAgentConn) +// +// Generated by this command: +// +// mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn +// + +// Package tailnettest is a generated GoMock package. +package tailnettest + +import ( + context "context" + reflect "reflect" + + proto "github.com/coder/coder/v2/tailnet/proto" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" +) + +// MockMultiAgentConn is a mock of MultiAgentConn interface. +type MockMultiAgentConn struct { + ctrl *gomock.Controller + recorder *MockMultiAgentConnMockRecorder +} + +// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn. +type MockMultiAgentConnMockRecorder struct { + mock *MockMultiAgentConn +} + +// NewMockMultiAgentConn creates a new mock instance. +func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn { + mock := &MockMultiAgentConn{ctrl: ctrl} + mock.recorder = &MockMultiAgentConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder { + return m.recorder +} + +// AgentIsLegacy mocks base method. +func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AgentIsLegacy", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AgentIsLegacy indicates an expected call of AgentIsLegacy. +func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0) +} + +// Close mocks base method. +func (m *MockMultiAgentConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close)) +} + +// IsClosed mocks base method. +func (m *MockMultiAgentConn) IsClosed() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClosed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsClosed indicates an expected call of IsClosed. +func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed)) +} + +// NextUpdate mocks base method. +func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) (*proto.CoordinateResponse, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextUpdate", arg0) + ret0, _ := ret[0].(*proto.CoordinateResponse) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// NextUpdate indicates an expected call of NextUpdate. +func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0) +} + +// SubscribeAgent mocks base method. +func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscribeAgent", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SubscribeAgent indicates an expected call of SubscribeAgent. +func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0) +} + +// UnsubscribeAgent mocks base method. +func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnsubscribeAgent indicates an expected call of UnsubscribeAgent. +func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0) +} + +// UpdateSelf mocks base method. +func (m *MockMultiAgentConn) UpdateSelf(arg0 *proto.Node) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSelf", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSelf indicates an expected call of UpdateSelf. +func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0) +} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index e7ed6361a1..794aee549c 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -21,7 +21,9 @@ import ( "github.com/coder/coder/v2/tailnet" ) +//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/v2/tailnet MultiAgentConn //go:generate mockgen -destination ./coordinatormock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinator +//go:generate mockgen -destination ./coordinateemock.go -package tailnettest github.com/coder/coder/v2/tailnet Coordinatee // RunDERPAndSTUN creates a DERP mapping for tests. func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {