mirror of https://github.com/coder/coder.git
fix(tailnet): enforce valid agent and client addresses (#12197)
This adds the ability for `TunnelAuth` to also authorize incoming wireguard node IPs, preventing agents from reporting anything other than their static IP generated from the agent ID.
This commit is contained in:
parent
7fbca62e08
commit
e5d911462f
|
@ -108,11 +108,10 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
|
|||
c.t.Cleanup(c.LastWorkspaceAgent)
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
c.t.Cleanup(cancel)
|
||||
auth := tailnet.AgentTunnelAuth{}
|
||||
streamID := tailnet.StreamID{
|
||||
Name: "agenttest",
|
||||
ID: c.agentID,
|
||||
Auth: auth,
|
||||
Auth: tailnet.AgentCoordinateeAuth{ID: c.agentID},
|
||||
}
|
||||
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
|
||||
go func() {
|
||||
|
|
|
@ -155,7 +155,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||
streamID := tailnet.StreamID{
|
||||
Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
|
||||
ID: workspaceAgent.ID,
|
||||
Auth: tailnet.AgentTunnelAuth{},
|
||||
Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID},
|
||||
}
|
||||
ctx = tailnet.WithStreamID(ctx, streamID)
|
||||
ctx = agentapi.WithAPIVersion(ctx, version)
|
||||
|
|
|
@ -54,7 +54,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: clientID,
|
||||
Auth: tailnet.ClientTunnelAuth{AgentID: agentID},
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
|
|
|
@ -30,7 +30,7 @@ type connIO struct {
|
|||
responses chan<- *proto.CoordinateResponse
|
||||
bindings chan<- binding
|
||||
tunnels chan<- tunnel
|
||||
auth agpl.TunnelAuth
|
||||
auth agpl.CoordinateeAuth
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
disconnected bool
|
||||
|
@ -50,7 +50,7 @@ func newConnIO(coordContext context.Context,
|
|||
responses chan<- *proto.CoordinateResponse,
|
||||
id uuid.UUID,
|
||||
name string,
|
||||
auth agpl.TunnelAuth,
|
||||
auth agpl.CoordinateeAuth,
|
||||
) *connIO {
|
||||
peerCtx, cancel := context.WithCancel(peerCtx)
|
||||
now := time.Now().Unix()
|
||||
|
@ -126,6 +126,11 @@ var errDisconnect = xerrors.New("graceful disconnect")
|
|||
|
||||
func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
||||
c.logger.Debug(c.peerCtx, "got request")
|
||||
err := c.auth.Authorize(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("authorize request: %w", err)
|
||||
}
|
||||
|
||||
if req.UpdateSelf != nil {
|
||||
c.logger.Debug(c.peerCtx, "got node update", slog.F("node", req.UpdateSelf))
|
||||
b := binding{
|
||||
|
@ -147,9 +152,6 @@ func (c *connIO) handleRequest(req *proto.CoordinateRequest) error {
|
|||
// doesn't just happily continue thinking everything is fine.
|
||||
return err
|
||||
}
|
||||
if !c.auth.Authorize(dst) {
|
||||
return xerrors.New("unauthorized tunnel")
|
||||
}
|
||||
t := tunnel{
|
||||
tKey: tKey{
|
||||
src: c.UniqueID(),
|
||||
|
|
|
@ -224,7 +224,7 @@ func (c *pgCoord) Close() error {
|
|||
}
|
||||
|
||||
func (c *pgCoord) Coordinate(
|
||||
ctx context.Context, id uuid.UUID, name string, a agpl.TunnelAuth,
|
||||
ctx context.Context, id uuid.UUID, name string, a agpl.CoordinateeAuth,
|
||||
) (
|
||||
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
|
||||
) {
|
||||
|
|
|
@ -5,10 +5,12 @@ import (
|
|||
"database/sql"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -113,6 +115,144 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator, "agent")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(agpl.IP(), 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
|
||||
// The agent connection should be closed immediately after sending an invalid addr
|
||||
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
|
||||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator, "agent")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
|
||||
// The agent connection should be closed immediately after sending an invalid addr
|
||||
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
|
||||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator, "agent")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
require.Eventually(t, func() bool {
|
||||
agents, err := store.GetTailnetPeers(ctx, agent.id)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("database error: %v", err)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return false
|
||||
}
|
||||
node := new(proto.Node)
|
||||
err = gProto.Unmarshal(agents[0].Node, node)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||
return true
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
err = agent.close()
|
||||
require.NoError(t, err)
|
||||
<-agent.errChan
|
||||
<-agent.closeChan
|
||||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentValidIPLegacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
coordinator, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
agent := newTestAgent(t, coordinator, "agent")
|
||||
defer agent.close()
|
||||
agent.sendNode(&agpl.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
require.Eventually(t, func() bool {
|
||||
agents, err := store.GetTailnetPeers(ctx, agent.id)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("database error: %v", err)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return false
|
||||
}
|
||||
node := new(proto.Node)
|
||||
err = gProto.Unmarshal(agents[0].Node, node)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||
return true
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
err = agent.close()
|
||||
require.NoError(t, err)
|
||||
<-agent.errChan
|
||||
<-agent.closeChan
|
||||
assertEventuallyLost(ctx, t, store, agent.id)
|
||||
}
|
||||
|
||||
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
|
|
|
@ -52,7 +52,7 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
|
|||
sub := coord.ServeMultiAgent(id)
|
||||
return ServeWorkspaceProxy(ctx, conn, sub)
|
||||
case 2:
|
||||
auth := agpl.SingleTailnetTunnelAuth{}
|
||||
auth := agpl.SingleTailnetCoordinateeAuth{}
|
||||
streamID := agpl.StreamID{
|
||||
Name: id.String(),
|
||||
ID: id,
|
||||
|
|
|
@ -182,7 +182,7 @@ func TestDialCoordinator(t *testing.T) {
|
|||
// avoid blocking
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetTunnelAuth{}).
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), proxyID, gomock.Any(), agpl.SingleTailnetCoordinateeAuth{}).
|
||||
Times(1).
|
||||
Return(reqs, resps)
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ type CoordinatorV2 interface {
|
|||
// Node returns a node by peer ID, if known to the coordinator. Returns nil if unknown.
|
||||
Node(id uuid.UUID) *Node
|
||||
Close() error
|
||||
Coordinate(ctx context.Context, id uuid.UUID, name string, a TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
|
||||
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
|
||||
}
|
||||
|
||||
// Node represents a node in the network.
|
||||
|
@ -247,10 +247,10 @@ func NewInMemoryCoordination(
|
|||
) Coordination {
|
||||
thisID := agentID
|
||||
logger = logger.With(slog.F("agent_id", agentID))
|
||||
var auth TunnelAuth = AgentTunnelAuth{}
|
||||
var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID}
|
||||
if clientID != uuid.Nil {
|
||||
// this is a client connection
|
||||
auth = ClientTunnelAuth{AgentID: agentID}
|
||||
auth = ClientCoordinateeAuth{AgentID: agentID}
|
||||
logger = logger.With(slog.F("client_id", clientID))
|
||||
thisID = clientID
|
||||
}
|
||||
|
@ -420,7 +420,7 @@ type coordinator struct {
|
|||
}
|
||||
|
||||
func (c *coordinator) Coordinate(
|
||||
ctx context.Context, id uuid.UUID, name string, a TunnelAuth,
|
||||
ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth,
|
||||
) (
|
||||
chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse,
|
||||
) {
|
||||
|
@ -476,7 +476,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
|
|||
func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAgentConn {
|
||||
logger = logger.With(slog.F("client_id", id)).Named("multiagent")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetTunnelAuth{})
|
||||
reqs, resps := c.Coordinate(ctx, id, id.String(), SingleTailnetCoordinateeAuth{})
|
||||
m := (&MultiAgent{
|
||||
ID: id,
|
||||
OnSubscribe: func(enq Queue, agent uuid.UUID) error {
|
||||
|
@ -584,7 +584,7 @@ func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, con
|
|||
}()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientTunnelAuth{AgentID: agent})
|
||||
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent})
|
||||
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)},
|
||||
})
|
||||
|
@ -611,6 +611,11 @@ func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
|
|||
if !ok || pr != p {
|
||||
return ErrAlreadyRemoved
|
||||
}
|
||||
|
||||
if err := pr.auth.Authorize(req); err != nil {
|
||||
return xerrors.Errorf("authorize request: %w", err)
|
||||
}
|
||||
|
||||
if req.UpdateSelf != nil {
|
||||
err := c.nodeUpdateLocked(p, req.UpdateSelf.Node)
|
||||
if xerrors.Is(err, ErrAlreadyRemoved) || xerrors.Is(err, ErrClosed) {
|
||||
|
@ -683,9 +688,6 @@ func (c *core) updateTunnelPeersLocked(id uuid.UUID, n *proto.Node, k proto.Coor
|
|||
}
|
||||
|
||||
func (c *core) addTunnelLocked(src *peer, dstID uuid.UUID) error {
|
||||
if !src.auth.Authorize(dstID) {
|
||||
return xerrors.Errorf("src %s is not allowed to tunnel to %s", src.id, dstID)
|
||||
}
|
||||
c.tunnels.add(src.id, dstID)
|
||||
c.logger.Debug(context.Background(), "adding tunnel",
|
||||
slog.F("src_id", src.id),
|
||||
|
@ -813,7 +815,7 @@ func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn
|
|||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
logger.Debug(ctx, "starting new agent connection")
|
||||
reqs, resps := c.Coordinate(ctx, id, name, AgentTunnelAuth{})
|
||||
reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id})
|
||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent)
|
||||
go tc.SendUpdates()
|
||||
go v1RespLoop(ctx, cancel, logger, tc, resps)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
@ -21,6 +22,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
|
@ -50,7 +52,12 @@ func TestCoordinator(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&tailnet.Node{})
|
||||
sendNode(&tailnet.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(tailnet.IP(), 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(id) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
@ -60,6 +67,37 @@ func TestCoordinator(t *testing.T) {
|
|||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||
})
|
||||
|
||||
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer func() {
|
||||
err := coordinator.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
client, server := net.Pipe()
|
||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
||||
return nil
|
||||
})
|
||||
id := uuid.New()
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeClient(server, id, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&tailnet.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(tailnet.IP(), 64),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
|
||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||
})
|
||||
|
||||
t.Run("AgentWithoutClients", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
@ -80,7 +118,12 @@ func TestCoordinator(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&tailnet.Node{})
|
||||
sendNode(&tailnet.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(tailnet.IPFromUUID(id), 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(id) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
@ -90,6 +133,101 @@ func TestCoordinator(t *testing.T) {
|
|||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||
})
|
||||
|
||||
t.Run("AgentWithoutClients_ValidIPLegacy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer func() {
|
||||
err := coordinator.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
client, server := net.Pipe()
|
||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
||||
return nil
|
||||
})
|
||||
id := uuid.New()
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(server, id, "")
|
||||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&tailnet.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(id) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
err := client.Close()
|
||||
require.NoError(t, err)
|
||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||
})
|
||||
|
||||
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer func() {
|
||||
err := coordinator.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
client, server := net.Pipe()
|
||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
||||
return nil
|
||||
})
|
||||
id := uuid.New()
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(server, id, "")
|
||||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&tailnet.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(tailnet.IP(), 128),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||
})
|
||||
|
||||
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer func() {
|
||||
err := coordinator.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
client, server := net.Pipe()
|
||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
||||
return nil
|
||||
})
|
||||
id := uuid.New()
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
err := coordinator.ServeAgent(server, id, "")
|
||||
assert.NoError(t, err)
|
||||
close(closeChan)
|
||||
}()
|
||||
sendNode(&tailnet.Node{
|
||||
Addresses: []netip.Prefix{
|
||||
netip.PrefixFrom(tailnet.IPFromUUID(id), 64),
|
||||
},
|
||||
PreferredDERP: 10,
|
||||
})
|
||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
||||
})
|
||||
|
||||
t.Run("AgentWithClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
@ -435,7 +573,7 @@ func TestInMemoryCoordination(t *testing.T) {
|
|||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
||||
|
@ -462,7 +600,7 @@ func TestRemoteCoordination(t *testing.T) {
|
|||
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientTunnelAuth{agentID}).
|
||||
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
var coord tailnet.Coordinator = mCoord
|
||||
|
|
|
@ -19,7 +19,7 @@ type peer struct {
|
|||
node *proto.Node
|
||||
resps chan<- *proto.CoordinateResponse
|
||||
reqs <-chan *proto.CoordinateRequest
|
||||
auth TunnelAuth
|
||||
auth CoordinateeAuth
|
||||
sent map[uuid.UUID]*proto.Node
|
||||
|
||||
name string
|
||||
|
|
|
@ -29,7 +29,7 @@ type streamIDContextKey struct{}
|
|||
type StreamID struct {
|
||||
Name string
|
||||
ID uuid.UUID
|
||||
Auth TunnelAuth
|
||||
Auth CoordinateeAuth
|
||||
}
|
||||
|
||||
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
|
||||
|
@ -91,7 +91,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
|
|||
coord := *(s.CoordPtr.Load())
|
||||
return coord.ServeClient(conn, id, agent)
|
||||
case 2:
|
||||
auth := ClientTunnelAuth{AgentID: agent}
|
||||
auth := ClientCoordinateeAuth{AgentID: agent}
|
||||
streamID := StreamID{
|
||||
Name: "client",
|
||||
ID: id,
|
||||
|
|
|
@ -64,7 +64,9 @@ func TestClientService_ServeClient_V2(t *testing.T) {
|
|||
require.NotNil(t, call)
|
||||
require.Equal(t, call.ID, clientID)
|
||||
require.Equal(t, call.Name, "client")
|
||||
require.True(t, call.Auth.Authorize(agentID))
|
||||
require.NoError(t, call.Auth.Authorize(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
}))
|
||||
req := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
||||
require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp())
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ func (mr *MockCoordinatorMockRecorder) Close() *gomock.Call {
|
|||
}
|
||||
|
||||
// Coordinate mocks base method.
|
||||
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
|
||||
func (m *MockCoordinator) Coordinate(arg0 context.Context, arg1 uuid.UUID, arg2 string, arg3 tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Coordinate", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(chan<- *proto.CoordinateRequest)
|
||||
|
|
|
@ -312,7 +312,7 @@ func (*FakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
|
|||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
|
||||
func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
|
||||
reqs := make(chan *proto.CoordinateRequest, 100)
|
||||
resps := make(chan *proto.CoordinateResponse, 100)
|
||||
f.CoordinateCalls <- &FakeCoordinate{
|
||||
|
@ -337,7 +337,7 @@ type FakeCoordinate struct {
|
|||
Ctx context.Context
|
||||
ID uuid.UUID
|
||||
Name string
|
||||
Auth tailnet.TunnelAuth
|
||||
Auth tailnet.CoordinateeAuth
|
||||
Reqs chan *proto.CoordinateRequest
|
||||
Resps chan *proto.CoordinateResponse
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, nam
|
|||
p.ID = uuid.New()
|
||||
}
|
||||
// SingleTailnetTunnelAuth allows connections to arbitrary peers
|
||||
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetTunnelAuth{})
|
||||
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{})
|
||||
return p
|
||||
}
|
||||
|
||||
|
|
|
@ -1,32 +1,89 @@
|
|||
package tailnet
|
||||
|
||||
import "github.com/google/uuid"
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
type TunnelAuth interface {
|
||||
Authorize(dst uuid.UUID) bool
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
var legacyWorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
|
||||
|
||||
type CoordinateeAuth interface {
|
||||
Authorize(req *proto.CoordinateRequest) error
|
||||
}
|
||||
|
||||
// SingleTailnetTunnelAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
|
||||
type SingleTailnetTunnelAuth struct{}
|
||||
// SingleTailnetCoordinateeAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent
|
||||
type SingleTailnetCoordinateeAuth struct{}
|
||||
|
||||
func (SingleTailnetTunnelAuth) Authorize(uuid.UUID) bool {
|
||||
return true
|
||||
func (SingleTailnetCoordinateeAuth) Authorize(*proto.CoordinateRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientTunnelAuth allows connecting to a single, given agent
|
||||
type ClientTunnelAuth struct {
|
||||
// ClientCoordinateeAuth allows connecting to a single, given agent
|
||||
type ClientCoordinateeAuth struct {
|
||||
AgentID uuid.UUID
|
||||
}
|
||||
|
||||
func (c ClientTunnelAuth) Authorize(dst uuid.UUID) bool {
|
||||
return c.AgentID == dst
|
||||
func (c ClientCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error {
|
||||
if tun := req.GetAddTunnel(); tun != nil {
|
||||
uid, err := uuid.FromBytes(tun.Id)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse add tunnel id: %w", err)
|
||||
}
|
||||
|
||||
if c.AgentID != uid {
|
||||
return xerrors.Errorf("invalid agent id, expected %s, got %s", c.AgentID.String(), uid.String())
|
||||
}
|
||||
}
|
||||
|
||||
if upd := req.GetUpdateSelf(); upd != nil {
|
||||
for _, addrStr := range upd.Node.Addresses {
|
||||
pre, err := netip.ParsePrefix(addrStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse node address: %w", err)
|
||||
}
|
||||
|
||||
if pre.Bits() != 128 {
|
||||
return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AgentTunnelAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
|
||||
type AgentTunnelAuth struct{}
|
||||
// AgentCoordinateeAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels
|
||||
type AgentCoordinateeAuth struct {
|
||||
ID uuid.UUID
|
||||
}
|
||||
|
||||
func (AgentTunnelAuth) Authorize(uuid.UUID) bool {
|
||||
return false
|
||||
func (a AgentCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error {
|
||||
if tun := req.GetAddTunnel(); tun != nil {
|
||||
return xerrors.New("agents cannot open tunnels")
|
||||
}
|
||||
|
||||
if upd := req.GetUpdateSelf(); upd != nil {
|
||||
for _, addrStr := range upd.Node.Addresses {
|
||||
pre, err := netip.ParsePrefix(addrStr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse node address: %w", err)
|
||||
}
|
||||
|
||||
if pre.Bits() != 128 {
|
||||
return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits())
|
||||
}
|
||||
|
||||
if IPFromUUID(a.ID).Compare(pre.Addr()) != 0 &&
|
||||
legacyWorkspaceAgentIP.Compare(pre.Addr()) != 0 {
|
||||
return xerrors.Errorf("invalid node address, got %s", pre.Addr().String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all
|
||||
|
|
Loading…
Reference in New Issue