mirror of https://github.com/coder/coder.git
feat(enterprise): add ready for handshake support to pgcoord (#12935)
This commit is contained in:
parent
942e90270e
commit
777dfbe965
|
@ -103,7 +103,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *sql.TxOptions) err
|
|||
// Transaction succeeded.
|
||||
return nil
|
||||
}
|
||||
if err != nil && !IsSerializedError(err) {
|
||||
if !IsSerializedError(err) {
|
||||
// We should only retry if the error is a serialization error.
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package tailnet
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -30,10 +32,13 @@ type connIO struct {
|
|||
responses chan<- *proto.CoordinateResponse
|
||||
bindings chan<- binding
|
||||
tunnels chan<- tunnel
|
||||
rfhs chan<- readyForHandshake
|
||||
auth agpl.CoordinateeAuth
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
disconnected bool
|
||||
// latest is the most recent, unfiltered snapshot of the mappings we know about
|
||||
latest []mapping
|
||||
|
||||
name string
|
||||
start int64
|
||||
|
@ -46,6 +51,7 @@ func newConnIO(coordContext context.Context,
|
|||
logger slog.Logger,
|
||||
bindings chan<- binding,
|
||||
tunnels chan<- tunnel,
|
||||
rfhs chan<- readyForHandshake,
|
||||
requests <-chan *proto.CoordinateRequest,
|
||||
responses chan<- *proto.CoordinateResponse,
|
||||
id uuid.UUID,
|
||||
|
@ -64,6 +70,7 @@ func newConnIO(coordContext context.Context,
|
|||
responses: responses,
|
||||
bindings: bindings,
|
||||
tunnels: tunnels,
|
||||
rfhs: rfhs,
|
||||
auth: auth,
|
||||
name: name,
|
||||
start: now,
|
||||
|
@ -190,9 +197,54 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
|||
c.disconnected = true
|
||||
return errDisconnect
|
||||
}
|
||||
if req.ReadyForHandshake != nil {
|
||||
c.logger.Debug(c.peerCtx, "got ready for handshake ", slog.F("rfh", req.ReadyForHandshake))
|
||||
for _, rfh := range req.ReadyForHandshake {
|
||||
dst, err := uuid.FromBytes(rfh.Id)
|
||||
if err != nil {
|
||||
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
|
||||
// this shouldn't happen unless there is a client error. Close the connection so the client
|
||||
// doesn't just happily continue thinking everything is fine.
|
||||
return err
|
||||
}
|
||||
|
||||
mappings := c.getLatestMapping()
|
||||
if !slices.ContainsFunc(mappings, func(mapping mapping) bool {
|
||||
return mapping.peer == dst
|
||||
}) {
|
||||
c.logger.Debug(c.peerCtx, "cannot process ready for handshake, src isn't peered with dst",
|
||||
slog.F("dst", dst.String()),
|
||||
)
|
||||
_ = c.Enqueue(&proto.CoordinateResponse{
|
||||
Error: fmt.Sprintf("you do not share a tunnel with %q", dst.String()),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := agpl.SendCtx(c.coordCtx, c.rfhs, readyForHandshake{
|
||||
src: c.id,
|
||||
dst: dst,
|
||||
}); err != nil {
|
||||
c.logger.Debug(c.peerCtx, "failed to send ready for handshake", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connIO) setLatestMapping(latest []mapping) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.latest = latest
|
||||
}
|
||||
|
||||
func (c *connIO) getLatestMapping() []mapping {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.latest
|
||||
}
|
||||
|
||||
func (c *connIO) UniqueID() uuid.UUID {
|
||||
return c.id
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
type readyForHandshake struct {
|
||||
src uuid.UUID
|
||||
dst uuid.UUID
|
||||
}
|
||||
|
||||
type handshaker struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
coordinatorID uuid.UUID
|
||||
pubsub pubsub.Pubsub
|
||||
updates <-chan readyForHandshake
|
||||
|
||||
workerWG sync.WaitGroup
|
||||
}
|
||||
|
||||
func newHandshaker(ctx context.Context,
|
||||
logger slog.Logger,
|
||||
id uuid.UUID,
|
||||
ps pubsub.Pubsub,
|
||||
updates <-chan readyForHandshake,
|
||||
startWorkers <-chan struct{},
|
||||
) *handshaker {
|
||||
s := &handshaker{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
coordinatorID: id,
|
||||
pubsub: ps,
|
||||
updates: updates,
|
||||
}
|
||||
// add to the waitgroup immediately to avoid any races waiting for it before
|
||||
// the workers start.
|
||||
s.workerWG.Add(numHandshakerWorkers)
|
||||
go func() {
|
||||
<-startWorkers
|
||||
for i := 0; i < numHandshakerWorkers; i++ {
|
||||
go s.worker()
|
||||
}
|
||||
}()
|
||||
return s
|
||||
}
|
||||
|
||||
func (t *handshaker) worker() {
|
||||
defer t.workerWG.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
t.logger.Debug(t.ctx, "handshaker worker exiting", slog.Error(t.ctx.Err()))
|
||||
return
|
||||
|
||||
case rfh := <-t.updates:
|
||||
err := t.pubsub.Publish(eventReadyForHandshake, []byte(fmt.Sprintf(
|
||||
"%s,%s", rfh.dst.String(), rfh.src.String(),
|
||||
)))
|
||||
if err != nil {
|
||||
t.logger.Error(t.ctx, "publish ready for handshake", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package tailnet_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPGCoordinator_ReadyForHandshake_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
||||
agpltest.ReadyForHandshakeTest(ctx, t, coord1)
|
||||
}
|
||||
|
||||
func TestPGCoordinator_ReadyForHandshake_NoPermission(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
||||
agpltest.ReadyForHandshakeNoPermissionTest(ctx, t, coord1)
|
||||
}
|
|
@ -9,8 +9,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -22,25 +20,31 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
EventHeartbeats = "tailnet_coordinator_heartbeat"
|
||||
eventPeerUpdate = "tailnet_peer_update"
|
||||
eventTunnelUpdate = "tailnet_tunnel_update"
|
||||
HeartbeatPeriod = time.Second * 2
|
||||
MissedHeartbeats = 3
|
||||
numQuerierWorkers = 10
|
||||
numBinderWorkers = 10
|
||||
numTunnelerWorkers = 10
|
||||
dbMaxBackoff = 10 * time.Second
|
||||
cleanupPeriod = time.Hour
|
||||
EventHeartbeats = "tailnet_coordinator_heartbeat"
|
||||
eventPeerUpdate = "tailnet_peer_update"
|
||||
eventTunnelUpdate = "tailnet_tunnel_update"
|
||||
eventReadyForHandshake = "tailnet_ready_for_handshake"
|
||||
HeartbeatPeriod = time.Second * 2
|
||||
MissedHeartbeats = 3
|
||||
numQuerierWorkers = 10
|
||||
numBinderWorkers = 10
|
||||
numTunnelerWorkers = 10
|
||||
numHandshakerWorkers = 5
|
||||
dbMaxBackoff = 10 * time.Second
|
||||
cleanupPeriod = time.Hour
|
||||
)
|
||||
|
||||
// pgCoord is a postgres-backed coordinator
|
||||
//
|
||||
// ┌──────────┐
|
||||
// ┌────────────► tunneler ├──────────┐
|
||||
// ┌────────────┐
|
||||
// ┌────────────► handshaker ├────────┐
|
||||
// │ └────────────┘ │
|
||||
// │ ┌──────────┐ │
|
||||
// ├────────────► tunneler ├──────────┤
|
||||
// │ └──────────┘ │
|
||||
// │ │
|
||||
// ┌────────┐ ┌────────┐ ┌───▼───┐
|
||||
|
@ -78,15 +82,17 @@ type pgCoord struct {
|
|||
newConnections chan *connIO
|
||||
closeConnections chan *connIO
|
||||
tunnelerCh chan tunnel
|
||||
handshakerCh chan readyForHandshake
|
||||
id uuid.UUID
|
||||
|
||||
cancel context.CancelFunc
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
|
||||
binder *binder
|
||||
tunneler *tunneler
|
||||
querier *querier
|
||||
binder *binder
|
||||
tunneler *tunneler
|
||||
handshaker *handshaker
|
||||
querier *querier
|
||||
}
|
||||
|
||||
var pgCoordSubject = rbac.Subject{
|
||||
|
@ -126,6 +132,8 @@ func newPGCoordInternal(
|
|||
ccCh := make(chan *connIO)
|
||||
// for communicating subscriptions with the tunneler
|
||||
sCh := make(chan tunnel)
|
||||
// for communicating ready for handshakes with the handshaker
|
||||
rfhCh := make(chan readyForHandshake)
|
||||
// signals when first heartbeat has been sent, so it's safe to start binding.
|
||||
fHB := make(chan struct{})
|
||||
|
||||
|
@ -145,6 +153,8 @@ func newPGCoordInternal(
|
|||
closeConnections: ccCh,
|
||||
tunneler: newTunneler(ctx, logger, id, store, sCh, fHB),
|
||||
tunnelerCh: sCh,
|
||||
handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB),
|
||||
handshakerCh: rfhCh,
|
||||
id: id,
|
||||
querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB),
|
||||
closed: make(chan struct{}),
|
||||
|
@ -242,7 +252,7 @@ func (c *pgCoord) Coordinate(
|
|||
close(resps)
|
||||
return reqs, resps
|
||||
}
|
||||
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
|
||||
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, c.handshakerCh, reqs, resps, id, name, a)
|
||||
err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
|
||||
if err != nil {
|
||||
// this can only happen if the context is canceled, no need to log
|
||||
|
@ -626,8 +636,6 @@ type mapper struct {
|
|||
|
||||
c *connIO
|
||||
|
||||
// latest is the most recent, unfiltered snapshot of the mappings we know about
|
||||
latest []mapping
|
||||
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates.
|
||||
sent map[uuid.UUID]mapping
|
||||
|
||||
|
@ -660,11 +668,11 @@ func (m *mapper) run() {
|
|||
return
|
||||
case mappings := <-m.mappings:
|
||||
m.logger.Debug(m.ctx, "got new mappings")
|
||||
m.latest = mappings
|
||||
m.c.setLatestMapping(mappings)
|
||||
best = m.bestMappings(mappings)
|
||||
case <-m.update:
|
||||
m.logger.Debug(m.ctx, "triggered update")
|
||||
best = m.bestMappings(m.latest)
|
||||
best = m.bestMappings(m.c.getLatestMapping())
|
||||
}
|
||||
update := m.bestToUpdate(best)
|
||||
if update == nil {
|
||||
|
@ -1067,6 +1075,28 @@ func (q *querier) subscribe() {
|
|||
}()
|
||||
q.logger.Info(q.ctx, "subscribed to tunnel updates")
|
||||
|
||||
var cancelRFH context.CancelFunc
|
||||
err = backoff.Retry(func() error {
|
||||
cancelFn, err := q.pubsub.SubscribeWithErr(eventReadyForHandshake, q.listenReadyForHandshake)
|
||||
if err != nil {
|
||||
q.logger.Warn(q.ctx, "failed to subscribe to ready for handshakes", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
cancelRFH = cancelFn
|
||||
return nil
|
||||
}, bkoff)
|
||||
if err != nil {
|
||||
if q.ctx.Err() == nil {
|
||||
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
q.logger.Info(q.ctx, "canceling ready for handshake subscription")
|
||||
cancelRFH()
|
||||
}()
|
||||
q.logger.Info(q.ctx, "subscribed to ready for handshakes")
|
||||
|
||||
// unblock the outer function from returning
|
||||
subscribed <- struct{}{}
|
||||
|
||||
|
@ -1112,6 +1142,7 @@ func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
|||
}
|
||||
if err != nil {
|
||||
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
peers, err := parseTunnelUpdate(string(msg))
|
||||
if err != nil {
|
||||
|
@ -1133,6 +1164,36 @@ func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (q *querier) listenReadyForHandshake(_ context.Context, msg []byte, err error) {
|
||||
if err != nil && !xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
to, from, err := parseReadyForHandshake(string(msg))
|
||||
if err != nil {
|
||||
q.logger.Error(q.ctx, "failed to parse ready for handshake", slog.F("msg", string(msg)), slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
mk := mKey(to)
|
||||
q.mu.Lock()
|
||||
mpr, ok := q.mappers[mk]
|
||||
q.mu.Unlock()
|
||||
if !ok {
|
||||
q.logger.Debug(q.ctx, "ignoring ready for handshake because we have no mapper",
|
||||
slog.F("peer_id", to))
|
||||
return
|
||||
}
|
||||
|
||||
_ = mpr.c.Enqueue(&proto.CoordinateResponse{
|
||||
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
|
||||
Id: from[:],
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
func (q *querier) resyncPeerMappings() {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
|
@ -1225,6 +1286,21 @@ func parsePeerUpdate(msg string) (peer uuid.UUID, err error) {
|
|||
return peer, nil
|
||||
}
|
||||
|
||||
func parseReadyForHandshake(msg string) (to uuid.UUID, from uuid.UUID, err error) {
|
||||
parts := strings.Split(msg, ",")
|
||||
if len(parts) != 2 {
|
||||
return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma")
|
||||
}
|
||||
ids := make([]uuid.UUID, 2)
|
||||
for i, part := range parts {
|
||||
ids[i], err = uuid.Parse(part)
|
||||
if err != nil {
|
||||
return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse UUID: %w", err)
|
||||
}
|
||||
}
|
||||
return ids[0], ids[1], nil
|
||||
}
|
||||
|
||||
// mKey identifies a set of node mappings we want to query.
|
||||
type mKey uuid.UUID
|
||||
|
||||
|
|
|
@ -10,9 +10,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -24,14 +21,15 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
|
|
@ -215,8 +215,7 @@ func NewRemoteCoordination(logger slog.Logger,
|
|||
respLoopDone: make(chan struct{}),
|
||||
}
|
||||
if tunnelTarget != uuid.Nil {
|
||||
// TODO: reenable in upstack PR
|
||||
// c.coordinatee.SetTunnelDestination(tunnelTarget)
|
||||
c.coordinatee.SetTunnelDestination(tunnelTarget)
|
||||
c.Lock()
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
|
||||
c.Unlock()
|
||||
|
|
|
@ -419,60 +419,16 @@ func TestCoordinator(t *testing.T) {
|
|||
coordinator := tailnet.NewCoordinator(logger)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
clientID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
|
||||
aReq, aRes := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
|
||||
cReq, cRes := coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
|
||||
|
||||
{
|
||||
nk, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
dk, err := key.NewDisco().Public().MarshalText()
|
||||
require.NoError(t, err)
|
||||
cReq <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
||||
Node: &proto.Node{
|
||||
Id: 3,
|
||||
Key: nk,
|
||||
Disco: string(dk),
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
cReq <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{
|
||||
Id: agentID[:],
|
||||
}}
|
||||
|
||||
testutil.RequireRecvCtx(ctx, t, aRes)
|
||||
|
||||
aReq <- &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
||||
Id: clientID[:],
|
||||
}}}
|
||||
ack := testutil.RequireRecvCtx(ctx, t, cRes)
|
||||
require.NotNil(t, ack.PeerUpdates)
|
||||
require.Len(t, ack.PeerUpdates, 1)
|
||||
require.Equal(t, proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE, ack.PeerUpdates[0].Kind)
|
||||
require.Equal(t, agentID[:], ack.PeerUpdates[0].Id)
|
||||
test.ReadyForHandshakeTest(ctx, t, coordinator)
|
||||
})
|
||||
|
||||
t.Run("AgentAck_NoPermission", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
clientID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
|
||||
aReq, aRes := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
|
||||
_, _ = coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
|
||||
|
||||
aReq <- &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
||||
Id: clientID[:],
|
||||
}}}
|
||||
|
||||
rfhError := testutil.RequireRecvCtx(ctx, t, aRes)
|
||||
require.NotEmpty(t, rfhError.Error)
|
||||
test.ReadyForHandshakeNoPermissionTest(ctx, t, coordinator)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
|
@ -53,3 +54,31 @@ func BidirectionalTunnels(ctx context.Context, t *testing.T, coordinator tailnet
|
|||
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||
}
|
||||
|
||||
func ReadyForHandshakeTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||
defer p1.Close(ctx)
|
||||
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||
defer p2.Close(ctx)
|
||||
p1.AddTunnel(p2.ID)
|
||||
p2.AddTunnel(p1.ID)
|
||||
p1.UpdateDERP(1)
|
||||
p2.UpdateDERP(2)
|
||||
|
||||
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||
p2.ReadyForHandshake(p1.ID)
|
||||
p1.AssertEventuallyReadyForHandshake(p2.ID)
|
||||
}
|
||||
|
||||
func ReadyForHandshakeNoPermissionTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||
defer p1.Close(ctx)
|
||||
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||
defer p2.Close(ctx)
|
||||
p1.UpdateDERP(1)
|
||||
p2.UpdateDERP(2)
|
||||
|
||||
p2.ReadyForHandshake(p1.ID)
|
||||
p2.AssertEventuallyGetsError(fmt.Sprintf("you do not share a tunnel with %q", p1.ID.String()))
|
||||
}
|
||||
|
|
|
@ -13,8 +13,9 @@ import (
|
|||
)
|
||||
|
||||
type PeerStatus struct {
|
||||
preferredDERP int32
|
||||
status proto.CoordinateResponse_PeerUpdate_Kind
|
||||
preferredDERP int32
|
||||
status proto.CoordinateResponse_PeerUpdate_Kind
|
||||
readyForHandshake bool
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
|
@ -68,6 +69,21 @@ func (p *Peer) UpdateDERP(derp int32) {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Peer) ReadyForHandshake(peer uuid.UUID) {
|
||||
p.t.Helper()
|
||||
|
||||
req := &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
||||
Id: peer[:],
|
||||
}}}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout sending ready for handshake for %s", p.name)
|
||||
return
|
||||
case p.reqs <- req:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) Disconnect() {
|
||||
p.t.Helper()
|
||||
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
||||
|
@ -135,6 +151,35 @@ func (p *Peer) AssertEventuallyResponsesClosed() {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
||||
p.t.Helper()
|
||||
for {
|
||||
o := p.peers[other]
|
||||
if o.readyForHandshake {
|
||||
return
|
||||
}
|
||||
|
||||
err := p.handleOneResp()
|
||||
if xerrors.Is(err, responsesClosed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Peer) AssertEventuallyGetsError(match string) {
|
||||
p.t.Helper()
|
||||
for {
|
||||
err := p.handleOneResp()
|
||||
if xerrors.Is(err, responsesClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil && assert.ErrorContains(p.t, err, match) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var responsesClosed = xerrors.New("responses closed")
|
||||
|
||||
func (p *Peer) handleOneResp() error {
|
||||
|
@ -145,6 +190,9 @@ func (p *Peer) handleOneResp() error {
|
|||
if !ok {
|
||||
return responsesClosed
|
||||
}
|
||||
if resp.Error != "" {
|
||||
return xerrors.New(resp.Error)
|
||||
}
|
||||
for _, update := range resp.PeerUpdates {
|
||||
id, err := uuid.FromBytes(update.Id)
|
||||
if err != nil {
|
||||
|
@ -152,12 +200,16 @@ func (p *Peer) handleOneResp() error {
|
|||
}
|
||||
switch update.Kind {
|
||||
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
||||
p.peers[id] = PeerStatus{
|
||||
preferredDERP: update.GetNode().GetPreferredDerp(),
|
||||
status: update.Kind,
|
||||
}
|
||||
peer := p.peers[id]
|
||||
peer.preferredDERP = update.GetNode().GetPreferredDerp()
|
||||
peer.status = update.Kind
|
||||
p.peers[id] = peer
|
||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||
delete(p.peers, id)
|
||||
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
|
||||
peer := p.peers[id]
|
||||
peer.readyForHandshake = true
|
||||
p.peers[id] = peer
|
||||
default:
|
||||
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue