feat: modify PG Coordinator to work with new v2 Tailnet API (#10573)

re: #10528

Refactors PG Coordinator to work with the Tailnet v2 API, including wrappers for the existing v1 API.

The debug endpoint functions, but doesn't return sensible data, that will be in another stacked PR.
This commit is contained in:
Spike Curtis 2023-11-20 14:31:04 +04:00 committed by GitHub
parent a8c25180db
commit 5c48cb4447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1041 additions and 822 deletions

View File

@ -2,136 +2,230 @@ package tailnet
import (
"context"
"encoding/json"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"cdr.dev/slog"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
// via its updates TrackedConn, which then writes them.
// connIO manages the reading and writing to a connected peer. It reads requests via its requests
// channel, then pushes them onto the bindings or tunnels channel. It receives responses via calls
// to Enqueue and pushes them onto the responses channel.
type connIO struct {
pCtx context.Context
ctx context.Context
cancel context.CancelFunc
logger slog.Logger
decoder *json.Decoder
updates *agpl.TrackedConn
bindings chan<- binding
id uuid.UUID
// coordCtx is the parent context, that is, the context of the Coordinator
coordCtx context.Context
// peerCtx is the context of the connection to our peer
peerCtx context.Context
cancel context.CancelFunc
logger slog.Logger
requests <-chan *proto.CoordinateRequest
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
auth agpl.TunnelAuth
mu sync.Mutex
closed bool
name string
start int64
lastWrite int64
overwrites int64
}
func newConnIO(pCtx context.Context,
func newConnIO(coordContext context.Context,
peerCtx context.Context,
logger slog.Logger,
bindings chan<- binding,
conn net.Conn,
tunnels chan<- tunnel,
requests <-chan *proto.CoordinateRequest,
responses chan<- *proto.CoordinateResponse,
id uuid.UUID,
name string,
kind agpl.QueueKind,
auth agpl.TunnelAuth,
) *connIO {
ctx, cancel := context.WithCancel(pCtx)
peerCtx, cancel := context.WithCancel(peerCtx)
now := time.Now().Unix()
c := &connIO{
pCtx: pCtx,
ctx: ctx,
cancel: cancel,
logger: logger,
decoder: json.NewDecoder(conn),
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
bindings: bindings,
id: id,
coordCtx: coordContext,
peerCtx: peerCtx,
cancel: cancel,
logger: logger.With(slog.F("name", name)),
requests: requests,
responses: responses,
bindings: bindings,
tunnels: tunnels,
auth: auth,
name: name,
start: now,
lastWrite: now,
}
go c.recvLoop()
go c.updates.SendUpdates()
logger.Info(ctx, "serving connection")
c.logger.Info(coordContext, "serving connection")
return c
}
func (c *connIO) recvLoop() {
defer func() {
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
// canceled, but we still need to withdraw bindings.
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
// our own context might be canceled, but we still need to withdraw.
b := binding{
bKey: bKey{
id: c.UniqueID(),
kind: c.Kind(),
},
bKey: bKey(c.UniqueID()),
}
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
}
t := tunnel{
tKey: tKey{src: c.UniqueID()},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
}
}()
defer c.cancel()
defer c.Close()
for {
var node agpl.Node
err := c.decoder.Decode(&node)
req, err := recvCtx(c.peerCtx, c.requests)
if err != nil {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, io.ErrClosedPipe) ||
xerrors.Is(err, context.Canceled) ||
if xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
websocket.CloseStatus(err) > 0 {
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
xerrors.Is(err, io.EOF) {
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
} else {
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
}
return
}
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
b := binding{
bKey: bKey{
id: c.UniqueID(),
kind: c.Kind(),
},
node: &node,
}
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
if err := c.handleRequest(req); err != nil {
return
}
}
}
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.logger.Debug(c.peerCtx, "got request")
if req.UpdateSelf != nil {
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
b := binding{
bKey: bKey(c.UniqueID()),
node: req.UpdateSelf.Node,
}
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
return err
}
}
if req.AddTunnel != nil {
c.logger.Debug(c.peerCtx, "got add tunnel", slog.F("tunnel", req.AddTunnel))
dst, err := uuid.FromBytes(req.AddTunnel.Uuid)
if err != nil {
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return err
}
if !c.auth.Authorize(dst) {
return xerrors.New("unauthorized tunnel")
}
t := tunnel{
tKey: tKey{
src: c.UniqueID(),
dst: dst,
},
active: true,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
return err
}
}
if req.RemoveTunnel != nil {
c.logger.Debug(c.peerCtx, "got remove tunnel", slog.F("tunnel", req.RemoveTunnel))
dst, err := uuid.FromBytes(req.RemoveTunnel.Uuid)
if err != nil {
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
// this shouldn't happen unless there is a client error. Close the connection so the client
// doesn't just happily continue thinking everything is fine.
return err
}
t := tunnel{
tKey: tKey{
src: c.UniqueID(),
dst: dst,
},
active: false,
}
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
return err
}
}
// TODO: (spikecurtis) support Disconnect
return nil
}
func (c *connIO) UniqueID() uuid.UUID {
return c.updates.UniqueID()
return c.id
}
func (c *connIO) Kind() agpl.QueueKind {
return c.updates.Kind()
}
func (c *connIO) Enqueue(n []*agpl.Node) error {
return c.updates.Enqueue(n)
func (c *connIO) Enqueue(resp *proto.CoordinateResponse) error {
atomic.StoreInt64(&c.lastWrite, time.Now().Unix())
c.mu.Lock()
closed := c.closed
c.mu.Unlock()
if closed {
return xerrors.New("connIO closed")
}
select {
case <-c.peerCtx.Done():
return c.peerCtx.Err()
case c.responses <- resp:
c.logger.Debug(c.peerCtx, "wrote response")
return nil
default:
return agpl.ErrWouldBlock
}
}
func (c *connIO) Name() string {
return c.updates.Name()
return c.name
}
func (c *connIO) Stats() (start int64, lastWrite int64) {
return c.updates.Stats()
return c.start, atomic.LoadInt64(&c.lastWrite)
}
func (c *connIO) Overwrites() int64 {
return c.updates.Overwrites()
return atomic.LoadInt64(&c.overwrites)
}
// CoordinatorClose is used by the coordinator when closing a Queue. It
// should skip removing itself from the coordinator.
func (c *connIO) CoordinatorClose() error {
c.cancel()
return c.updates.CoordinatorClose()
return c.Close()
}
func (c *connIO) Done() <-chan struct{} {
return c.ctx.Done()
return c.peerCtx.Done()
}
func (c *connIO) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.cancel()
return c.updates.Close()
c.closed = true
close(c.responses)
return nil
}

View File

@ -27,11 +27,10 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
@ -75,11 +74,10 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
@ -124,11 +122,10 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
@ -189,11 +186,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
@ -243,11 +239,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
@ -299,11 +294,10 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,6 @@ package tailnet_test
import (
"context"
"database/sql"
"encoding/json"
"io"
"net"
"sync"
@ -17,6 +16,7 @@ import (
"go.uber.org/goleak"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
@ -27,6 +27,7 @@ import (
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
)
@ -52,17 +53,17 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
defer client.close()
client.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(clients) == 0 {
return false
}
var node agpl.Node
err = json.Unmarshal(clients[0].Node, &node)
node := new(proto.Node)
err = gProto.Unmarshal(clients[0].Node, node)
assert.NoError(t, err)
assert.Equal(t, 10, node.PreferredDERP)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
@ -90,17 +91,17 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetAgents(ctx, agent.id)
agents, err := store.GetTailnetPeers(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
var node agpl.Node
err = json.Unmarshal(agents[0].Node, &node)
node := new(proto.Node)
err = gProto.Unmarshal(agents[0].Node, node)
assert.NoError(t, err)
assert.Equal(t, 10, node.PreferredDERP)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
@ -342,39 +343,51 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
t.Logf("agent1=%s", agent1.id)
agent2 := newTestAgent(t, coord2, "agent2")
defer agent2.close()
t.Logf("agent2=%s", agent2.id)
client11 := newTestClient(t, coord1, agent1.id)
defer client11.close()
t.Logf("client11=%s", client11.id)
client12 := newTestClient(t, coord1, agent2.id)
defer client12.close()
t.Logf("client12=%s", client12.id)
client21 := newTestClient(t, coord2, agent1.id)
defer client21.close()
t.Logf("client21=%s", client21.id)
client22 := newTestClient(t, coord2, agent2.id)
defer client22.close()
t.Logf("client22=%s", client22.id)
t.Logf("client11 -> Node 11")
client11.sendNode(&agpl.Node{PreferredDERP: 11})
assertEventuallyHasDERPs(ctx, t, agent1, 11)
t.Logf("client21 -> Node 21")
client21.sendNode(&agpl.Node{PreferredDERP: 21})
assertEventuallyHasDERPs(ctx, t, agent1, 21, 11)
assertEventuallyHasDERPs(ctx, t, agent1, 21)
t.Logf("client22 -> Node 22")
client22.sendNode(&agpl.Node{PreferredDERP: 22})
assertEventuallyHasDERPs(ctx, t, agent2, 22)
t.Logf("agent2 -> Node 2")
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
assertEventuallyHasDERPs(ctx, t, client22, 2)
assertEventuallyHasDERPs(ctx, t, client12, 2)
t.Logf("client12 -> Node 12")
client12.sendNode(&agpl.Node{PreferredDERP: 12})
assertEventuallyHasDERPs(ctx, t, agent2, 12, 22)
assertEventuallyHasDERPs(ctx, t, agent2, 12)
t.Logf("agent1 -> Node 1")
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertEventuallyHasDERPs(ctx, t, client21, 1)
assertEventuallyHasDERPs(ctx, t, client11, 1)
// let's close coord2
t.Logf("close coord2")
err = coord2.Close()
require.NoError(t, err)
@ -386,18 +399,9 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
err = client21.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
// agent1 will see an update that drops client21.
// In this case the update is superfluous because client11's node hasn't changed, and agents don't deprogram clients
// from the dataplane even if they are missing. Suppressing this kind of update would require the coordinator to
// store all the data its sent to each connection, so we don't bother.
assertEventuallyHasDERPs(ctx, t, agent1, 11)
// note that although agent2 is disconnected, client12 does NOT get an update because we suppress empty updates.
// (Its easy to tell these are superfluous.)
assertEventuallyNoAgents(ctx, t, store, agent2.id)
// Close coord1
t.Logf("close coord1")
err = coord1.Close()
require.NoError(t, err)
// this closes agent1, client12, client11
@ -541,9 +545,12 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
Return(database.TailnetCoordinator{}, nil)
// extra calls we don't particularly care about for this test
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().GetTailnetClientsForAgent(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
mStore.EXPECT().DeleteTailnetAgent(gomock.Any(), gomock.Any()).
AnyTimes().Return(database.DeleteTailnetAgentRow{}, nil)
mStore.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
mStore.EXPECT().GetTailnetTunnelPeerBindings(gomock.Any(), gomock.Any()).
AnyTimes().Return(nil, nil)
mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()).
AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil)
mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
@ -589,6 +596,34 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
}
}
// TestPGCoordinator_BidirectionalTunnels tests when peers create tunnels to each other. We don't
// do this now, but it's schematically possible, so we should make sure it doesn't break anything.
func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
p1 := newTestPeer(ctx, t, coordinator, "p1")
defer p1.close(ctx)
p2 := newTestPeer(ctx, t, coordinator, "p2")
defer p2.close(ctx)
p1.addTunnel(p2.id)
p2.addTunnel(p1.id)
p1.updateDERP(1)
p2.updateDERP(2)
p1.assertEventuallyHasDERP(p2.id, 2)
p2.assertEventuallyHasDERP(p1.id, 1)
}
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node
@ -779,7 +814,7 @@ func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.Mu
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
assert.Eventually(t, func() bool {
agents, err := store.GetTailnetAgents(ctx, agentID)
agents, err := store.GetTailnetPeers(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
@ -793,7 +828,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
clients, err := store.GetTailnetTunnelPeerIDs(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
@ -804,6 +839,108 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
}, testutil.WaitShort, testutil.IntervalFast)
}
type testPeer struct {
ctx context.Context
cancel context.CancelFunc
t testing.TB
id uuid.UUID
name string
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
derps map[uuid.UUID]int32
}
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
p := &testPeer{t: t, name: name, derps: make(map[uuid.UUID]int32)}
p.ctx, p.cancel = context.WithCancel(ctx)
if len(id) > 1 {
t.Fatal("too many")
}
if len(id) == 1 {
p.id = id[0]
} else {
p.id = uuid.New()
}
// SingleTailnetTunnelAuth allows connections to arbitrary peers
p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{})
return p
}
func (p *testPeer) addTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout adding tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) updateDERP(derp int32) {
p.t.Helper()
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
p.t.Helper()
for {
d, ok := p.derps[other]
if ok && d == derp {
return
}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout waiting for response for %s", p.name)
return
case resp, ok := <-p.resps:
if !ok {
p.t.Errorf("responses closed for %s", p.name)
return
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Uuid)
if !assert.NoError(p.t, err) {
return
}
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE:
p.derps[id] = update.Node.PreferredDerp
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.derps, id)
default:
p.t.Errorf("unhandled update kind %s", update.Kind)
}
}
}
}
}
func (p *testPeer) close(ctx context.Context) {
p.t.Helper()
p.cancel()
for {
select {
case <-ctx.Done():
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
return
case _, ok := <-p.resps:
if ok {
continue
}
return
}
}
}
type fakeCoordinator struct {
ctx context.Context
t *testing.T
@ -819,12 +956,15 @@ func (c *fakeCoordinator) heartbeat() {
func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) {
c.t.Helper()
nodeRaw, err := json.Marshal(node)
pNode, err := agpl.NodeToProto(node)
require.NoError(c.t, err)
_, err = c.store.UpsertTailnetAgent(c.ctx, database.UpsertTailnetAgentParams{
nodeRaw, err := gProto.Marshal(pNode)
require.NoError(c.t, err)
_, err = c.store.UpsertTailnetPeer(c.ctx, database.UpsertTailnetPeerParams{
ID: agentID,
CoordinatorID: c.id,
Node: nodeRaw,
Status: database.TailnetStatusOk,
})
require.NoError(c.t, err)
}

View File

@ -22,6 +22,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/tailnet/proto"
)
// Coordinator exchanges nodes with agents to establish connections.
@ -48,6 +49,17 @@ type Coordinator interface {
ServeMultiAgent(id uuid.UUID) MultiAgentConn
}
// CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API.
type CoordinatorV2 interface {
// ServeHTTPDebug serves a debug webpage that shows the internal state of
// the coordinator.
ServeHTTPDebug(w http.ResponseWriter, r *http.Request)
// Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown.
Node(id uuid.UUID) *Node
Close() error
Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
}
// Node represents a node in the network.
type Node struct {
// ID is used to identify the connection.

20
tailnet/proto/compare.go Normal file
View File

@ -0,0 +1,20 @@
package proto
import (
"bytes"
gProto "google.golang.org/protobuf/proto"
)
// Equal returns true if the nodes have the same contents
func (s *Node) Equal(o *Node) (bool, error) {
sBytes, err := gProto.Marshal(s)
if err != nil {
return false, err
}
oBytes, err := gProto.Marshal(o)
if err != nil {
return false, err
}
return bytes.Equal(sBytes, oBytes), nil
}

View File

@ -13,9 +13,14 @@ import (
"cdr.dev/slog"
)
// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung.
// It is exported so that tests can use it.
const WriteTimeout = time.Second * 5
const (
// WriteTimeout is the amount of time we wait to write a node update to a connection before we
// declare it hung. It is exported so that tests can use it.
WriteTimeout = time.Second * 5
// ResponseBufferSize is the max number of responses to buffer per connection before we start
// dropping updates
ResponseBufferSize = 512
)
type TrackedConn struct {
ctx context.Context
@ -48,7 +53,7 @@ func NewTrackedConn(ctx context.Context, cancel func(),
// coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but
// the most pathological cases.
updates := make(chan []*Node, 512)
updates := make(chan []*Node, ResponseBufferSize)
now := time.Now().Unix()
return &TrackedConn{
ctx: ctx,

30
tailnet/tunnel.go Normal file
View File

@ -0,0 +1,30 @@
package tailnet
import "github.com/google/uuid"
type TunnelAuth interface {
Authorize(dst uuid.UUID) bool
}
// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
type SingleTailnetTunnelAuth struct{}
func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool {
return true
}
// ClientTunnelAuth allows connecting to a single, given agent
type ClientTunnelAuth struct {
AgentID uuid.UUID
}
func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool {
return c.AgentID == dst
}
// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
type AgentTunnelAuth struct{}
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
return false
}