package tailnet import ( "net/netip" "github.com/google/uuid" "golang.org/x/xerrors" "github.com/coder/coder/v2/tailnet/proto" ) var legacyWorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4") type CoordinateeAuth interface { Authorize(req *proto.CoordinateRequest) error } // SingleTailnetCoordinateeAuth allows all tunnels, since Coderd and wsproxy are allowed to initiate a tunnel to any agent type SingleTailnetCoordinateeAuth struct{} func (SingleTailnetCoordinateeAuth) Authorize(*proto.CoordinateRequest) error { return nil } // ClientCoordinateeAuth allows connecting to a single, given agent type ClientCoordinateeAuth struct { AgentID uuid.UUID } func (c ClientCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error { if tun := req.GetAddTunnel(); tun != nil { uid, err := uuid.FromBytes(tun.Id) if err != nil { return xerrors.Errorf("parse add tunnel id: %w", err) } if c.AgentID != uid { return xerrors.Errorf("invalid agent id, expected %s, got %s", c.AgentID.String(), uid.String()) } } if upd := req.GetUpdateSelf(); upd != nil { for _, addrStr := range upd.Node.Addresses { pre, err := netip.ParsePrefix(addrStr) if err != nil { return xerrors.Errorf("parse node address: %w", err) } if pre.Bits() != 128 { return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits()) } } } if rfh := req.GetReadyForHandshake(); rfh != nil { return xerrors.Errorf("clients may not send ready_for_handshake") } return nil } // AgentCoordinateeAuth disallows all tunnels, since agents are not allowed to initiate their own tunnels type AgentCoordinateeAuth struct { ID uuid.UUID } func (a AgentCoordinateeAuth) Authorize(req *proto.CoordinateRequest) error { if tun := req.GetAddTunnel(); tun != nil { return xerrors.New("agents cannot open tunnels") } if upd := req.GetUpdateSelf(); upd != nil { for _, addrStr := range upd.Node.Addresses { pre, err := netip.ParsePrefix(addrStr) if err != nil { return xerrors.Errorf("parse node address: %w", err) } if pre.Bits() != 128 { return xerrors.Errorf("invalid address bits, expected 128, got %d", pre.Bits()) } if IPFromUUID(a.ID).Compare(pre.Addr()) != 0 && legacyWorkspaceAgentIP.Compare(pre.Addr()) != 0 { return xerrors.Errorf("invalid node address, got %s", pre.Addr().String()) } } } return nil } // tunnelStore contains tunnel information and allows querying it. It is not threadsafe and all // methods must be serialized by holding, e.g. the core mutex. type tunnelStore struct { bySrc map[uuid.UUID]map[uuid.UUID]struct{} byDst map[uuid.UUID]map[uuid.UUID]struct{} } func newTunnelStore() *tunnelStore { return &tunnelStore{ bySrc: make(map[uuid.UUID]map[uuid.UUID]struct{}), byDst: make(map[uuid.UUID]map[uuid.UUID]struct{}), } } func (s *tunnelStore) add(src, dst uuid.UUID) { srcM, ok := s.bySrc[src] if !ok { srcM = make(map[uuid.UUID]struct{}) s.bySrc[src] = srcM } srcM[dst] = struct{}{} dstM, ok := s.byDst[dst] if !ok { dstM = make(map[uuid.UUID]struct{}) s.byDst[dst] = dstM } dstM[src] = struct{}{} } func (s *tunnelStore) remove(src, dst uuid.UUID) { delete(s.bySrc[src], dst) if len(s.bySrc[src]) == 0 { delete(s.bySrc, src) } delete(s.byDst[dst], src) if len(s.byDst[dst]) == 0 { delete(s.byDst, dst) } } func (s *tunnelStore) removeAll(src uuid.UUID) { for dst := range s.bySrc[src] { s.remove(src, dst) } } func (s *tunnelStore) findTunnelPeers(id uuid.UUID) []uuid.UUID { set := make(map[uuid.UUID]struct{}) for dst := range s.bySrc[id] { set[dst] = struct{}{} } for src := range s.byDst[id] { set[src] = struct{}{} } out := make([]uuid.UUID, 0, len(set)) for id := range set { out = append(out, id) } return out } func (s *tunnelStore) tunnelExists(src, dst uuid.UUID) bool { _, srcOK := s.bySrc[src][dst] _, dstOK := s.byDst[src][dst] return srcOK || dstOK } func (s *tunnelStore) htmlDebug() []HTMLTunnel { out := make([]HTMLTunnel, 0) for src, dsts := range s.bySrc { for dst := range dsts { out = append(out, HTMLTunnel{Src: src, Dst: dst}) } } return out }