feat(enterprise): add ready for handshake support to pgcoord (#12935)

This commit is contained in:
Colin Adler 2024-04-16 15:01:10 -05:00 committed by GitHub
parent 942e90270e
commit 777dfbe965
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 364 additions and 82 deletions

View File

@ -103,7 +103,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *sql.TxOptions) err
// Transaction succeeded.
return nil
}
if err != nil && !IsSerializedError(err) {
if !IsSerializedError(err) {
// We should only retry if the error is a serialization error.
return err
}

View File

@ -2,6 +2,8 @@ package tailnet
import (
"context"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
@ -30,10 +32,13 @@ type connIO struct {
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
rfhs chan<- readyForHandshake
auth agpl.CoordinateeAuth
mu sync.Mutex
closed bool
disconnected bool
// latest is the most recent, unfiltered snapshot of the mappings we know about
latest []mapping
name string
start int64
@ -46,6 +51,7 @@ func newConnIO(coordContext context.Context,
logger slog.Logger,
bindings chan<- binding,
tunnels chan<- tunnel,
rfhs chan<- readyForHandshake,
requests <-chan *proto.CoordinateRequest,
responses chan<- *proto.CoordinateResponse,
id uuid.UUID,
@ -64,6 +70,7 @@ func newConnIO(coordContext context.Context,
responses: responses,
bindings: bindings,
tunnels: tunnels,
rfhs: rfhs,
auth: auth,
name: name,
start: now,
@ -190,9 +197,54 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.disconnected = true
return errDisconnect
}
if req.ReadyForHandshake != nil {
c.logger.Debug(c.peerCtx, "got ready for handshake ", slog.F("rfh", req.ReadyForHandshake))
for _, rfh := range req.ReadyForHandshake {
dst, err := uuid.FromBytes(rfh.Id)
if err != nil {
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return err
}
mappings := c.getLatestMapping()
if !slices.ContainsFunc(mappings, func(mapping mapping) bool {
return mapping.peer == dst
}) {
c.logger.Debug(c.peerCtx, "cannot process ready for handshake, src isn't peered with dst",
slog.F("dst", dst.String()),
)
_ = c.Enqueue(&proto.CoordinateResponse{
Error: fmt.Sprintf("you do not share a tunnel with %q", dst.String()),
})
return nil
}
if err := agpl.SendCtx(c.coordCtx, c.rfhs, readyForHandshake{
src: c.id,
dst: dst,
}); err != nil {
c.logger.Debug(c.peerCtx, "failed to send ready for handshake", slog.Error(err))
return err
}
}
}
return nil
}
func (c *connIO) setLatestMapping(latest []mapping) {
c.mu.Lock()
defer c.mu.Unlock()
c.latest = latest
}
func (c *connIO) getLatestMapping() []mapping {
c.mu.Lock()
defer c.mu.Unlock()
return c.latest
}
func (c *connIO) UniqueID() uuid.UUID {
return c.id
}

View File

@ -0,0 +1,73 @@
package tailnet
import (
"context"
"fmt"
"sync"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
type readyForHandshake struct {
src uuid.UUID
dst uuid.UUID
}
type handshaker struct {
ctx context.Context
logger slog.Logger
coordinatorID uuid.UUID
pubsub pubsub.Pubsub
updates <-chan readyForHandshake
workerWG sync.WaitGroup
}
func newHandshaker(ctx context.Context,
logger slog.Logger,
id uuid.UUID,
ps pubsub.Pubsub,
updates <-chan readyForHandshake,
startWorkers <-chan struct{},
) *handshaker {
s := &handshaker{
ctx: ctx,
logger: logger,
coordinatorID: id,
pubsub: ps,
updates: updates,
}
// add to the waitgroup immediately to avoid any races waiting for it before
// the workers start.
s.workerWG.Add(numHandshakerWorkers)
go func() {
<-startWorkers
for i := 0; i < numHandshakerWorkers; i++ {
go s.worker()
}
}()
return s
}
func (t *handshaker) worker() {
defer t.workerWG.Done()
for {
select {
case <-t.ctx.Done():
t.logger.Debug(t.ctx, "handshaker worker exiting", slog.Error(t.ctx.Err()))
return
case rfh := <-t.updates:
err := t.pubsub.Publish(eventReadyForHandshake, []byte(fmt.Sprintf(
"%s,%s", rfh.dst.String(), rfh.src.String(),
)))
if err != nil {
t.logger.Error(t.ctx, "publish ready for handshake", slog.Error(err))
}
}
}
}

View File

@ -0,0 +1,47 @@
package tailnet_test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)
func TestPGCoordinator_ReadyForHandshake_OK(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
agpltest.ReadyForHandshakeTest(ctx, t, coord1)
}
func TestPGCoordinator_ReadyForHandshake_NoPermission(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
agpltest.ReadyForHandshakeNoPermissionTest(ctx, t, coord1)
}

View File

@ -9,8 +9,6 @@ import (
"sync/atomic"
"time"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"golang.org/x/xerrors"
@ -22,25 +20,31 @@ import (
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
const (
EventHeartbeats = "tailnet_coordinator_heartbeat"
eventPeerUpdate = "tailnet_peer_update"
eventTunnelUpdate = "tailnet_tunnel_update"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 10
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
EventHeartbeats = "tailnet_coordinator_heartbeat"
eventPeerUpdate = "tailnet_peer_update"
eventTunnelUpdate = "tailnet_tunnel_update"
eventReadyForHandshake = "tailnet_ready_for_handshake"
HeartbeatPeriod = time.Second * 2
MissedHeartbeats = 3
numQuerierWorkers = 10
numBinderWorkers = 10
numTunnelerWorkers = 10
numHandshakerWorkers = 5
dbMaxBackoff = 10 * time.Second
cleanupPeriod = time.Hour
)
// pgCoord is a postgres-backed coordinator
//
// ┌──────────┐
// ┌────────────► tunneler ├──────────┐
// ┌────────────┐
// ┌────────────► handshaker ├────────┐
// │ └────────────┘ │
// │ ┌──────────┐ │
// ├────────────► tunneler ├──────────┤
// │ └──────────┘ │
// │ │
// ┌────────┐ ┌────────┐ ┌───▼───┐
@ -78,15 +82,17 @@ type pgCoord struct {
newConnections chan *connIO
closeConnections chan *connIO
tunnelerCh chan tunnel
handshakerCh chan readyForHandshake
id uuid.UUID
cancel context.CancelFunc
closeOnce sync.Once
closed chan struct{}
binder *binder
tunneler *tunneler
querier *querier
binder *binder
tunneler *tunneler
handshaker *handshaker
querier *querier
}
var pgCoordSubject = rbac.Subject{
@ -126,6 +132,8 @@ func newPGCoordInternal(
ccCh := make(chan *connIO)
// for communicating subscriptions with the tunneler
sCh := make(chan tunnel)
// for communicating ready for handshakes with the handshaker
rfhCh := make(chan readyForHandshake)
// signals when first heartbeat has been sent, so it's safe to start binding.
fHB := make(chan struct{})
@ -145,6 +153,8 @@ func newPGCoordInternal(
closeConnections: ccCh,
tunneler: newTunneler(ctx, logger, id, store, sCh, fHB),
tunnelerCh: sCh,
handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB),
handshakerCh: rfhCh,
id: id,
querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB),
closed: make(chan struct{}),
@ -242,7 +252,7 @@ func (c *pgCoord) Coordinate(
close(resps)
return reqs, resps
}
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, c.handshakerCh, reqs, resps, id, name, a)
err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
if err != nil {
// this can only happen if the context is canceled, no need to log
@ -626,8 +636,6 @@ type mapper struct {
c *connIO
// latest is the most recent, unfiltered snapshot of the mappings we know about
latest []mapping
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates.
sent map[uuid.UUID]mapping
@ -660,11 +668,11 @@ func (m *mapper) run() {
return
case mappings := <-m.mappings:
m.logger.Debug(m.ctx, "got new mappings")
m.latest = mappings
m.c.setLatestMapping(mappings)
best = m.bestMappings(mappings)
case <-m.update:
m.logger.Debug(m.ctx, "triggered update")
best = m.bestMappings(m.latest)
best = m.bestMappings(m.c.getLatestMapping())
}
update := m.bestToUpdate(best)
if update == nil {
@ -1067,6 +1075,28 @@ func (q *querier) subscribe() {
}()
q.logger.Info(q.ctx, "subscribed to tunnel updates")
var cancelRFH context.CancelFunc
err = backoff.Retry(func() error {
cancelFn, err := q.pubsub.SubscribeWithErr(eventReadyForHandshake, q.listenReadyForHandshake)
if err != nil {
q.logger.Warn(q.ctx, "failed to subscribe to ready for handshakes", slog.Error(err))
return err
}
cancelRFH = cancelFn
return nil
}, bkoff)
if err != nil {
if q.ctx.Err() == nil {
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
}
return
}
defer func() {
q.logger.Info(q.ctx, "canceling ready for handshake subscription")
cancelRFH()
}()
q.logger.Info(q.ctx, "subscribed to ready for handshakes")
// unblock the outer function from returning
subscribed <- struct{}{}
@ -1112,6 +1142,7 @@ func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
}
if err != nil {
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
return
}
peers, err := parseTunnelUpdate(string(msg))
if err != nil {
@ -1133,6 +1164,36 @@ func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
}
}
func (q *querier) listenReadyForHandshake(_ context.Context, msg []byte, err error) {
if err != nil && !xerrors.Is(err, pubsub.ErrDroppedMessages) {
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
return
}
to, from, err := parseReadyForHandshake(string(msg))
if err != nil {
q.logger.Error(q.ctx, "failed to parse ready for handshake", slog.F("msg", string(msg)), slog.Error(err))
return
}
mk := mKey(to)
q.mu.Lock()
mpr, ok := q.mappers[mk]
q.mu.Unlock()
if !ok {
q.logger.Debug(q.ctx, "ignoring ready for handshake because we have no mapper",
slog.F("peer_id", to))
return
}
_ = mpr.c.Enqueue(&proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: from[:],
Kind: proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE,
}},
})
}
func (q *querier) resyncPeerMappings() {
q.mu.Lock()
defer q.mu.Unlock()
@ -1225,6 +1286,21 @@ func parsePeerUpdate(msg string) (peer uuid.UUID, err error) {
return peer, nil
}
func parseReadyForHandshake(msg string) (to uuid.UUID, from uuid.UUID, err error) {
parts := strings.Split(msg, ",")
if len(parts) != 2 {
return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma")
}
ids := make([]uuid.UUID, 2)
for i, part := range parts {
ids[i], err = uuid.Parse(part)
if err != nil {
return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse UUID: %w", err)
}
}
return ids[0], ids[1], nil
}
// mKey identifies a set of node mappings we want to query.
type mKey uuid.UUID

View File

@ -10,9 +10,6 @@ import (
"testing"
"time"
"github.com/coder/coder/v2/codersdk/workspacesdk"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -24,14 +21,15 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil"
)

View File

@ -215,8 +215,7 @@ func NewRemoteCoordination(logger slog.Logger,
respLoopDone: make(chan struct{}),
}
if tunnelTarget != uuid.Nil {
// TODO: reenable in upstack PR
// c.coordinatee.SetTunnelDestination(tunnelTarget)
c.coordinatee.SetTunnelDestination(tunnelTarget)
c.Lock()
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
c.Unlock()

View File

@ -419,60 +419,16 @@ func TestCoordinator(t *testing.T) {
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
clientID := uuid.New()
agentID := uuid.New()
aReq, aRes := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
cReq, cRes := coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
{
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
cReq <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: &proto.Node{
Id: 3,
Key: nk,
Disco: string(dk),
},
}}
}
cReq <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{
Id: agentID[:],
}}
testutil.RequireRecvCtx(ctx, t, aRes)
aReq <- &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
Id: clientID[:],
}}}
ack := testutil.RequireRecvCtx(ctx, t, cRes)
require.NotNil(t, ack.PeerUpdates)
require.Len(t, ack.PeerUpdates, 1)
require.Equal(t, proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE, ack.PeerUpdates[0].Kind)
require.Equal(t, agentID[:], ack.PeerUpdates[0].Id)
test.ReadyForHandshakeTest(ctx, t, coordinator)
})
t.Run("AgentAck_NoPermission", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitShort)
clientID := uuid.New()
agentID := uuid.New()
aReq, aRes := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
_, _ = coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
aReq <- &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
Id: clientID[:],
}}}
rfhError := testutil.RequireRecvCtx(ctx, t, aRes)
require.NotEmpty(t, rfhError.Error)
test.ReadyForHandshakeNoPermissionTest(ctx, t, coordinator)
})
}

View File

@ -2,6 +2,7 @@ package test
import (
"context"
"fmt"
"testing"
"github.com/coder/coder/v2/tailnet"
@ -53,3 +54,31 @@ func BidirectionalTunnels(ctx context.Context, t *testing.T, coordinator tailnet
p1.AssertEventuallyHasDERP(p2.ID, 2)
p2.AssertEventuallyHasDERP(p1.ID, 1)
}
func ReadyForHandshakeTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
p1 := NewPeer(ctx, t, coordinator, "p1")
defer p1.Close(ctx)
p2 := NewPeer(ctx, t, coordinator, "p2")
defer p2.Close(ctx)
p1.AddTunnel(p2.ID)
p2.AddTunnel(p1.ID)
p1.UpdateDERP(1)
p2.UpdateDERP(2)
p1.AssertEventuallyHasDERP(p2.ID, 2)
p2.AssertEventuallyHasDERP(p1.ID, 1)
p2.ReadyForHandshake(p1.ID)
p1.AssertEventuallyReadyForHandshake(p2.ID)
}
func ReadyForHandshakeNoPermissionTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
p1 := NewPeer(ctx, t, coordinator, "p1")
defer p1.Close(ctx)
p2 := NewPeer(ctx, t, coordinator, "p2")
defer p2.Close(ctx)
p1.UpdateDERP(1)
p2.UpdateDERP(2)
p2.ReadyForHandshake(p1.ID)
p2.AssertEventuallyGetsError(fmt.Sprintf("you do not share a tunnel with %q", p1.ID.String()))
}

View File

@ -13,8 +13,9 @@ import (
)
type PeerStatus struct {
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
readyForHandshake bool
}
type Peer struct {
@ -68,6 +69,21 @@ func (p *Peer) UpdateDERP(derp int32) {
}
}
func (p *Peer) ReadyForHandshake(peer uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
Id: peer[:],
}}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout sending ready for handshake for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) Disconnect() {
p.t.Helper()
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
@ -135,6 +151,35 @@ func (p *Peer) AssertEventuallyResponsesClosed() {
}
}
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.readyForHandshake {
return
}
err := p.handleOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
}
}
func (p *Peer) AssertEventuallyGetsError(match string) {
p.t.Helper()
for {
err := p.handleOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
if err != nil && assert.ErrorContains(p.t, err, match) {
return
}
}
}
var responsesClosed = xerrors.New("responses closed")
func (p *Peer) handleOneResp() error {
@ -145,6 +190,9 @@ func (p *Peer) handleOneResp() error {
if !ok {
return responsesClosed
}
if resp.Error != "" {
return xerrors.New(resp.Error)
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Id)
if err != nil {
@ -152,12 +200,16 @@ func (p *Peer) handleOneResp() error {
}
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
p.peers[id] = PeerStatus{
preferredDERP: update.GetNode().GetPreferredDerp(),
status: update.Kind,
}
peer := p.peers[id]
peer.preferredDERP = update.GetNode().GetPreferredDerp()
peer.status = update.Kind
p.peers[id] = peer
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
peer := p.peers[id]
peer.readyForHandshake = true
p.peers[id] = peer
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
}