mirror of https://github.com/coder/coder.git
feat(enterprise): add ready for handshake support to pgcoord (#12935)
This commit is contained in:
parent
942e90270e
commit
777dfbe965
|
@ -103,7 +103,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *sql.TxOptions) err
|
||||||
// Transaction succeeded.
|
// Transaction succeeded.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err != nil && !IsSerializedError(err) {
|
if !IsSerializedError(err) {
|
||||||
// We should only retry if the error is a serialization error.
|
// We should only retry if the error is a serialization error.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@ package tailnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
@ -30,10 +32,13 @@ type connIO struct {
|
||||||
responses chan<- *proto.CoordinateResponse
|
responses chan<- *proto.CoordinateResponse
|
||||||
bindings chan<- binding
|
bindings chan<- binding
|
||||||
tunnels chan<- tunnel
|
tunnels chan<- tunnel
|
||||||
|
rfhs chan<- readyForHandshake
|
||||||
auth agpl.CoordinateeAuth
|
auth agpl.CoordinateeAuth
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
disconnected bool
|
disconnected bool
|
||||||
|
// latest is the most recent, unfiltered snapshot of the mappings we know about
|
||||||
|
latest []mapping
|
||||||
|
|
||||||
name string
|
name string
|
||||||
start int64
|
start int64
|
||||||
|
@ -46,6 +51,7 @@ func newConnIO(coordContext context.Context,
|
||||||
logger slog.Logger,
|
logger slog.Logger,
|
||||||
bindings chan<- binding,
|
bindings chan<- binding,
|
||||||
tunnels chan<- tunnel,
|
tunnels chan<- tunnel,
|
||||||
|
rfhs chan<- readyForHandshake,
|
||||||
requests <-chan *proto.CoordinateRequest,
|
requests <-chan *proto.CoordinateRequest,
|
||||||
responses chan<- *proto.CoordinateResponse,
|
responses chan<- *proto.CoordinateResponse,
|
||||||
id uuid.UUID,
|
id uuid.UUID,
|
||||||
|
@ -64,6 +70,7 @@ func newConnIO(coordContext context.Context,
|
||||||
responses: responses,
|
responses: responses,
|
||||||
bindings: bindings,
|
bindings: bindings,
|
||||||
tunnels: tunnels,
|
tunnels: tunnels,
|
||||||
|
rfhs: rfhs,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
name: name,
|
name: name,
|
||||||
start: now,
|
start: now,
|
||||||
|
@ -190,9 +197,54 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||||
c.disconnected = true
|
c.disconnected = true
|
||||||
return errDisconnect
|
return errDisconnect
|
||||||
}
|
}
|
||||||
|
if req.ReadyForHandshake != nil {
|
||||||
|
c.logger.Debug(c.peerCtx, "got ready for handshake ", slog.F("rfh", req.ReadyForHandshake))
|
||||||
|
for _, rfh := range req.ReadyForHandshake {
|
||||||
|
dst, err := uuid.FromBytes(rfh.Id)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
|
||||||
|
// this shouldn't happen unless there is a client error. Close the connection so the client
|
||||||
|
// doesn't just happily continue thinking everything is fine.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mappings := c.getLatestMapping()
|
||||||
|
if !slices.ContainsFunc(mappings, func(mapping mapping) bool {
|
||||||
|
return mapping.peer == dst
|
||||||
|
}) {
|
||||||
|
c.logger.Debug(c.peerCtx, "cannot process ready for handshake, src isn't peered with dst",
|
||||||
|
slog.F("dst", dst.String()),
|
||||||
|
)
|
||||||
|
_ = c.Enqueue(&proto.CoordinateResponse{
|
||||||
|
Error: fmt.Sprintf("you do not share a tunnel with %q", dst.String()),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := agpl.SendCtx(c.coordCtx, c.rfhs, readyForHandshake{
|
||||||
|
src: c.id,
|
||||||
|
dst: dst,
|
||||||
|
}); err != nil {
|
||||||
|
c.logger.Debug(c.peerCtx, "failed to send ready for handshake", slog.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *connIO) setLatestMapping(latest []mapping) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.latest = latest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connIO) getLatestMapping() []mapping {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.latest
|
||||||
|
}
|
||||||
|
|
||||||
func (c *connIO) UniqueID() uuid.UUID {
|
func (c *connIO) UniqueID() uuid.UUID {
|
||||||
return c.id
|
return c.id
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
package tailnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
|
)
|
||||||
|
|
||||||
|
type readyForHandshake struct {
|
||||||
|
src uuid.UUID
|
||||||
|
dst uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshaker struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger slog.Logger
|
||||||
|
coordinatorID uuid.UUID
|
||||||
|
pubsub pubsub.Pubsub
|
||||||
|
updates <-chan readyForHandshake
|
||||||
|
|
||||||
|
workerWG sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandshaker(ctx context.Context,
|
||||||
|
logger slog.Logger,
|
||||||
|
id uuid.UUID,
|
||||||
|
ps pubsub.Pubsub,
|
||||||
|
updates <-chan readyForHandshake,
|
||||||
|
startWorkers <-chan struct{},
|
||||||
|
) *handshaker {
|
||||||
|
s := &handshaker{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
coordinatorID: id,
|
||||||
|
pubsub: ps,
|
||||||
|
updates: updates,
|
||||||
|
}
|
||||||
|
// add to the waitgroup immediately to avoid any races waiting for it before
|
||||||
|
// the workers start.
|
||||||
|
s.workerWG.Add(numHandshakerWorkers)
|
||||||
|
go func() {
|
||||||
|
<-startWorkers
|
||||||
|
for i := 0; i < numHandshakerWorkers; i++ {
|
||||||
|
go s.worker()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *handshaker) worker() {
|
||||||
|
defer t.workerWG.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
t.logger.Debug(t.ctx, "handshaker worker exiting", slog.Error(t.ctx.Err()))
|
||||||
|
return
|
||||||
|
|
||||||
|
case rfh := <-t.updates:
|
||||||
|
err := t.pubsub.Publish(eventReadyForHandshake, []byte(fmt.Sprintf(
|
||||||
|
"%s,%s", rfh.dst.String(), rfh.src.String(),
|
||||||
|
)))
|
||||||
|
if err != nil {
|
||||||
|
t.logger.Error(t.ctx, "publish ready for handshake", slog.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package tailnet_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||||
|
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||||
|
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPGCoordinator_ReadyForHandshake_OK(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !dbtestutil.WillUsePostgres() {
|
||||||
|
t.Skip("test only with postgres")
|
||||||
|
}
|
||||||
|
store, ps := dbtestutil.NewDB(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||||
|
defer cancel()
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer coord1.Close()
|
||||||
|
|
||||||
|
agpltest.ReadyForHandshakeTest(ctx, t, coord1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPGCoordinator_ReadyForHandshake_NoPermission(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !dbtestutil.WillUsePostgres() {
|
||||||
|
t.Skip("test only with postgres")
|
||||||
|
}
|
||||||
|
store, ps := dbtestutil.NewDB(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||||
|
defer cancel()
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer coord1.Close()
|
||||||
|
|
||||||
|
agpltest.ReadyForHandshakeNoPermissionTest(ctx, t, coord1)
|
||||||
|
}
|
|
@ -9,8 +9,6 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
@ -22,25 +20,31 @@ import (
|
||||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
"github.com/coder/coder/v2/coderd/rbac"
|
"github.com/coder/coder/v2/coderd/rbac"
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
agpl "github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EventHeartbeats = "tailnet_coordinator_heartbeat"
|
EventHeartbeats = "tailnet_coordinator_heartbeat"
|
||||||
eventPeerUpdate = "tailnet_peer_update"
|
eventPeerUpdate = "tailnet_peer_update"
|
||||||
eventTunnelUpdate = "tailnet_tunnel_update"
|
eventTunnelUpdate = "tailnet_tunnel_update"
|
||||||
HeartbeatPeriod = time.Second * 2
|
eventReadyForHandshake = "tailnet_ready_for_handshake"
|
||||||
MissedHeartbeats = 3
|
HeartbeatPeriod = time.Second * 2
|
||||||
numQuerierWorkers = 10
|
MissedHeartbeats = 3
|
||||||
numBinderWorkers = 10
|
numQuerierWorkers = 10
|
||||||
numTunnelerWorkers = 10
|
numBinderWorkers = 10
|
||||||
dbMaxBackoff = 10 * time.Second
|
numTunnelerWorkers = 10
|
||||||
cleanupPeriod = time.Hour
|
numHandshakerWorkers = 5
|
||||||
|
dbMaxBackoff = 10 * time.Second
|
||||||
|
cleanupPeriod = time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
// pgCoord is a postgres-backed coordinator
|
// pgCoord is a postgres-backed coordinator
|
||||||
//
|
//
|
||||||
// ┌──────────┐
|
// ┌────────────┐
|
||||||
// ┌────────────► tunneler ├──────────┐
|
// ┌────────────► handshaker ├────────┐
|
||||||
|
// │ └────────────┘ │
|
||||||
|
// │ ┌──────────┐ │
|
||||||
|
// ├────────────► tunneler ├──────────┤
|
||||||
// │ └──────────┘ │
|
// │ └──────────┘ │
|
||||||
// │ │
|
// │ │
|
||||||
// ┌────────┐ ┌────────┐ ┌───▼───┐
|
// ┌────────┐ ┌────────┐ ┌───▼───┐
|
||||||
|
@ -78,15 +82,17 @@ type pgCoord struct {
|
||||||
newConnections chan *connIO
|
newConnections chan *connIO
|
||||||
closeConnections chan *connIO
|
closeConnections chan *connIO
|
||||||
tunnelerCh chan tunnel
|
tunnelerCh chan tunnel
|
||||||
|
handshakerCh chan readyForHandshake
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
|
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
|
||||||
binder *binder
|
binder *binder
|
||||||
tunneler *tunneler
|
tunneler *tunneler
|
||||||
querier *querier
|
handshaker *handshaker
|
||||||
|
querier *querier
|
||||||
}
|
}
|
||||||
|
|
||||||
var pgCoordSubject = rbac.Subject{
|
var pgCoordSubject = rbac.Subject{
|
||||||
|
@ -126,6 +132,8 @@ func newPGCoordInternal(
|
||||||
ccCh := make(chan *connIO)
|
ccCh := make(chan *connIO)
|
||||||
// for communicating subscriptions with the tunneler
|
// for communicating subscriptions with the tunneler
|
||||||
sCh := make(chan tunnel)
|
sCh := make(chan tunnel)
|
||||||
|
// for communicating ready for handshakes with the handshaker
|
||||||
|
rfhCh := make(chan readyForHandshake)
|
||||||
// signals when first heartbeat has been sent, so it's safe to start binding.
|
// signals when first heartbeat has been sent, so it's safe to start binding.
|
||||||
fHB := make(chan struct{})
|
fHB := make(chan struct{})
|
||||||
|
|
||||||
|
@ -145,6 +153,8 @@ func newPGCoordInternal(
|
||||||
closeConnections: ccCh,
|
closeConnections: ccCh,
|
||||||
tunneler: newTunneler(ctx, logger, id, store, sCh, fHB),
|
tunneler: newTunneler(ctx, logger, id, store, sCh, fHB),
|
||||||
tunnelerCh: sCh,
|
tunnelerCh: sCh,
|
||||||
|
handshaker: newHandshaker(ctx, logger, id, ps, rfhCh, fHB),
|
||||||
|
handshakerCh: rfhCh,
|
||||||
id: id,
|
id: id,
|
||||||
querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB),
|
querier: newQuerier(querierCtx, logger, id, ps, store, id, cCh, ccCh, numQuerierWorkers, fHB),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
|
@ -242,7 +252,7 @@ func (c *pgCoord) Coordinate(
|
||||||
close(resps)
|
close(resps)
|
||||||
return reqs, resps
|
return reqs, resps
|
||||||
}
|
}
|
||||||
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, reqs, resps, id, name, a)
|
cIO := newConnIO(c.ctx, ctx, logger, c.bindings, c.tunnelerCh, c.handshakerCh, reqs, resps, id, name, a)
|
||||||
err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
|
err := agpl.SendCtx(c.ctx, c.newConnections, cIO)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this can only happen if the context is canceled, no need to log
|
// this can only happen if the context is canceled, no need to log
|
||||||
|
@ -626,8 +636,6 @@ type mapper struct {
|
||||||
|
|
||||||
c *connIO
|
c *connIO
|
||||||
|
|
||||||
// latest is the most recent, unfiltered snapshot of the mappings we know about
|
|
||||||
latest []mapping
|
|
||||||
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates.
|
// sent is the state of mappings we have actually enqueued; used to compute diffs for updates.
|
||||||
sent map[uuid.UUID]mapping
|
sent map[uuid.UUID]mapping
|
||||||
|
|
||||||
|
@ -660,11 +668,11 @@ func (m *mapper) run() {
|
||||||
return
|
return
|
||||||
case mappings := <-m.mappings:
|
case mappings := <-m.mappings:
|
||||||
m.logger.Debug(m.ctx, "got new mappings")
|
m.logger.Debug(m.ctx, "got new mappings")
|
||||||
m.latest = mappings
|
m.c.setLatestMapping(mappings)
|
||||||
best = m.bestMappings(mappings)
|
best = m.bestMappings(mappings)
|
||||||
case <-m.update:
|
case <-m.update:
|
||||||
m.logger.Debug(m.ctx, "triggered update")
|
m.logger.Debug(m.ctx, "triggered update")
|
||||||
best = m.bestMappings(m.latest)
|
best = m.bestMappings(m.c.getLatestMapping())
|
||||||
}
|
}
|
||||||
update := m.bestToUpdate(best)
|
update := m.bestToUpdate(best)
|
||||||
if update == nil {
|
if update == nil {
|
||||||
|
@ -1067,6 +1075,28 @@ func (q *querier) subscribe() {
|
||||||
}()
|
}()
|
||||||
q.logger.Info(q.ctx, "subscribed to tunnel updates")
|
q.logger.Info(q.ctx, "subscribed to tunnel updates")
|
||||||
|
|
||||||
|
var cancelRFH context.CancelFunc
|
||||||
|
err = backoff.Retry(func() error {
|
||||||
|
cancelFn, err := q.pubsub.SubscribeWithErr(eventReadyForHandshake, q.listenReadyForHandshake)
|
||||||
|
if err != nil {
|
||||||
|
q.logger.Warn(q.ctx, "failed to subscribe to ready for handshakes", slog.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cancelRFH = cancelFn
|
||||||
|
return nil
|
||||||
|
}, bkoff)
|
||||||
|
if err != nil {
|
||||||
|
if q.ctx.Err() == nil {
|
||||||
|
q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
q.logger.Info(q.ctx, "canceling ready for handshake subscription")
|
||||||
|
cancelRFH()
|
||||||
|
}()
|
||||||
|
q.logger.Info(q.ctx, "subscribed to ready for handshakes")
|
||||||
|
|
||||||
// unblock the outer function from returning
|
// unblock the outer function from returning
|
||||||
subscribed <- struct{}{}
|
subscribed <- struct{}{}
|
||||||
|
|
||||||
|
@ -1112,6 +1142,7 @@ func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
|
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
peers, err := parseTunnelUpdate(string(msg))
|
peers, err := parseTunnelUpdate(string(msg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1133,6 +1164,36 @@ func (q *querier) listenTunnel(_ context.Context, msg []byte, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *querier) listenReadyForHandshake(_ context.Context, msg []byte, err error) {
|
||||||
|
if err != nil && !xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||||
|
q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
to, from, err := parseReadyForHandshake(string(msg))
|
||||||
|
if err != nil {
|
||||||
|
q.logger.Error(q.ctx, "failed to parse ready for handshake", slog.F("msg", string(msg)), slog.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mk := mKey(to)
|
||||||
|
q.mu.Lock()
|
||||||
|
mpr, ok := q.mappers[mk]
|
||||||
|
q.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
q.logger.Debug(q.ctx, "ignoring ready for handshake because we have no mapper",
|
||||||
|
slog.F("peer_id", to))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = mpr.c.Enqueue(&proto.CoordinateResponse{
|
||||||
|
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
|
||||||
|
Id: from[:],
|
||||||
|
Kind: proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (q *querier) resyncPeerMappings() {
|
func (q *querier) resyncPeerMappings() {
|
||||||
q.mu.Lock()
|
q.mu.Lock()
|
||||||
defer q.mu.Unlock()
|
defer q.mu.Unlock()
|
||||||
|
@ -1225,6 +1286,21 @@ func parsePeerUpdate(msg string) (peer uuid.UUID, err error) {
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseReadyForHandshake(msg string) (to uuid.UUID, from uuid.UUID, err error) {
|
||||||
|
parts := strings.Split(msg, ",")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma")
|
||||||
|
}
|
||||||
|
ids := make([]uuid.UUID, 2)
|
||||||
|
for i, part := range parts {
|
||||||
|
ids[i], err = uuid.Parse(part)
|
||||||
|
if err != nil {
|
||||||
|
return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse UUID: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids[0], ids[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
// mKey identifies a set of node mappings we want to query.
|
// mKey identifies a set of node mappings we want to query.
|
||||||
type mKey uuid.UUID
|
type mKey uuid.UUID
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
||||||
agpltest "github.com/coder/coder/v2/tailnet/test"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -24,14 +21,15 @@ import (
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
agpl "github.com/coder/coder/v2/tailnet"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -215,8 +215,7 @@ func NewRemoteCoordination(logger slog.Logger,
|
||||||
respLoopDone: make(chan struct{}),
|
respLoopDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
if tunnelTarget != uuid.Nil {
|
if tunnelTarget != uuid.Nil {
|
||||||
// TODO: reenable in upstack PR
|
c.coordinatee.SetTunnelDestination(tunnelTarget)
|
||||||
// c.coordinatee.SetTunnelDestination(tunnelTarget)
|
|
||||||
c.Lock()
|
c.Lock()
|
||||||
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
|
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
|
||||||
c.Unlock()
|
c.Unlock()
|
||||||
|
|
|
@ -419,60 +419,16 @@ func TestCoordinator(t *testing.T) {
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
ctx := testutil.Context(t, testutil.WaitShort)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
|
||||||
clientID := uuid.New()
|
test.ReadyForHandshakeTest(ctx, t, coordinator)
|
||||||
agentID := uuid.New()
|
|
||||||
|
|
||||||
aReq, aRes := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
|
|
||||||
cReq, cRes := coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
|
|
||||||
|
|
||||||
{
|
|
||||||
nk, err := key.NewNode().Public().MarshalBinary()
|
|
||||||
require.NoError(t, err)
|
|
||||||
dk, err := key.NewDisco().Public().MarshalText()
|
|
||||||
require.NoError(t, err)
|
|
||||||
cReq <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
|
||||||
Node: &proto.Node{
|
|
||||||
Id: 3,
|
|
||||||
Key: nk,
|
|
||||||
Disco: string(dk),
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
|
|
||||||
cReq <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{
|
|
||||||
Id: agentID[:],
|
|
||||||
}}
|
|
||||||
|
|
||||||
testutil.RequireRecvCtx(ctx, t, aRes)
|
|
||||||
|
|
||||||
aReq <- &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
|
||||||
Id: clientID[:],
|
|
||||||
}}}
|
|
||||||
ack := testutil.RequireRecvCtx(ctx, t, cRes)
|
|
||||||
require.NotNil(t, ack.PeerUpdates)
|
|
||||||
require.Len(t, ack.PeerUpdates, 1)
|
|
||||||
require.Equal(t, proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE, ack.PeerUpdates[0].Kind)
|
|
||||||
require.Equal(t, agentID[:], ack.PeerUpdates[0].Id)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentAck_NoPermission", func(t *testing.T) {
|
t.Run("AgentAck_NoPermission", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
ctx := testutil.Context(t, testutil.WaitShort)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
|
||||||
clientID := uuid.New()
|
test.ReadyForHandshakeNoPermissionTest(ctx, t, coordinator)
|
||||||
agentID := uuid.New()
|
|
||||||
|
|
||||||
aReq, aRes := coordinator.Coordinate(ctx, agentID, agentID.String(), tailnet.AgentCoordinateeAuth{ID: agentID})
|
|
||||||
_, _ = coordinator.Coordinate(ctx, clientID, clientID.String(), tailnet.ClientCoordinateeAuth{AgentID: agentID})
|
|
||||||
|
|
||||||
aReq <- &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
|
||||||
Id: clientID[:],
|
|
||||||
}}}
|
|
||||||
|
|
||||||
rfhError := testutil.RequireRecvCtx(ctx, t, aRes)
|
|
||||||
require.NotEmpty(t, rfhError.Error)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
@ -53,3 +54,31 @@ func BidirectionalTunnels(ctx context.Context, t *testing.T, coordinator tailnet
|
||||||
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||||
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ReadyForHandshakeTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||||
|
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||||
|
defer p1.Close(ctx)
|
||||||
|
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||||
|
defer p2.Close(ctx)
|
||||||
|
p1.AddTunnel(p2.ID)
|
||||||
|
p2.AddTunnel(p1.ID)
|
||||||
|
p1.UpdateDERP(1)
|
||||||
|
p2.UpdateDERP(2)
|
||||||
|
|
||||||
|
p1.AssertEventuallyHasDERP(p2.ID, 2)
|
||||||
|
p2.AssertEventuallyHasDERP(p1.ID, 1)
|
||||||
|
p2.ReadyForHandshake(p1.ID)
|
||||||
|
p1.AssertEventuallyReadyForHandshake(p2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadyForHandshakeNoPermissionTest(ctx context.Context, t *testing.T, coordinator tailnet.CoordinatorV2) {
|
||||||
|
p1 := NewPeer(ctx, t, coordinator, "p1")
|
||||||
|
defer p1.Close(ctx)
|
||||||
|
p2 := NewPeer(ctx, t, coordinator, "p2")
|
||||||
|
defer p2.Close(ctx)
|
||||||
|
p1.UpdateDERP(1)
|
||||||
|
p2.UpdateDERP(2)
|
||||||
|
|
||||||
|
p2.ReadyForHandshake(p1.ID)
|
||||||
|
p2.AssertEventuallyGetsError(fmt.Sprintf("you do not share a tunnel with %q", p1.ID.String()))
|
||||||
|
}
|
||||||
|
|
|
@ -13,8 +13,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type PeerStatus struct {
|
type PeerStatus struct {
|
||||||
preferredDERP int32
|
preferredDERP int32
|
||||||
status proto.CoordinateResponse_PeerUpdate_Kind
|
status proto.CoordinateResponse_PeerUpdate_Kind
|
||||||
|
readyForHandshake bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
|
@ -68,6 +69,21 @@ func (p *Peer) UpdateDERP(derp int32) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) ReadyForHandshake(peer uuid.UUID) {
|
||||||
|
p.t.Helper()
|
||||||
|
|
||||||
|
req := &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
||||||
|
Id: peer[:],
|
||||||
|
}}}
|
||||||
|
select {
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
p.t.Errorf("timeout sending ready for handshake for %s", p.name)
|
||||||
|
return
|
||||||
|
case p.reqs <- req:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Peer) Disconnect() {
|
func (p *Peer) Disconnect() {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
||||||
|
@ -135,6 +151,35 @@ func (p *Peer) AssertEventuallyResponsesClosed() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
o := p.peers[other]
|
||||||
|
if o.readyForHandshake {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.handleOneResp()
|
||||||
|
if xerrors.Is(err, responsesClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertEventuallyGetsError(match string) {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
err := p.handleOneResp()
|
||||||
|
if xerrors.Is(err, responsesClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && assert.ErrorContains(p.t, err, match) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var responsesClosed = xerrors.New("responses closed")
|
var responsesClosed = xerrors.New("responses closed")
|
||||||
|
|
||||||
func (p *Peer) handleOneResp() error {
|
func (p *Peer) handleOneResp() error {
|
||||||
|
@ -145,6 +190,9 @@ func (p *Peer) handleOneResp() error {
|
||||||
if !ok {
|
if !ok {
|
||||||
return responsesClosed
|
return responsesClosed
|
||||||
}
|
}
|
||||||
|
if resp.Error != "" {
|
||||||
|
return xerrors.New(resp.Error)
|
||||||
|
}
|
||||||
for _, update := range resp.PeerUpdates {
|
for _, update := range resp.PeerUpdates {
|
||||||
id, err := uuid.FromBytes(update.Id)
|
id, err := uuid.FromBytes(update.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -152,12 +200,16 @@ func (p *Peer) handleOneResp() error {
|
||||||
}
|
}
|
||||||
switch update.Kind {
|
switch update.Kind {
|
||||||
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
||||||
p.peers[id] = PeerStatus{
|
peer := p.peers[id]
|
||||||
preferredDERP: update.GetNode().GetPreferredDerp(),
|
peer.preferredDERP = update.GetNode().GetPreferredDerp()
|
||||||
status: update.Kind,
|
peer.status = update.Kind
|
||||||
}
|
p.peers[id] = peer
|
||||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||||
delete(p.peers, id)
|
delete(p.peers, id)
|
||||||
|
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
|
||||||
|
peer := p.peers[id]
|
||||||
|
peer.readyForHandshake = true
|
||||||
|
p.peers[id] = peer
|
||||||
default:
|
default:
|
||||||
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue