mirror of https://github.com/coder/coder.git
feat: modify PG Coordinator to work with new v2 Tailnet API (#10573)
re: #10528 Refactors PG Coordinator to work with the Tailnet v2 API, including wrappers for the existing v1 API. The debug endpoint functions, but doesn't return sensible data, that will be in another stacked PR.
This commit is contained in:
parent
a8c25180db
commit
5c48cb4447
|
@ -2,136 +2,230 @@ package tailnet
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
|
||||
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
|
||||
// via its updates TrackedConn, which then writes them.
|
||||
// connIO manages the reading and writing to a connected peer. It reads requests via its requests
|
||||
// channel, then pushes them onto the bindings or tunnels channel. It receives responses via calls
|
||||
// to Enqueue and pushes them onto the responses channel.
|
||||
type connIO struct {
|
||||
pCtx context.Context
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger slog.Logger
|
||||
decoder *json.Decoder
|
||||
updates *agpl.TrackedConn
|
||||
bindings chan<- binding
|
||||
id uuid.UUID
|
||||
// coordCtx is the parent context, that is, the context of the Coordinator
|
||||
coordCtx context.Context
|
||||
// peerCtx is the context of the connection to our peer
|
||||
peerCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger slog.Logger
|
||||
requests <-chan *proto.CoordinateRequest
|
||||
responses chan<- *proto.CoordinateResponse
|
||||
bindings chan<- binding
|
||||
tunnels chan<- tunnel
|
||||
auth agpl.TunnelAuth
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
|
||||
name string
|
||||
start int64
|
||||
lastWrite int64
|
||||
overwrites int64
|
||||
}
|
||||
|
||||
func newConnIO(pCtx context.Context,
|
||||
func newConnIO(coordContext context.Context,
|
||||
peerCtx context.Context,
|
||||
logger slog.Logger,
|
||||
bindings chan<- binding,
|
||||
conn net.Conn,
|
||||
tunnels chan<- tunnel,
|
||||
requests <-chan *proto.CoordinateRequest,
|
||||
responses chan<- *proto.CoordinateResponse,
|
||||
id uuid.UUID,
|
||||
name string,
|
||||
kind agpl.QueueKind,
|
||||
auth agpl.TunnelAuth,
|
||||
) *connIO {
|
||||
ctx, cancel := context.WithCancel(pCtx)
|
||||
peerCtx, cancel := context.WithCancel(peerCtx)
|
||||
now := time.Now().Unix()
|
||||
c := &connIO{
|
||||
pCtx: pCtx,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
decoder: json.NewDecoder(conn),
|
||||
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
|
||||
bindings: bindings,
|
||||
id: id,
|
||||
coordCtx: coordContext,
|
||||
peerCtx: peerCtx,
|
||||
cancel: cancel,
|
||||
logger: logger.With(slog.F("name", name)),
|
||||
requests: requests,
|
||||
responses: responses,
|
||||
bindings: bindings,
|
||||
tunnels: tunnels,
|
||||
auth: auth,
|
||||
name: name,
|
||||
start: now,
|
||||
lastWrite: now,
|
||||
}
|
||||
go c.recvLoop()
|
||||
go c.updates.SendUpdates()
|
||||
logger.Info(ctx, "serving connection")
|
||||
c.logger.Info(coordContext, "serving connection")
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *connIO) recvLoop() {
|
||||
defer func() {
|
||||
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
|
||||
// canceled, but we still need to withdraw bindings.
|
||||
// withdraw bindings & tunnels when we exit. We need to use the parent context here, since
|
||||
// our own context might be canceled, but we still need to withdraw.
|
||||
b := binding{
|
||||
bKey: bKey{
|
||||
id: c.UniqueID(),
|
||||
kind: c.Kind(),
|
||||
},
|
||||
bKey: bKey(c.UniqueID()),
|
||||
}
|
||||
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
|
||||
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing bindings", slog.Error(err))
|
||||
}
|
||||
t := tunnel{
|
||||
tKey: tKey{src: c.UniqueID()},
|
||||
active: false,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||
c.logger.Debug(c.coordCtx, "parent context expired while withdrawing tunnels", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
defer c.cancel()
|
||||
defer c.Close()
|
||||
for {
|
||||
var node agpl.Node
|
||||
err := c.decoder.Decode(&node)
|
||||
req, err := recvCtx(c.peerCtx, c.requests)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, io.EOF) ||
|
||||
xerrors.Is(err, io.ErrClosedPipe) ||
|
||||
xerrors.Is(err, context.Canceled) ||
|
||||
if xerrors.Is(err, context.Canceled) ||
|
||||
xerrors.Is(err, context.DeadlineExceeded) ||
|
||||
websocket.CloseStatus(err) > 0 {
|
||||
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
|
||||
xerrors.Is(err, io.EOF) {
|
||||
c.logger.Debug(c.coordCtx, "exiting io recvLoop", slog.Error(err))
|
||||
} else {
|
||||
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
|
||||
c.logger.Error(c.coordCtx, "failed to receive request", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
|
||||
b := binding{
|
||||
bKey: bKey{
|
||||
id: c.UniqueID(),
|
||||
kind: c.Kind(),
|
||||
},
|
||||
node: &node,
|
||||
}
|
||||
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
|
||||
if err := c.handleRequest(req); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||
c.logger.Debug(c.peerCtx, "got request")
|
||||
if req.UpdateSelf != nil {
|
||||
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
|
||||
b := binding{
|
||||
bKey: bKey(c.UniqueID()),
|
||||
node: req.UpdateSelf.Node,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.peerCtx, "failed to send binding", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
if req.AddTunnel != nil {
|
||||
c.logger.Debug(c.peerCtx, "got add tunnel", slog.F("tunnel", req.AddTunnel))
|
||||
dst, err := uuid.FromBytes(req.AddTunnel.Uuid)
|
||||
if err != nil {
|
||||
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
|
||||
// this shouldn't happen unless there is a client error. Close the connection so the client
|
||||
// doesn't just happily continue thinking everything is fine.
|
||||
return err
|
||||
}
|
||||
if !c.auth.Authorize(dst) {
|
||||
return xerrors.New("unauthorized tunnel")
|
||||
}
|
||||
t := tunnel{
|
||||
tKey: tKey{
|
||||
src: c.UniqueID(),
|
||||
dst: dst,
|
||||
},
|
||||
active: true,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||
c.logger.Debug(c.peerCtx, "failed to send add tunnel", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
if req.RemoveTunnel != nil {
|
||||
c.logger.Debug(c.peerCtx, "got remove tunnel", slog.F("tunnel", req.RemoveTunnel))
|
||||
dst, err := uuid.FromBytes(req.RemoveTunnel.Uuid)
|
||||
if err != nil {
|
||||
c.logger.Error(c.peerCtx, "unable to convert bytes to UUID", slog.Error(err))
|
||||
// this shouldn't happen unless there is a client error. Close the connection so the client
|
||||
// doesn't just happily continue thinking everything is fine.
|
||||
return err
|
||||
}
|
||||
t := tunnel{
|
||||
tKey: tKey{
|
||||
src: c.UniqueID(),
|
||||
dst: dst,
|
||||
},
|
||||
active: false,
|
||||
}
|
||||
if err := sendCtx(c.coordCtx, c.tunnels, t); err != nil {
|
||||
c.logger.Debug(c.peerCtx, "failed to send remove tunnel", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO: (spikecurtis) support Disconnect
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connIO) UniqueID() uuid.UUID {
|
||||
return c.updates.UniqueID()
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *connIO) Kind() agpl.QueueKind {
|
||||
return c.updates.Kind()
|
||||
}
|
||||
|
||||
func (c *connIO) Enqueue(n []*agpl.Node) error {
|
||||
return c.updates.Enqueue(n)
|
||||
func (c *connIO) Enqueue(resp *proto.CoordinateResponse) error {
|
||||
atomic.StoreInt64(&c.lastWrite, time.Now().Unix())
|
||||
c.mu.Lock()
|
||||
closed := c.closed
|
||||
c.mu.Unlock()
|
||||
if closed {
|
||||
return xerrors.New("connIO closed")
|
||||
}
|
||||
select {
|
||||
case <-c.peerCtx.Done():
|
||||
return c.peerCtx.Err()
|
||||
case c.responses <- resp:
|
||||
c.logger.Debug(c.peerCtx, "wrote response")
|
||||
return nil
|
||||
default:
|
||||
return agpl.ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connIO) Name() string {
|
||||
return c.updates.Name()
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *connIO) Stats() (start int64, lastWrite int64) {
|
||||
return c.updates.Stats()
|
||||
return c.start, atomic.LoadInt64(&c.lastWrite)
|
||||
}
|
||||
|
||||
func (c *connIO) Overwrites() int64 {
|
||||
return c.updates.Overwrites()
|
||||
return atomic.LoadInt64(&c.overwrites)
|
||||
}
|
||||
|
||||
// CoordinatorClose is used by the coordinator when closing a Queue. It
|
||||
// should skip removing itself from the coordinator.
|
||||
func (c *connIO) CoordinatorClose() error {
|
||||
c.cancel()
|
||||
return c.updates.CoordinatorClose()
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *connIO) Done() <-chan struct{} {
|
||||
return c.ctx.Done()
|
||||
return c.peerCtx.Done()
|
||||
}
|
||||
|
||||
func (c *connIO) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.cancel()
|
||||
return c.updates.Close()
|
||||
c.closed = true
|
||||
close(c.responses)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -27,11 +27,10 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
|||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
@ -75,11 +74,10 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
|
|||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
@ -124,11 +122,10 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
|
|||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
@ -189,11 +186,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
|
|||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
@ -243,11 +239,10 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
|
|||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
@ -299,11 +294,10 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
|
|||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,7 +3,6 @@ package tailnet_test
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -17,6 +16,7 @@ import (
|
|||
"go.uber.org/goleak"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
gProto "google.golang.org/protobuf/proto"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
|
@ -52,17 +53,17 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
|
|||
defer client.close()
|
||||
client.sendNode(&agpl.Node{PreferredDERP: 10})
|
||||
require.Eventually(t, func() bool {
|
||||
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
|
||||
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("database error: %v", err)
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
return false
|
||||
}
|
||||
var node agpl.Node
|
||||
err = json.Unmarshal(clients[0].Node, &node)
|
||||
node := new(proto.Node)
|
||||
err = gProto.Unmarshal(clients[0].Node, node)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, node.PreferredDERP)
|
||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||
return true
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
|
@ -90,17 +91,17 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
||||
require.Eventually(t, func() bool {
|
||||
agents, err := store.GetTailnetAgents(ctx, agent.id)
|
||||
agents, err := store.GetTailnetPeers(ctx, agent.id)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("database error: %v", err)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return false
|
||||
}
|
||||
var node agpl.Node
|
||||
err = json.Unmarshal(agents[0].Node, &node)
|
||||
node := new(proto.Node)
|
||||
err = gProto.Unmarshal(agents[0].Node, node)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, node.PreferredDERP)
|
||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||
return true
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
err = agent.close()
|
||||
|
@ -342,39 +343,51 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
t.Logf("agent1=%s", agent1.id)
|
||||
agent2 := newTestAgent(t, coord2, "agent2")
|
||||
defer agent2.close()
|
||||
t.Logf("agent2=%s", agent2.id)
|
||||
|
||||
client11 := newTestClient(t, coord1, agent1.id)
|
||||
defer client11.close()
|
||||
t.Logf("client11=%s", client11.id)
|
||||
client12 := newTestClient(t, coord1, agent2.id)
|
||||
defer client12.close()
|
||||
t.Logf("client12=%s", client12.id)
|
||||
client21 := newTestClient(t, coord2, agent1.id)
|
||||
defer client21.close()
|
||||
t.Logf("client21=%s", client21.id)
|
||||
client22 := newTestClient(t, coord2, agent2.id)
|
||||
defer client22.close()
|
||||
t.Logf("client22=%s", client22.id)
|
||||
|
||||
t.Logf("client11 -> Node 11")
|
||||
client11.sendNode(&agpl.Node{PreferredDERP: 11})
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 11)
|
||||
|
||||
t.Logf("client21 -> Node 21")
|
||||
client21.sendNode(&agpl.Node{PreferredDERP: 21})
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 21, 11)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 21)
|
||||
|
||||
t.Logf("client22 -> Node 22")
|
||||
client22.sendNode(&agpl.Node{PreferredDERP: 22})
|
||||
assertEventuallyHasDERPs(ctx, t, agent2, 22)
|
||||
|
||||
t.Logf("agent2 -> Node 2")
|
||||
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
|
||||
assertEventuallyHasDERPs(ctx, t, client22, 2)
|
||||
assertEventuallyHasDERPs(ctx, t, client12, 2)
|
||||
|
||||
t.Logf("client12 -> Node 12")
|
||||
client12.sendNode(&agpl.Node{PreferredDERP: 12})
|
||||
assertEventuallyHasDERPs(ctx, t, agent2, 12, 22)
|
||||
assertEventuallyHasDERPs(ctx, t, agent2, 12)
|
||||
|
||||
t.Logf("agent1 -> Node 1")
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertEventuallyHasDERPs(ctx, t, client21, 1)
|
||||
assertEventuallyHasDERPs(ctx, t, client11, 1)
|
||||
|
||||
// let's close coord2
|
||||
t.Logf("close coord2")
|
||||
err = coord2.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -386,18 +399,9 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||
err = client21.recvErr(ctx, t)
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
// agent1 will see an update that drops client21.
|
||||
// In this case the update is superfluous because client11's node hasn't changed, and agents don't deprogram clients
|
||||
// from the dataplane even if they are missing. Suppressing this kind of update would require the coordinator to
|
||||
// store all the data its sent to each connection, so we don't bother.
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 11)
|
||||
|
||||
// note that although agent2 is disconnected, client12 does NOT get an update because we suppress empty updates.
|
||||
// (Its easy to tell these are superfluous.)
|
||||
|
||||
assertEventuallyNoAgents(ctx, t, store, agent2.id)
|
||||
|
||||
// Close coord1
|
||||
t.Logf("close coord1")
|
||||
err = coord1.Close()
|
||||
require.NoError(t, err)
|
||||
// this closes agent1, client12, client11
|
||||
|
@ -541,9 +545,12 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||
Return(database.TailnetCoordinator{}, nil)
|
||||
// extra calls we don't particularly care about for this test
|
||||
mStore.EXPECT().CleanTailnetCoordinators(gomock.Any()).AnyTimes().Return(nil)
|
||||
mStore.EXPECT().GetTailnetClientsForAgent(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
|
||||
mStore.EXPECT().DeleteTailnetAgent(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().Return(database.DeleteTailnetAgentRow{}, nil)
|
||||
mStore.EXPECT().GetTailnetTunnelPeerIDs(gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
|
||||
mStore.EXPECT().GetTailnetTunnelPeerBindings(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().Return(nil, nil)
|
||||
mStore.EXPECT().DeleteTailnetPeer(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().Return(database.DeleteTailnetPeerRow{}, nil)
|
||||
mStore.EXPECT().DeleteAllTailnetTunnels(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
||||
mStore.EXPECT().DeleteCoordinator(gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
||||
|
||||
uut, err := tailnet.NewPGCoord(ctx, logger, ps, mStore)
|
||||
|
@ -589,6 +596,34 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestPGCoordinator_BidirectionalTunnels tests when peers create tunnels to each other. We don't
|
||||
// do this now, but it's schematically possible, so we should make sure it doesn't break anything.
|
||||
func TestPGCoordinator_BidirectionalTunnels(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoordV2(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
p1 := newTestPeer(ctx, t, coordinator, "p1")
|
||||
defer p1.close(ctx)
|
||||
p2 := newTestPeer(ctx, t, coordinator, "p2")
|
||||
defer p2.close(ctx)
|
||||
p1.addTunnel(p2.id)
|
||||
p2.addTunnel(p1.id)
|
||||
p1.updateDERP(1)
|
||||
p2.updateDERP(2)
|
||||
|
||||
p1.assertEventuallyHasDERP(p2.id, 2)
|
||||
p2.assertEventuallyHasDERP(p1.id, 1)
|
||||
}
|
||||
|
||||
type testConn struct {
|
||||
ws, serverWS net.Conn
|
||||
nodeChan chan []*agpl.Node
|
||||
|
@ -779,7 +814,7 @@ func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.Mu
|
|||
|
||||
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
assert.Eventually(t, func() bool {
|
||||
agents, err := store.GetTailnetAgents(ctx, agentID)
|
||||
agents, err := store.GetTailnetPeers(ctx, agentID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return true
|
||||
}
|
||||
|
@ -793,7 +828,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.
|
|||
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
|
||||
clients, err := store.GetTailnetTunnelPeerIDs(ctx, agentID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return true
|
||||
}
|
||||
|
@ -804,6 +839,108 @@ func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store
|
|||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
type testPeer struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
t testing.TB
|
||||
id uuid.UUID
|
||||
name string
|
||||
resps <-chan *proto.CoordinateResponse
|
||||
reqs chan<- *proto.CoordinateRequest
|
||||
derps map[uuid.UUID]int32
|
||||
}
|
||||
|
||||
func newTestPeer(ctx context.Context, t testing.TB, coord agpl.CoordinatorV2, name string, id ...uuid.UUID) *testPeer {
|
||||
p := &testPeer{t: t, name: name, derps: make(map[uuid.UUID]int32)}
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
if len(id) > 1 {
|
||||
t.Fatal("too many")
|
||||
}
|
||||
if len(id) == 1 {
|
||||
p.id = id[0]
|
||||
} else {
|
||||
p.id = uuid.New()
|
||||
}
|
||||
// SingleTailnetTunnelAuth allows connections to arbitrary peers
|
||||
p.reqs, p.resps = coord.Coordinate(p.ctx, p.id, name, agpl.SingleTailnetTunnelAuth{})
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *testPeer) addTunnel(other uuid.UUID) {
|
||||
p.t.Helper()
|
||||
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Uuid: agpl.UUIDToByteSlice(other)}}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout adding tunnel for %s", p.name)
|
||||
return
|
||||
case p.reqs <- req:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) updateDERP(derp int32) {
|
||||
p.t.Helper()
|
||||
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout updating node for %s", p.name)
|
||||
return
|
||||
case p.reqs <- req:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) assertEventuallyHasDERP(other uuid.UUID, derp int32) {
|
||||
p.t.Helper()
|
||||
for {
|
||||
d, ok := p.derps[other]
|
||||
if ok && d == derp {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
p.t.Errorf("timeout waiting for response for %s", p.name)
|
||||
return
|
||||
case resp, ok := <-p.resps:
|
||||
if !ok {
|
||||
p.t.Errorf("responses closed for %s", p.name)
|
||||
return
|
||||
}
|
||||
for _, update := range resp.PeerUpdates {
|
||||
id, err := uuid.FromBytes(update.Uuid)
|
||||
if !assert.NoError(p.t, err) {
|
||||
return
|
||||
}
|
||||
switch update.Kind {
|
||||
case proto.CoordinateResponse_PeerUpdate_NODE:
|
||||
p.derps[id] = update.Node.PreferredDerp
|
||||
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
||||
delete(p.derps, id)
|
||||
default:
|
||||
p.t.Errorf("unhandled update kind %s", update.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *testPeer) close(ctx context.Context) {
|
||||
p.t.Helper()
|
||||
p.cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
|
||||
return
|
||||
case _, ok := <-p.resps:
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeCoordinator struct {
|
||||
ctx context.Context
|
||||
t *testing.T
|
||||
|
@ -819,12 +956,15 @@ func (c *fakeCoordinator) heartbeat() {
|
|||
|
||||
func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) {
|
||||
c.t.Helper()
|
||||
nodeRaw, err := json.Marshal(node)
|
||||
pNode, err := agpl.NodeToProto(node)
|
||||
require.NoError(c.t, err)
|
||||
_, err = c.store.UpsertTailnetAgent(c.ctx, database.UpsertTailnetAgentParams{
|
||||
nodeRaw, err := gProto.Marshal(pNode)
|
||||
require.NoError(c.t, err)
|
||||
_, err = c.store.UpsertTailnetPeer(c.ctx, database.UpsertTailnetPeerParams{
|
||||
ID: agentID,
|
||||
CoordinatorID: c.id,
|
||||
Node: nodeRaw,
|
||||
Status: database.TailnetStatusOk,
|
||||
})
|
||||
require.NoError(c.t, err)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
// Coordinator exchanges nodes with agents to establish connections.
|
||||
|
@ -48,6 +49,17 @@ type Coordinator interface {
|
|||
ServeMultiAgent(id uuid.UUID) MultiAgentConn
|
||||
}
|
||||
|
||||
// CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API.
|
||||
type CoordinatorV2 interface {
|
||||
// ServeHTTPDebug serves a debug webpage that shows the internal state of
|
||||
// the coordinator.
|
||||
ServeHTTPDebug(w http.ResponseWriter, r *http.Request)
|
||||
// Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown.
|
||||
Node(id uuid.UUID) *Node
|
||||
Close() error
|
||||
Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
|
||||
}
|
||||
|
||||
// Node represents a node in the network.
|
||||
type Node struct {
|
||||
// ID is used to identify the connection.
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
gProto "google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// Equal returns true if the nodes have the same contents
|
||||
func (s *Node) Equal(o *Node) (bool, error) {
|
||||
sBytes, err := gProto.Marshal(s)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
oBytes, err := gProto.Marshal(o)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bytes.Equal(sBytes, oBytes), nil
|
||||
}
|
|
@ -13,9 +13,14 @@ import (
|
|||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung.
|
||||
// It is exported so that tests can use it.
|
||||
const WriteTimeout = time.Second * 5
|
||||
const (
|
||||
// WriteTimeout is the amount of time we wait to write a node update to a connection before we
|
||||
// declare it hung. It is exported so that tests can use it.
|
||||
WriteTimeout = time.Second * 5
|
||||
// ResponseBufferSize is the max number of responses to buffer per connection before we start
|
||||
// dropping updates
|
||||
ResponseBufferSize = 512
|
||||
)
|
||||
|
||||
type TrackedConn struct {
|
||||
ctx context.Context
|
||||
|
@ -48,7 +53,7 @@ func NewTrackedConn(ctx context.Context, cancel func(),
|
|||
// coordinator mutex while queuing. Node updates don't
|
||||
// come quickly, so 512 should be plenty for all but
|
||||
// the most pathological cases.
|
||||
updates := make(chan []*Node, 512)
|
||||
updates := make(chan []*Node, ResponseBufferSize)
|
||||
now := time.Now().Unix()
|
||||
return &TrackedConn{
|
||||
ctx: ctx,
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
package tailnet
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type TunnelAuth interface {
|
||||
Authorize(dst uuid.UUID) bool
|
||||
}
|
||||
|
||||
// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
|
||||
type SingleTailnetTunnelAuth struct{}
|
||||
|
||||
func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ClientTunnelAuth allows connecting to a single, given agent
|
||||
type ClientTunnelAuth struct {
|
||||
AgentID uuid.UUID
|
||||
}
|
||||
|
||||
func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool {
|
||||
return c.AgentID == dst
|
||||
}
|
||||
|
||||
// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
|
||||
type AgentTunnelAuth struct{}
|
||||
|
||||
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
|
||||
return false
|
||||
}
|
Loading…
Reference in New Issue