fix(tailnet): enforce valid agent and client addresses (#12197)

This adds the ability for `TunnelAuth` to also authorize incoming wireguard node IPs, preventing agents from reporting anything other than their static IP generated from the agent ID.
This commit is contained in:
Colin Adler 2024-03-01 09:02:33 -06:00 committed by GitHub
parent 7fbca62e08
commit e5d911462f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 389 additions and 49 deletions

View File

@ -108,11 +108,10 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
c.t.Cleanup(c.LastWorkspaceAgent)
serveCtx, cancel := context.WithCancel(ctx)
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "agenttest",
ID: c.agentID,
Auth: auth,
Auth: tailnet.AgentCoordinateeAuth{ID: c.agentID},
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {

View File

@ -155,7 +155,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
streamID := tailnet.StreamID{
Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
ID: workspaceAgent.ID,
Auth: tailnet.AgentTunnelAuth{},
Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID},
}
ctx = tailnet.WithStreamID(ctx, streamID)
ctx = agentapi.WithAPIVersion(ctx, version)

View File

@ -54,7 +54,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
})
assert.NoError(t, err)
}))

View File

@ -30,7 +30,7 @@ type connIO struct {
responses chan<- *proto.CoordinateResponse
bindings chan<- binding
tunnels chan<- tunnel
auth agpl.TunnelAuth
auth agpl.CoordinateeAuth
mu sync.Mutex
closed bool
disconnected bool
@ -50,7 +50,7 @@ func newConnIO(coordContext context.Context,
responses chan<- *proto.CoordinateResponse,
id uuid.UUID,
name string,
auth agpl.TunnelAuth,
auth agpl.CoordinateeAuth,
) *connIO {
peerCtx, cancel := context.WithCancel(peerCtx)
now := time.Now().Unix()
@ -126,6 +126,11 @@ var errDisconnect = xerrors.New("graceful disconnect")
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
c.logger.Debug(c.peerCtx, "got request")
err := c.auth.Authorize(req)
if err != nil {
return xerrors.Errorf("authorize request: %w", err)
}
if req.UpdateSelf != nil {
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
b := binding{
@ -147,9 +152,6 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
// doesn't just happily continue thinking everything is fine.
return err
}
if !c.auth.Authorize(dst) {
return xerrors.New("unauthorized tunnel")
}
t := tunnel{
tKey: tKey{
src: c.UniqueID(),

View File

@ -224,7 +224,7 @@ func (c *pgCoord) Close() error {
}
func (c *pgCoord) Coordinate(
ctx context.Context, id uuid.UUID, name string, a agpl.TunnelAuth,
ctx context.Context, id uuid.UUID, name string, a agpl.CoordinateeAuth,
) (
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
) {

View File

@ -5,10 +5,12 @@ import (
"database/sql"
"io"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/coder/coder/v2/codersdk"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/google/uuid"
@ -113,6 +115,144 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(agpl.IP(), 128),
},
PreferredDERP: 10,
})
// The agent connection should be closed immediately after sending an invalid addr
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64),
},
PreferredDERP: 10,
})
// The agent connection should be closed immediately after sending an invalid addr
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
node := new(proto.Node)
err = gProto.Unmarshal(agents[0].Node, node)
assert.NoError(t, err)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentValidIPLegacy(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, ps := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent")
defer agent.close()
agent.sendNode(&agpl.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
node := new(proto.Node)
err = gProto.Unmarshal(agents[0].Node, node)
assert.NoError(t, err)
assert.EqualValues(t, 10, node.PreferredDerp)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {

View File

@ -52,7 +52,7 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
sub := coord.ServeMultiAgent(id)
return ServeWorkspaceProxy(ctx, conn, sub)
case 2:
auth := agpl.SingleTailnetTunnelAuth{}
auth := agpl.SingleTailnetCoordinateeAuth{}
streamID := agpl.StreamID{
Name: id.String(),
ID: id,

View File

@ -182,7 +182,7 @@ func TestDialCoordinator(t *testing.T) {
// avoid blocking
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}).
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetCoordinateeAuth{}).
Times(1).
Return(reqs, resps)

View File

@ -59,7 +59,7 @@ type CoordinatorV2 interface {
// Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown.
Node(id uuid.UUID) *Node
Close() error
Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
}
// Node represents a node in the network.
@ -247,10 +247,10 @@ func NewInMemoryCoordination(
) Coordination {
thisID := agentID
logger = logger.With(slog.F("agent_id", agentID))
var auth TunnelAuth = AgentTunnelAuth{}
var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID}
if clientID != uuid.Nil {
// this is a client connection
auth = ClientTunnelAuth{AgentID: agentID}
auth = ClientCoordinateeAuth{AgentID: agentID}
logger = logger.With(slog.F("client_id", clientID))
thisID = clientID
}
@ -420,7 +420,7 @@ type coordinator struct {
}
func (c *coordinator) Coordinate(
ctx context.Context, id uuid.UUID, name string, a TunnelAuth,
ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth,
) (
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
) {
@ -476,7 +476,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn {
logger = logger.With(slog.F("client_id", id)).Named("multiagent")
ctx, cancel := context.WithCancel(context.Background())
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetTunnelAuth{})
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetCoordinateeAuth{})
m := (&MultiAgent{
ID: id,
OnSubscribe: func(enq Queue, agent uuid.UUID) error {
@ -584,7 +584,7 @@ func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, con
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientTunnelAuth{AgentID: agent})
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent})
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)},
})
@ -611,6 +611,11 @@ func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
if !ok || pr != p {
return ErrAlreadyRemoved
}
if err := pr.auth.Authorize(req); err != nil {
return xerrors.Errorf("authorize request: %w", err)
}
if req.UpdateSelf != nil {
err := c.nodeUpdateLocked(p, req.UpdateSelf.Node)
if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
@ -683,9 +688,6 @@ func (c *core) updateTunnelPeersLocked(id uuid.UUID, n *proto.Node, k proto.Coor
}
func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error {
if !src.auth.Authorize(dstID) {
return xerrors.Errorf("src %s is not allowed to tunnel to %s", src.id, dstID)
}
c.tunnels.add(src.id, dstID)
c.logger.Debug(context.Background(), "adding tunnel",
slog.F("src_id", src.id),
@ -813,7 +815,7 @@ func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.Debug(ctx, "starting new agent connection")
reqs, resps := c.Coordinate(ctx, id, name, AgentTunnelAuth{})
reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id})
tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent)
go tc.SendUpdates()
go v1RespLoop(ctx, cancel, logger, tc, resps)

View File

@ -6,6 +6,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"sync/atomic"
"testing"
@ -21,6 +22,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
@ -50,7 +52,12 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{})
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
@ -60,6 +67,37 @@ func TestCoordinator(t *testing.T) {
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 64),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@ -80,7 +118,12 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{})
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IPFromUUID(id), 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
@ -90,6 +133,101 @@ func TestCoordinator(t *testing.T) {
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients_ValidIPLegacy", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
},
PreferredDERP: 10,
})
require.Eventually(t, func() bool {
return coordinator.Node(id) != nil
}, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 128),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium)
coordinator := tailnet.NewCoordinator(logger)
defer func() {
err := coordinator.Close()
require.NoError(t, err)
}()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IPFromUUID(id), 64),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
})
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@ -435,7 +573,7 @@ func TestInMemoryCoordination(t *testing.T) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
@ -462,7 +600,7 @@ func TestRemoteCoordination(t *testing.T) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord

View File

@ -19,7 +19,7 @@ type peer struct {
node *proto.Node
resps chan<- *proto.CoordinateResponse
reqs <-chan *proto.CoordinateRequest
auth TunnelAuth
auth CoordinateeAuth
sent map[uuid.UUID]*proto.Node
name string

View File

@ -29,7 +29,7 @@ type streamIDContextKey struct{}
type StreamID struct {
Name string
ID uuid.UUID
Auth TunnelAuth
Auth CoordinateeAuth
}
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
@ -91,7 +91,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
coord := *(s.CoordPtr.Load())
return coord.ServeClient(conn, id, agent)
case 2:
auth := ClientTunnelAuth{AgentID: agent}
auth := ClientCoordinateeAuth{AgentID: agent}
streamID := StreamID{
Name: "client",
ID: id,

View File

@ -64,7 +64,9 @@ func TestClientService_ServeClient_V2(t *testing.T) {
require.NotNil(t, call)
require.Equal(t, call.ID, clientID)
require.Equal(t, call.Name, "client")
require.True(t, call.Auth.Authorize(agentID))
require.NoError(t, call.Auth.Authorize(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
}))
req := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp())

View File

@ -59,7 +59,7 @@ func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call {
}
// Coordinate mocks base method.
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(chan<- *proto.CoordinateRequest)

View File

@ -312,7 +312,7 @@ func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
panic("unimplemented")
}
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
f.CoordinateCalls <- &FakeCoordinate{
@ -337,7 +337,7 @@ type FakeCoordinate struct {
Ctx context.Context
ID uuid.UUID
Name string
Auth tailnet.TunnelAuth
Auth tailnet.CoordinateeAuth
Reqs chan *proto.CoordinateRequest
Resps chan *proto.CoordinateResponse
}

View File

@ -40,7 +40,7 @@ func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, nam
p.ID = uuid.New()
}
// SingleTailnetTunnelAuth allows connections to arbitrary peers
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{})
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{})
return p
}

View File

@ -1,32 +1,89 @@
package tailnet
import "github.com/google/uuid"
import (
"net/netip"
type TunnelAuth interface {
Authorize(dst uuid.UUID) bool
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/tailnet/proto"
)
var legacyWorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
type CoordinateeAuth interface {
Authorize(req *proto.CoordinateRequest) error
}
// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
type SingleTailnetTunnelAuth struct{}
// SingleTailnetCoordinateeAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
type SingleTailnetCoordinateeAuth struct{}
func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool {
return true
func (SingleTailnetCoordinateeAuth) Authorize(*proto.CoordinateRequest) error {
return nil
}
// ClientTunnelAuth allows connecting to a single, given agent
type ClientTunnelAuth struct {
// ClientCoordinateeAuth allows connecting to a single, given agent
type ClientCoordinateeAuth struct {
AgentID uuid.UUID
}
func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool {
return c.AgentID == dst
func (c ClientCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error {
if tun := req.GetAddTunnel(); tun != nil {
uid, err := uuid.FromBytes(tun.Id)
if err != nil {
return xerrors.Errorf("parse add tunnel id: %w", err)
}
if c.AgentID != uid {
return xerrors.Errorf("invalid agent id, expected %s, got %s", c.AgentID.String(), uid.String())
}
}
if upd := req.GetUpdateSelf(); upd != nil {
for _, addrStr := range upd.Node.Addresses {
pre, err := netip.ParsePrefix(addrStr)
if err != nil {
return xerrors.Errorf("parse node address: %w", err)
}
if pre.Bits() != 128 {
return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits())
}
}
}
return nil
}
// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
type AgentTunnelAuth struct{}
// AgentCoordinateeAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
type AgentCoordinateeAuth struct {
ID uuid.UUID
}
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
return false
func (a AgentCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error {
if tun := req.GetAddTunnel(); tun != nil {
return xerrors.New("agents cannot open tunnels")
}
if upd := req.GetUpdateSelf(); upd != nil {
for _, addrStr := range upd.Node.Addresses {
pre, err := netip.ParsePrefix(addrStr)
if err != nil {
return xerrors.Errorf("parse node address: %w", err)
}
if pre.Bits() != 128 {
return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits())
}
if IPFromUUID(a.ID).Compare(pre.Addr()) != 0 &&
legacyWorkspaceAgentIP.Compare(pre.Addr()) != 0 {
return xerrors.Errorf("invalid node address, got %s", pre.Addr().String())
}
}
}
return nil
}
// tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all