fix: fix graceful disconnect in DialWorkspaceAgent (#11993)

I noticed in testing that the CLI wasn't correctly sending the disconnect message when it shuts down, and thus agents are seeing this as a "lost" peer, rather than a "disconnected" one. 

What was happening is that we just used a single context for everything from the netconn to the RPCs, and when the context was canceled we failed to send the disconnect message due to canceled context.

So, this PR splits things into two contexts, with a graceful one set to last up to 1 second longer than the main one.
This commit is contained in:
Spike Curtis 2024-02-05 14:01:37 +04:00 committed by GitHub
parent bb99cb7d2b
commit e5ba586e30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 190 additions and 39 deletions

View File

@ -12,10 +12,9 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/sync/errgroup"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
@ -360,6 +359,15 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
return agentConn, nil
}
// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
// included so that we can fake it in testing.
//
// @typescript-ignore tailnetConn
type tailnetConn interface {
tailnet.Coordinatee
SetDERPMap(derpMap *tailcfg.DERPMap)
}
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
//
// 1) run the Coordinate API and pass node information back and forth
@ -370,13 +378,20 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
//
// @typescript-ignore tailnetAPIConnector
type tailnetAPIConnector struct {
ctx context.Context
// We keep track of two contexts: the main context from the caller, and a "graceful" context
// that we keep open slightly longer than the main context to give a chance to send the
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
// disconnect instead of just losing network connectivity.
ctx context.Context
gracefulCtx context.Context
cancelGracefulCtx context.CancelFunc
logger slog.Logger
agentID uuid.UUID
coordinateURL string
dialOptions *websocket.DialOptions
conn *tailnet.Conn
conn tailnetConn
connected chan error
isFirst bool
@ -387,7 +402,7 @@ type tailnetAPIConnector struct {
func runTailnetAPIConnector(
ctx context.Context, logger slog.Logger,
agentID uuid.UUID, coordinateURL string, dialOptions *websocket.DialOptions,
conn *tailnet.Conn,
conn tailnetConn,
) *tailnetAPIConnector {
tac := &tailnetAPIConnector{
ctx: ctx,
@ -399,10 +414,23 @@ func runTailnetAPIConnector(
connected: make(chan error, 1),
closed: make(chan struct{}),
}
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
go tac.manageGracefulTimeout()
go tac.run()
return tac
}
// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
// to allow a graceful disconnect.
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
defer tac.cancelGracefulCtx()
<-tac.ctx.Done()
select {
case <-tac.closed:
case <-time.After(time.Second):
}
}
func (tac *tailnetAPIConnector) run() {
tac.isFirst = true
defer close(tac.closed)
@ -437,7 +465,7 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
return nil, err
}
client, err := tailnet.NewDRPCClient(
websocket.NetConn(tac.ctx, ws, websocket.MessageBinary),
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
tac.logger,
)
if err != nil {
@ -464,65 +492,81 @@ func (tac *tailnetAPIConnector) coordinateAndDERPMap(client proto.DRPCTailnetCli
<-conn.Closed()
}
}()
eg, egCtx := errgroup.WithContext(tac.ctx)
eg.Go(func() error {
return tac.coordinate(egCtx, client)
})
eg.Go(func() error {
return tac.derpMap(egCtx, client)
})
err := eg.Wait()
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error while connected to tailnet v2+ API")
}
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
tac.coordinate(client)
}()
go func() {
defer wg.Done()
dErr := tac.derpMap(client)
if dErr != nil && tac.ctx.Err() == nil {
// The main context is still active, meaning that we want the tailnet data plane to stay
// up, even though we hit some error getting DERP maps on the control plane. That means
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
// close the underlying connection. This will trigger a retry of the control plane in
// run().
client.DRPCConn().Close()
// Note that derpMap() logs it own errors, we don't bother here.
}
}()
wg.Wait()
}
func (tac *tailnetAPIConnector) coordinate(ctx context.Context, client proto.DRPCTailnetClient) error {
coord, err := client.Coordinate(ctx)
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
coord, err := client.Coordinate(tac.gracefulCtx)
if err != nil {
return xerrors.Errorf("failed to connect to Coordinate RPC: %w", err)
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
return
}
defer func() {
cErr := coord.Close()
if cErr != nil {
tac.logger.Debug(ctx, "error closing Coordinate RPC", slog.Error(cErr))
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
}
}()
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID)
tac.logger.Debug(ctx, "serving coordinator")
err = <-coordination.Error()
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
return xerrors.Errorf("remote coordination error: %w", err)
tac.logger.Debug(tac.ctx, "serving coordinator")
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close()
if crdErr != nil {
tac.logger.Error(tac.ctx, "failed to close remote coordination", slog.Error(err))
}
case err = <-coordination.Error():
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "remote coordination error: %w", err)
}
}
return nil
}
func (tac *tailnetAPIConnector) derpMap(ctx context.Context, client proto.DRPCTailnetClient) error {
s, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
s, err := client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
if err != nil {
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
}
defer func() {
cErr := s.Close()
if cErr != nil {
tac.logger.Debug(ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
}
}()
for {
dmp, err := s.Recv()
if err != nil {
if xerrors.Is(err, io.EOF) || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
return xerrors.Errorf("error receiving DERP Map: %w", err)
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
return err
}
tac.logger.Debug(ctx, "got new DERP Map", slog.F("derp_map", dmp))
tac.logger.Debug(tac.ctx, "got new DERP Map", slog.F("derp_map", dmp))
dm := tailnet.DERPMapFromProto(dmp)
tac.conn.SetDERPMap(dm)
}

View File

@ -0,0 +1,106 @@
package codersdk
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
logger := slogtest.Make(t, &slogtest.Options{
// we get EOF when we simulate a DERPMap error
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF),
}).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
clientID := uuid.UUID{0x66}
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh)
svc, err := tailnet.NewClientService(
logger, &coordPtr,
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
)
require.NoError(t, err)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
ctx, nc := websocketNetConn(r.Context(), sws, websocket.MessageBinary)
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
})
assert.NoError(t, err)
}))
fConn := newFakeTailnetConn()
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.NotNil(t, reqTun.AddTunnel)
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
// simulate a problem with DERPMaps by sending nil
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
// this should cause the coordinate call to hang up WITHOUT disconnecting
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.Nil(t, reqNil)
// ...and then reconnect
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.NotNil(t, reqTun.AddTunnel)
// canceling the context should trigger the disconnect message
cancel()
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
require.NotNil(t, reqDisc)
require.NotNil(t, reqDisc.Disconnect)
}
type fakeTailnetConn struct{}
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
// TODO implement me
panic("implement me")
}
func (*fakeTailnetConn) SetAllPeersLost() {}
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
func newFakeTailnetConn() *fakeTailnetConn {
return &fakeTailnetConn{}
}

View File

@ -134,6 +134,7 @@ func (c *remoteCoordination) Close() (retErr error) {
if err != nil {
return xerrors.Errorf("send disconnect: %w", err)
}
c.logger.Debug(context.Background(), "sent disconnect")
return nil
}
@ -167,7 +168,7 @@ func (c *remoteCoordination) respLoop() {
}
}
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinee (usually a
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
// a client---agents should NOT set this!).
func NewRemoteCoordination(logger slog.Logger,