From e5d911462fb87a55f811d1bfc65c26bd807542bb Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Fri, 1 Mar 2024 09:02:33 -0600 Subject: [PATCH] fix(tailnet): enforce valid agent and client addresses (#12197) This adds the ability for `TunnelAuth` to also authorize incoming wireguard node IPs, preventing agents from reporting anything other than their static IP generated from the agent ID. --- agent/agenttest/client.go | 3 +- coderd/workspaceagentsrpc.go | 2 +- codersdk/workspaceagents_internal_test.go | 2 +- enterprise/tailnet/connio.go | 12 +- enterprise/tailnet/pgcoord.go | 2 +- enterprise/tailnet/pgcoord_test.go | 140 +++++++++++++++++ enterprise/tailnet/workspaceproxy.go | 2 +- .../wsproxy/wsproxysdk/wsproxysdk_test.go | 2 +- tailnet/coordinator.go | 22 +-- tailnet/coordinator_test.go | 146 +++++++++++++++++- tailnet/peer.go | 2 +- tailnet/service.go | 4 +- tailnet/service_test.go | 4 +- tailnet/tailnettest/coordinatormock.go | 2 +- tailnet/tailnettest/tailnettest.go | 4 +- tailnet/test/peer.go | 2 +- tailnet/tunnel.go | 87 +++++++++-- 17 files changed, 389 insertions(+), 49 deletions(-) diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index 040edddb6f..22eba14483 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -108,11 +108,10 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) { c.t.Cleanup(c.LastWorkspaceAgent) serveCtx, cancel := context.WithCancel(ctx) c.t.Cleanup(cancel) - auth := tailnet.AgentTunnelAuth{} streamID := tailnet.StreamID{ Name: "agenttest", ID: c.agentID, - Auth: auth, + Auth: tailnet.AgentCoordinateeAuth{ID: c.agentID}, } serveCtx = tailnet.WithStreamID(serveCtx, streamID) go func() { diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index a62286a9c9..2656adf374 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -155,7 +155,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { streamID := tailnet.StreamID{ Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), ID: workspaceAgent.ID, - Auth: tailnet.AgentTunnelAuth{}, + Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID}, } ctx = tailnet.WithStreamID(ctx, streamID) ctx = agentapi.WithAPIVersion(ctx, version) diff --git a/codersdk/workspaceagents_internal_test.go b/codersdk/workspaceagents_internal_test.go index c71f7d440c..0228cee1e2 100644 --- a/codersdk/workspaceagents_internal_test.go +++ b/codersdk/workspaceagents_internal_test.go @@ -54,7 +54,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ Name: "client", ID: clientID, - Auth: tailnet.ClientTunnelAuth{AgentID: agentID}, + Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, }) assert.NoError(t, err) })) diff --git a/enterprise/tailnet/connio.go b/enterprise/tailnet/connio.go index 6e98dfec4c..2e64bb4bd6 100644 --- a/enterprise/tailnet/connio.go +++ b/enterprise/tailnet/connio.go @@ -30,7 +30,7 @@ type connIO struct { responses chan<- *proto.CoordinateResponse bindings chan<- binding tunnels chan<- tunnel - auth agpl.TunnelAuth + auth agpl.CoordinateeAuth mu sync.Mutex closed bool disconnected bool @@ -50,7 +50,7 @@ func newConnIO(coordContext context.Context, responses chan<- *proto.CoordinateResponse, id uuid.UUID, name string, - auth agpl.TunnelAuth, + auth agpl.CoordinateeAuth, ) *connIO { peerCtx, cancel := context.WithCancel(peerCtx) now := time.Now().Unix() @@ -126,6 +126,11 @@ var errDisconnect = xerrors.New("graceful disconnect") func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { c.logger.Debug(c.peerCtx, "got request") + err := c.auth.Authorize(req) + if err != nil { + return xerrors.Errorf("authorize request: %w", err) + } + if req.UpdateSelf != nil { c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf)) b := binding{ @@ -147,9 +152,6 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { // doesn't just happily continue thinking everything is fine. return err } - if !c.auth.Authorize(dst) { - return xerrors.New("unauthorized tunnel") - } t := tunnel{ tKey: tKey{ src: c.UniqueID(), diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index d9cd8d37b3..aabb21eef6 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -224,7 +224,7 @@ func (c *pgCoord) Close() error { } func (c *pgCoord) Coordinate( - ctx context.Context, id uuid.UUID, name string, a agpl.TunnelAuth, + ctx context.Context, id uuid.UUID, name string, a agpl.CoordinateeAuth, ) ( chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, ) { diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index c26418ff66..c11cc9b630 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -5,10 +5,12 @@ import ( "database/sql" "io" "net" + "net/netip" "sync" "testing" "time" + "github.com/coder/coder/v2/codersdk" agpltest "github.com/coder/coder/v2/tailnet/test" "github.com/google/uuid" @@ -113,6 +115,144 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { assertEventuallyLost(ctx, t, store, agent.id) } +func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator, "agent") + defer agent.close() + agent.sendNode(&agpl.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(agpl.IP(), 128), + }, + PreferredDERP: 10, + }) + + // The agent connection should be closed immediately after sending an invalid addr + testutil.RequireRecvCtx(ctx, t, agent.closeChan) + assertEventuallyLost(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator, "agent") + defer agent.close() + agent.sendNode(&agpl.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64), + }, + PreferredDERP: 10, + }) + + // The agent connection should be closed immediately after sending an invalid addr + testutil.RequireRecvCtx(ctx, t, agent.closeChan) + assertEventuallyLost(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator, "agent") + defer agent.close() + agent.sendNode(&agpl.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128), + }, + PreferredDERP: 10, + }) + require.Eventually(t, func() bool { + agents, err := store.GetTailnetPeers(ctx, agent.id) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + t.Fatalf("database error: %v", err) + } + if len(agents) == 0 { + return false + } + node := new(proto.Node) + err = gProto.Unmarshal(agents[0].Node, node) + assert.NoError(t, err) + assert.EqualValues(t, 10, node.PreferredDerp) + return true + }, testutil.WaitShort, testutil.IntervalFast) + err = agent.close() + require.NoError(t, err) + <-agent.errChan + <-agent.closeChan + assertEventuallyLost(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_AgentValidIPLegacy(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, ps := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator, "agent") + defer agent.close() + agent.sendNode(&agpl.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }, + PreferredDERP: 10, + }) + require.Eventually(t, func() bool { + agents, err := store.GetTailnetPeers(ctx, agent.id) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + t.Fatalf("database error: %v", err) + } + if len(agents) == 0 { + return false + } + node := new(proto.Node) + err = gProto.Unmarshal(agents[0].Node, node) + assert.NoError(t, err) + assert.EqualValues(t, 10, node.PreferredDerp) + return true + }, testutil.WaitShort, testutil.IntervalFast) + err = agent.close() + require.NoError(t, err) + <-agent.errChan + <-agent.closeChan + assertEventuallyLost(ctx, t, store, agent.id) +} + func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { t.Parallel() if !dbtestutil.WillUsePostgres() { diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index b7daabd891..39126115b6 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -52,7 +52,7 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin sub := coord.ServeMultiAgent(id) return ServeWorkspaceProxy(ctx, conn, sub) case 2: - auth := agpl.SingleTailnetTunnelAuth{} + auth := agpl.SingleTailnetCoordinateeAuth{} streamID := agpl.StreamID{ Name: id.String(), ID: id, diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index 11fad78c1f..870d06b71d 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -182,7 +182,7 @@ func TestDialCoordinator(t *testing.T) { // avoid blocking reqs := make(chan *proto.CoordinateRequest, 100) resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}). + mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetCoordinateeAuth{}). Times(1). Return(reqs, resps) diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 842a6bcbfa..ce9c8e99b2 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -59,7 +59,7 @@ type CoordinatorV2 interface { // Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown. Node(id uuid.UUID) *Node Close() error - Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) + Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) } // Node represents a node in the network. @@ -247,10 +247,10 @@ func NewInMemoryCoordination( ) Coordination { thisID := agentID logger = logger.With(slog.F("agent_id", agentID)) - var auth TunnelAuth = AgentTunnelAuth{} + var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID} if clientID != uuid.Nil { // this is a client connection - auth = ClientTunnelAuth{AgentID: agentID} + auth = ClientCoordinateeAuth{AgentID: agentID} logger = logger.With(slog.F("client_id", clientID)) thisID = clientID } @@ -420,7 +420,7 @@ type coordinator struct { } func (c *coordinator) Coordinate( - ctx context.Context, id uuid.UUID, name string, a TunnelAuth, + ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth, ) ( chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, ) { @@ -476,7 +476,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn { logger = logger.With(slog.F("client_id", id)).Named("multiagent") ctx, cancel := context.WithCancel(context.Background()) - reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetTunnelAuth{}) + reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetCoordinateeAuth{}) m := (&MultiAgent{ ID: id, OnSubscribe: func(enq Queue, agent uuid.UUID) error { @@ -584,7 +584,7 @@ func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, con }() ctx, cancel := context.WithCancel(ctx) defer cancel() - reqs, resps := c.Coordinate(ctx, id, id.String(), ClientTunnelAuth{AgentID: agent}) + reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent}) err := SendCtx(ctx, reqs, &proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}, }) @@ -611,6 +611,11 @@ func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error { if !ok || pr != p { return ErrAlreadyRemoved } + + if err := pr.auth.Authorize(req); err != nil { + return xerrors.Errorf("authorize request: %w", err) + } + if req.UpdateSelf != nil { err := c.nodeUpdateLocked(p, req.UpdateSelf.Node) if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) { @@ -683,9 +688,6 @@ func (c *core) updateTunnelPeersLocked(id uuid.UUID, n *proto.Node, k proto.Coor } func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error { - if !src.auth.Authorize(dstID) { - return xerrors.Errorf("src %s is not allowed to tunnel to %s", src.id, dstID) - } c.tunnels.add(src.id, dstID) c.logger.Debug(context.Background(), "adding tunnel", slog.F("src_id", src.id), @@ -813,7 +815,7 @@ func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn ctx, cancel := context.WithCancel(ctx) defer cancel() logger.Debug(ctx, "starting new agent connection") - reqs, resps := c.Coordinate(ctx, id, name, AgentTunnelAuth{}) + reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id}) tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent) go tc.SendUpdates() go v1RespLoop(ctx, cancel, logger, tc, resps) diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 72a6591051..1d6a265d42 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/netip" "sync" "sync/atomic" "testing" @@ -21,6 +22,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" @@ -50,7 +52,12 @@ func TestCoordinator(t *testing.T) { assert.NoError(t, err) close(closeChan) }() - sendNode(&tailnet.Node{}) + sendNode(&tailnet.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(tailnet.IP(), 128), + }, + PreferredDERP: 10, + }) require.Eventually(t, func() bool { return coordinator.Node(id) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -60,6 +67,37 @@ func TestCoordinator(t *testing.T) { _ = testutil.RequireRecvCtx(ctx, t, closeChan) }) + t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitMedium) + coordinator := tailnet.NewCoordinator(logger) + defer func() { + err := coordinator.Close() + require.NoError(t, err) + }() + client, server := net.Pipe() + sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeClient(server, id, uuid.New()) + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&tailnet.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(tailnet.IP(), 64), + }, + PreferredDERP: 10, + }) + + _ = testutil.RequireRecvCtx(ctx, t, errChan) + _ = testutil.RequireRecvCtx(ctx, t, closeChan) + }) + t.Run("AgentWithoutClients", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) @@ -80,7 +118,12 @@ func TestCoordinator(t *testing.T) { assert.NoError(t, err) close(closeChan) }() - sendNode(&tailnet.Node{}) + sendNode(&tailnet.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(tailnet.IPFromUUID(id), 128), + }, + PreferredDERP: 10, + }) require.Eventually(t, func() bool { return coordinator.Node(id) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -90,6 +133,101 @@ func TestCoordinator(t *testing.T) { _ = testutil.RequireRecvCtx(ctx, t, closeChan) }) + t.Run("AgentWithoutClients_ValidIPLegacy", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitMedium) + coordinator := tailnet.NewCoordinator(logger) + defer func() { + err := coordinator.Close() + require.NoError(t, err) + }() + client, server := net.Pipe() + sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id, "") + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&tailnet.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128), + }, + PreferredDERP: 10, + }) + require.Eventually(t, func() bool { + return coordinator.Node(id) != nil + }, testutil.WaitShort, testutil.IntervalFast) + err := client.Close() + require.NoError(t, err) + _ = testutil.RequireRecvCtx(ctx, t, errChan) + _ = testutil.RequireRecvCtx(ctx, t, closeChan) + }) + + t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitMedium) + coordinator := tailnet.NewCoordinator(logger) + defer func() { + err := coordinator.Close() + require.NoError(t, err) + }() + client, server := net.Pipe() + sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id, "") + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&tailnet.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(tailnet.IP(), 128), + }, + PreferredDERP: 10, + }) + _ = testutil.RequireRecvCtx(ctx, t, errChan) + _ = testutil.RequireRecvCtx(ctx, t, closeChan) + }) + + t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitMedium) + coordinator := tailnet.NewCoordinator(logger) + defer func() { + err := coordinator.Close() + require.NoError(t, err) + }() + client, server := net.Pipe() + sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { + return nil + }) + id := uuid.New() + closeChan := make(chan struct{}) + go func() { + err := coordinator.ServeAgent(server, id, "") + assert.NoError(t, err) + close(closeChan) + }() + sendNode(&tailnet.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(tailnet.IPFromUUID(id), 64), + }, + PreferredDERP: 10, + }) + _ = testutil.RequireRecvCtx(ctx, t, errChan) + _ = testutil.RequireRecvCtx(ctx, t, closeChan) + }) + t.Run("AgentWithClient", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) @@ -435,7 +573,7 @@ func TestInMemoryCoordination(t *testing.T) { reqs := make(chan *proto.CoordinateRequest, 100) resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). Times(1).Return(reqs, resps) uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) @@ -462,7 +600,7 @@ func TestRemoteCoordination(t *testing.T) { reqs := make(chan *proto.CoordinateRequest, 100) resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). Times(1).Return(reqs, resps) var coord tailnet.Coordinator = mCoord diff --git a/tailnet/peer.go b/tailnet/peer.go index 1b9300fa30..eadc882f5a 100644 --- a/tailnet/peer.go +++ b/tailnet/peer.go @@ -19,7 +19,7 @@ type peer struct { node *proto.Node resps chan<- *proto.CoordinateResponse reqs <-chan *proto.CoordinateRequest - auth TunnelAuth + auth CoordinateeAuth sent map[uuid.UUID]*proto.Node name string diff --git a/tailnet/service.go b/tailnet/service.go index 4af8d6913c..e465508ce9 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -29,7 +29,7 @@ type streamIDContextKey struct{} type StreamID struct { Name string ID uuid.UUID - Auth TunnelAuth + Auth CoordinateeAuth } func WithStreamID(ctx context.Context, streamID StreamID) context.Context { @@ -91,7 +91,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne coord := *(s.CoordPtr.Load()) return coord.ServeClient(conn, id, agent) case 2: - auth := ClientTunnelAuth{AgentID: agent} + auth := ClientCoordinateeAuth{AgentID: agent} streamID := StreamID{ Name: "client", ID: id, diff --git a/tailnet/service_test.go b/tailnet/service_test.go index 4254122ba1..572d5ad2d7 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -64,7 +64,9 @@ func TestClientService_ServeClient_V2(t *testing.T) { require.NotNil(t, call) require.Equal(t, call.ID, clientID) require.Equal(t, call.Name, "client") - require.True(t, call.Auth.Authorize(agentID)) + require.NoError(t, call.Auth.Authorize(&proto.CoordinateRequest{ + AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, + })) req := testutil.RequireRecvCtx(ctx, t, call.Reqs) require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp()) diff --git a/tailnet/tailnettest/coordinatormock.go b/tailnet/tailnettest/coordinatormock.go index b0ae36d3f8..6225b8c86a 100644 --- a/tailnet/tailnettest/coordinatormock.go +++ b/tailnet/tailnettest/coordinatormock.go @@ -59,7 +59,7 @@ func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call { } // Coordinate mocks base method. -func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { +func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(chan<- *proto.CoordinateRequest) diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 256fd58139..9b34c9bd3d 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -312,7 +312,7 @@ func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn { panic("unimplemented") } -func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { +func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { reqs := make(chan *proto.CoordinateRequest, 100) resps := make(chan *proto.CoordinateResponse, 100) f.CoordinateCalls <- &FakeCoordinate{ @@ -337,7 +337,7 @@ type FakeCoordinate struct { Ctx context.Context ID uuid.UUID Name string - Auth tailnet.TunnelAuth + Auth tailnet.CoordinateeAuth Reqs chan *proto.CoordinateRequest Resps chan *proto.CoordinateResponse } diff --git a/tailnet/test/peer.go b/tailnet/test/peer.go index fb01cb4d1d..87d0b586ed 100644 --- a/tailnet/test/peer.go +++ b/tailnet/test/peer.go @@ -40,7 +40,7 @@ func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, nam p.ID = uuid.New() } // SingleTailnetTunnelAuth allows connections to arbitrary peers - p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{}) + p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{}) return p } diff --git a/tailnet/tunnel.go b/tailnet/tunnel.go index 6fe36ee419..bc5becbc94 100644 --- a/tailnet/tunnel.go +++ b/tailnet/tunnel.go @@ -1,32 +1,89 @@ package tailnet -import "github.com/google/uuid" +import ( + "net/netip" -type TunnelAuth interface { - Authorize(dst uuid.UUID) bool + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/tailnet/proto" +) + +var legacyWorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") + +type CoordinateeAuth interface { + Authorize(req *proto.CoordinateRequest) error } -// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent -type SingleTailnetTunnelAuth struct{} +// SingleTailnetCoordinateeAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent +type SingleTailnetCoordinateeAuth struct{} -func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool { - return true +func (SingleTailnetCoordinateeAuth) Authorize(*proto.CoordinateRequest) error { + return nil } -// ClientTunnelAuth allows connecting to a single, given agent -type ClientTunnelAuth struct { +// ClientCoordinateeAuth allows connecting to a single, given agent +type ClientCoordinateeAuth struct { AgentID uuid.UUID } -func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool { - return c.AgentID == dst +func (c ClientCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error { + if tun := req.GetAddTunnel(); tun != nil { + uid, err := uuid.FromBytes(tun.Id) + if err != nil { + return xerrors.Errorf("parse add tunnel id: %w", err) + } + + if c.AgentID != uid { + return xerrors.Errorf("invalid agent id, expected %s, got %s", c.AgentID.String(), uid.String()) + } + } + + if upd := req.GetUpdateSelf(); upd != nil { + for _, addrStr := range upd.Node.Addresses { + pre, err := netip.ParsePrefix(addrStr) + if err != nil { + return xerrors.Errorf("parse node address: %w", err) + } + + if pre.Bits() != 128 { + return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits()) + } + } + } + + return nil } -// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels -type AgentTunnelAuth struct{} +// AgentCoordinateeAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels +type AgentCoordinateeAuth struct { + ID uuid.UUID +} -func (AgentTunnelAuth) Authorize(uuid.UUID) bool { - return false +func (a AgentCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error { + if tun := req.GetAddTunnel(); tun != nil { + return xerrors.New("agents cannot open tunnels") + } + + if upd := req.GetUpdateSelf(); upd != nil { + for _, addrStr := range upd.Node.Addresses { + pre, err := netip.ParsePrefix(addrStr) + if err != nil { + return xerrors.Errorf("parse node address: %w", err) + } + + if pre.Bits() != 128 { + return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits()) + } + + if IPFromUUID(a.ID).Compare(pre.Addr()) != 0 && + legacyWorkspaceAgentIP.Compare(pre.Addr()) != 0 { + return xerrors.Errorf("invalid node address, got %s", pre.Addr().String()) + } + } + } + + return nil } // tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all