fix(tailnet): enforce valid agent and client addresses (#12197)

This adds the ability for `TunnelAuth` to also authorize incoming wireguard node IPs, preventing agents from reporting anything other than their static IP generated from the agent ID.
This commit is contained in:
Colin Adler 2024-03-01 09:02:33 -06:00 committed by GitHub
parent 7fbca62e08
commit e5d911462f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 389 additions and 49 deletions

View File

@ -108,11 +108,10 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
c.t.Cleanup(c.LastWorkspaceAgent) c.t.Cleanup(c.LastWorkspaceAgent)
serveCtx, cancel := context.WithCancel(ctx) serveCtx, cancel := context.WithCancel(ctx)
c.t.Cleanup(cancel) c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{ streamID := tailnet.StreamID{
Name: "agenttest", Name: "agenttest",
ID: c.agentID, ID: c.agentID,
Auth: auth, Auth: tailnet.AgentCoordinateeAuth{ID: c.agentID},
} }
serveCtx = tailnet.WithStreamID(serveCtx, streamID) serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() { go func() {

View File

@ -155,7 +155,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
streamID := tailnet.StreamID{ streamID := tailnet.StreamID{
Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
ID: workspaceAgent.ID, ID: workspaceAgent.ID,
Auth: tailnet.AgentTunnelAuth{}, Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID},
} }
ctx = tailnet.WithStreamID(ctx, streamID) ctx = tailnet.WithStreamID(ctx, streamID)
ctx = agentapi.WithAPIVersion(ctx, version) ctx = agentapi.WithAPIVersion(ctx, version)

View File

@ -54,7 +54,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client", Name: "client",
ID: clientID, ID: clientID,
Auth: tailnet.ClientTunnelAuth{AgentID: agentID}, Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
}) })
assert.NoError(t, err) assert.NoError(t, err)
})) }))

View File

@ -30,7 +30,7 @@ type connIO struct {
responses chan<- *proto.CoordinateResponse responses chan<- *proto.CoordinateResponse
bindings chan<- binding bindings chan<- binding
tunnels chan<- tunnel tunnels chan<- tunnel
auth agpl.TunnelAuth auth agpl.CoordinateeAuth
mu sync.Mutex mu sync.Mutex
closed bool closed bool
disconnected bool disconnected bool
@ -50,7 +50,7 @@ func newConnIO(coordContext context.Context,
responses chan<- *proto.CoordinateResponse, responses chan<- *proto.CoordinateResponse,
id uuid.UUID, id uuid.UUID,
name string, name string,
auth agpl.TunnelAuth, auth agpl.CoordinateeAuth,
) *connIO { ) *connIO {
peerCtx, cancel := context.WithCancel(peerCtx) peerCtx, cancel := context.WithCancel(peerCtx)
now := time.Now().Unix() now := time.Now().Unix()
@ -126,6 +126,11 @@ var errDisconnect = xerrors.New("graceful disconnect")
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error { func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.logger.Debug(c.peerCtx, "got request") c.logger.Debug(c.peerCtx, "got request")
err := c.auth.Authorize(req)
if err != nil {
return xerrors.Errorf("authorize request: %w", err)
}
if req.UpdateSelf != nil { if req.UpdateSelf != nil {
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf)) c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
b := binding{ b := binding{
@ -147,9 +152,6 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
// doesn't just happily continue thinking everything is fine. // doesn't just happily continue thinking everything is fine.
return err return err
} }
if !c.auth.Authorize(dst) {
return xerrors.New("unauthorized tunnel")
}
t := tunnel{ t := tunnel{
tKey: tKey{ tKey: tKey{
src: c.UniqueID(), src: c.UniqueID(),

View File

@ -224,7 +224,7 @@ func (c *pgCoord) Close() error {
} }
func (c *pgCoord) Coordinate( func (c *pgCoord) Coordinate(
ctx context.Context, id uuid.UUID, name string, a agpl.TunnelAuth, ctx context.Context, id uuid.UUID, name string, a agpl.CoordinateeAuth,
) ( ) (
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
) { ) {

View File

@ -5,10 +5,12 @@ import (
"database/sql" "database/sql"
"io" "io"
"net" "net"
"net/netip"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/coder/coder/v2/codersdk"
agpltest "github.com/coder/coder/v2/tailnet/test" agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/google/uuid" "github.com/google/uuid"
@ -113,6 +115,144 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
assertEventuallyLost(ctx, t, store, agent.id) assertEventuallyLost(ctx, t, store, agent.id)
} }
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(agpl.IP(), 128),
},
PreferredDERP: 10,
})
// The agent connection should be closed immediately after sending an invalid addr
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64),
},
PreferredDERP: 10,
})
// The agent connection should be closed immediately after sending an invalid addr
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
node := new(proto.Node)
err = gProto.Unmarshal(agents[0].Node, node)
assert.NoError(t, err)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentValidIPLegacy(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
node := new(proto.Node)
err = gProto.Unmarshal(agents[0].Node, node)
assert.NoError(t, err)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
t.Parallel() t.Parallel()
if !dbtestutil.WillUsePostgres() { if !dbtestutil.WillUsePostgres() {

View File

@ -52,7 +52,7 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
sub := coord.ServeMultiAgent(id) sub := coord.ServeMultiAgent(id)
return ServeWorkspaceProxy(ctx, conn, sub) return ServeWorkspaceProxy(ctx, conn, sub)
case 2: case 2:
auth := agpl.SingleTailnetTunnelAuth{} auth := agpl.SingleTailnetCoordinateeAuth{}
streamID := agpl.StreamID{ streamID := agpl.StreamID{
Name: id.String(), Name: id.String(),
ID: id, ID: id,

View File

@ -182,7 +182,7 @@ func TestDialCoordinator(t *testing.T) {
// avoid blocking // avoid blocking
reqs := make(chan *proto.CoordinateRequest, 100) reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100) resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}). mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetCoordinateeAuth{}).
Times(1). Times(1).
Return(reqs, resps) Return(reqs, resps)

View File

@ -59,7 +59,7 @@ type CoordinatorV2 interface {
// Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown. // Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown.
Node(id uuid.UUID) *Node Node(id uuid.UUID) *Node
Close() error Close() error
Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
} }
// Node represents a node in the network. // Node represents a node in the network.
@ -247,10 +247,10 @@ func NewInMemoryCoordination(
) Coordination { ) Coordination {
thisID := agentID thisID := agentID
logger = logger.With(slog.F("agent_id", agentID)) logger = logger.With(slog.F("agent_id", agentID))
var auth TunnelAuth = AgentTunnelAuth{} var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID}
if clientID != uuid.Nil { if clientID != uuid.Nil {
// this is a client connection // this is a client connection
auth = ClientTunnelAuth{AgentID: agentID} auth = ClientCoordinateeAuth{AgentID: agentID}
logger = logger.With(slog.F("client_id", clientID)) logger = logger.With(slog.F("client_id", clientID))
thisID = clientID thisID = clientID
} }
@ -420,7 +420,7 @@ type coordinator struct {
} }
func (c *coordinator) Coordinate( func (c *coordinator) Coordinate(
ctx context.Context, id uuid.UUID, name string, a TunnelAuth, ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth,
) ( ) (
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse, chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
) { ) {
@ -476,7 +476,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn { func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn {
logger = logger.With(slog.F("client_id", id)).Named("multiagent") logger = logger.With(slog.F("client_id", id)).Named("multiagent")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetTunnelAuth{}) reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetCoordinateeAuth{})
m := (&MultiAgent{ m := (&MultiAgent{
ID: id, ID: id,
OnSubscribe: func(enq Queue, agent uuid.UUID) error { OnSubscribe: func(enq Queue, agent uuid.UUID) error {
@ -584,7 +584,7 @@ func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, con
}() }()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientTunnelAuth{AgentID: agent}) reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent})
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{ err := SendCtx(ctx, reqs, &proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}, AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)},
}) })
@ -611,6 +611,11 @@ func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
if !ok || pr != p { if !ok || pr != p {
return ErrAlreadyRemoved return ErrAlreadyRemoved
} }
if err := pr.auth.Authorize(req); err != nil {
return xerrors.Errorf("authorize request: %w", err)
}
if req.UpdateSelf != nil { if req.UpdateSelf != nil {
err := c.nodeUpdateLocked(p, req.UpdateSelf.Node) err := c.nodeUpdateLocked(p, req.UpdateSelf.Node)
if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) { if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
@ -683,9 +688,6 @@ func (c *core) updateTunnelPeersLocked(id uuid.UUID, n *proto.Node, k proto.Coor
} }
func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error { func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error {
if !src.auth.Authorize(dstID) {
return xerrors.Errorf("src %s is not allowed to tunnel to %s", src.id, dstID)
}
c.tunnels.add(src.id, dstID) c.tunnels.add(src.id, dstID)
c.logger.Debug(context.Background(), "adding tunnel", c.logger.Debug(context.Background(), "adding tunnel",
slog.F("src_id", src.id), slog.F("src_id", src.id),
@ -813,7 +815,7 @@ func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
logger.Debug(ctx, "starting new agent connection") logger.Debug(ctx, "starting new agent connection")
reqs, resps := c.Coordinate(ctx, id, name, AgentTunnelAuth{}) reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id})
tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent) tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent)
go tc.SendUpdates() go tc.SendUpdates()
go v1RespLoop(ctx, cancel, logger, tc, resps) go v1RespLoop(ctx, cancel, logger, tc, resps)

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -21,6 +22,7 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/tailnettest"
@ -50,7 +52,12 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
close(closeChan) close(closeChan)
}() }()
sendNode(&tailnet.Node{}) sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return coordinator.Node(id) != nil return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
@ -60,6 +67,37 @@ func TestCoordinator(t *testing.T) {
_ = testutil.RequireRecvCtx(ctx, t, closeChan) _ = testutil.RequireRecvCtx(ctx, t, closeChan)
}) })
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 64),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients", func(t *testing.T) { t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@ -80,7 +118,12 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
close(closeChan) close(closeChan)
}() }()
sendNode(&tailnet.Node{}) sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IPFromUUID(id), 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return coordinator.Node(id) != nil return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
@ -90,6 +133,101 @@ func TestCoordinator(t *testing.T) {
_ = testutil.RequireRecvCtx(ctx, t, closeChan) _ = testutil.RequireRecvCtx(ctx, t, closeChan)
}) })
t.Run("AgentWithoutClients_ValidIPLegacy", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 128),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IPFromUUID(id), 64),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithClient", func(t *testing.T) { t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@ -435,7 +573,7 @@ func TestInMemoryCoordination(t *testing.T) {
reqs := make(chan *proto.CoordinateRequest, 100) reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100) resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps) Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
@ -462,7 +600,7 @@ func TestRemoteCoordination(t *testing.T) {
reqs := make(chan *proto.CoordinateRequest, 100) reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100) resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}). mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps) Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord var coord tailnet.Coordinator = mCoord

View File

@ -19,7 +19,7 @@ type peer struct {
node *proto.Node node *proto.Node
resps chan<- *proto.CoordinateResponse resps chan<- *proto.CoordinateResponse
reqs <-chan *proto.CoordinateRequest reqs <-chan *proto.CoordinateRequest
auth TunnelAuth auth CoordinateeAuth
sent map[uuid.UUID]*proto.Node sent map[uuid.UUID]*proto.Node
name string name string

View File

@ -29,7 +29,7 @@ type streamIDContextKey struct{}
type StreamID struct { type StreamID struct {
Name string Name string
ID uuid.UUID ID uuid.UUID
Auth TunnelAuth Auth CoordinateeAuth
} }
func WithStreamID(ctx context.Context, streamID StreamID) context.Context { func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
@ -91,7 +91,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
coord := *(s.CoordPtr.Load()) coord := *(s.CoordPtr.Load())
return coord.ServeClient(conn, id, agent) return coord.ServeClient(conn, id, agent)
case 2: case 2:
auth := ClientTunnelAuth{AgentID: agent} auth := ClientCoordinateeAuth{AgentID: agent}
streamID := StreamID{ streamID := StreamID{
Name: "client", Name: "client",
ID: id, ID: id,

View File

@ -64,7 +64,9 @@ func TestClientService_ServeClient_V2(t *testing.T) {
require.NotNil(t, call) require.NotNil(t, call)
require.Equal(t, call.ID, clientID) require.Equal(t, call.ID, clientID)
require.Equal(t, call.Name, "client") require.Equal(t, call.Name, "client")
require.True(t, call.Auth.Authorize(agentID)) require.NoError(t, call.Auth.Authorize(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
}))
req := testutil.RequireRecvCtx(ctx, t, call.Reqs) req := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp()) require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp())

View File

@ -59,7 +59,7 @@ func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call {
} }
// Coordinate mocks base method. // Coordinate mocks base method.
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(chan<- *proto.CoordinateRequest) ret0, _ := ret[0].(chan<- *proto.CoordinateRequest)

View File

@ -312,7 +312,7 @@ func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
panic("unimplemented") panic("unimplemented")
} }
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
reqs := make(chan *proto.CoordinateRequest, 100) reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100) resps := make(chan *proto.CoordinateResponse, 100)
f.CoordinateCalls <- &FakeCoordinate{ f.CoordinateCalls <- &FakeCoordinate{
@ -337,7 +337,7 @@ type FakeCoordinate struct {
Ctx context.Context Ctx context.Context
ID uuid.UUID ID uuid.UUID
Name string Name string
Auth tailnet.TunnelAuth Auth tailnet.CoordinateeAuth
Reqs chan *proto.CoordinateRequest Reqs chan *proto.CoordinateRequest
Resps chan *proto.CoordinateResponse Resps chan *proto.CoordinateResponse
} }

View File

@ -40,7 +40,7 @@ func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, nam
p.ID = uuid.New() p.ID = uuid.New()
} }
// SingleTailnetTunnelAuth allows connections to arbitrary peers // SingleTailnetTunnelAuth allows connections to arbitrary peers
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{}) p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{})
return p return p
} }

View File

@ -1,32 +1,89 @@
package tailnet package tailnet
import "github.com/google/uuid" import (
"net/netip"
type TunnelAuth interface { "github.com/google/uuid"
Authorize(dst uuid.UUID) bool "golang.org/x/xerrors"
"github.com/coder/coder/v2/tailnet/proto"
)
var legacyWorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
type CoordinateeAuth interface {
Authorize(req *proto.CoordinateRequest) error
} }
// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent // SingleTailnetCoordinateeAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
type SingleTailnetTunnelAuth struct{} type SingleTailnetCoordinateeAuth struct{}
func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool { func (SingleTailnetCoordinateeAuth) Authorize(*proto.CoordinateRequest) error {
return true return nil
} }
// ClientTunnelAuth allows connecting to a single, given agent // ClientCoordinateeAuth allows connecting to a single, given agent
type ClientTunnelAuth struct { type ClientCoordinateeAuth struct {
AgentID uuid.UUID AgentID uuid.UUID
} }
func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool { func (c ClientCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error {
return c.AgentID == dst if tun := req.GetAddTunnel(); tun != nil {
uid, err := uuid.FromBytes(tun.Id)
if err != nil {
return xerrors.Errorf("parse add tunnel id: %w", err)
}
if c.AgentID != uid {
return xerrors.Errorf("invalid agent id, expected %s, got %s", c.AgentID.String(), uid.String())
}
}
if upd := req.GetUpdateSelf(); upd != nil {
for _, addrStr := range upd.Node.Addresses {
pre, err := netip.ParsePrefix(addrStr)
if err != nil {
return xerrors.Errorf("parse node address: %w", err)
}
if pre.Bits() != 128 {
return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits())
}
}
}
return nil
} }
// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels // AgentCoordinateeAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
type AgentTunnelAuth struct{} type AgentCoordinateeAuth struct {
ID uuid.UUID
}
func (AgentTunnelAuth) Authorize(uuid.UUID) bool { func (a AgentCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error {
return false if tun := req.GetAddTunnel(); tun != nil {
return xerrors.New("agents cannot open tunnels")
}
if upd := req.GetUpdateSelf(); upd != nil {
for _, addrStr := range upd.Node.Addresses {
pre, err := netip.ParsePrefix(addrStr)
if err != nil {
return xerrors.Errorf("parse node address: %w", err)
}
if pre.Bits() != 128 {
return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits())
}
if IPFromUUID(a.ID).Compare(pre.Addr()) != 0 &&
legacyWorkspaceAgentIP.Compare(pre.Addr()) != 0 {
return xerrors.Errorf("invalid node address, got %s", pre.Addr().String())
}
}
}
return nil
} }
// tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all // tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all