coder/tailnet/conn.go

1067 lines
29 KiB
Go

package tailnet
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/http"
"net/netip"
"os"
"reflect"
"strconv"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"go4.org/netipx"
"golang.org/x/xerrors"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/connstats"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/tsdial"
"tailscale.com/net/tstun"
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
tslogger "tailscale.com/types/logger"
"tailscale.com/types/netlogtype"
"tailscale.com/types/netmap"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/netstack"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg/nmcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/cryptorand"
)
const (
WorkspaceAgentSSHPort = 1
WorkspaceAgentReconnectingPTYPort = 2
WorkspaceAgentSpeedtestPort = 3
)
// EnvMagicsockDebugLogging enables super-verbose logging for the magicsock
// internals. A logger must be supplied to the connection with the debug level
// enabled.
//
// With this disabled, you still get a lot of output if you have a valid logger
// with the debug level enabled.
const EnvMagicsockDebugLogging = "CODER_MAGICSOCK_DEBUG_LOGGING"
func init() {
// Globally disable network namespacing. All networking happens in
// userspace.
netns.SetEnabled(false)
// Tailscale, by default, "trims" the set of peers down to ones that we are
// "actively" communicating with in an effort to save memory. Since
// Tailscale removed keep-alives, it seems like open but idle connections
// (SSH, port-forward, etc) can get trimmed fairly easily, causing hangs for
// a few seconds while the connection is setup again.
//
// Note that Tailscale.com's use case is very different from ours: in their
// use case, users create one persistent tailnet per device, and it allows
// connections to every other thing in Tailscale that belongs to them. The
// tailnet stays up as long as your laptop or phone is turned on.
//
// Our use case is different: for clients, it's a point-to-point connection
// to a single workspace, and lasts only as long as the connection. For
// agents, it's connections to a small number of clients (CLI or Coderd)
// that are being actively used by the end user.
envknob.Setenv("TS_DEBUG_TRIM_WIREGUARD", "false")
}
type Options struct {
ID uuid.UUID
Addresses []netip.Prefix
DERPMap *tailcfg.DERPMap
DERPHeader *http.Header
// DERPForceWebSockets determines whether websockets is always used for DERP
// connections, rather than trying `Upgrade: derp` first and potentially
// falling back. This is useful for misbehaving proxies that prevent
// fallback due to odd behavior, like Azure App Proxy.
DERPForceWebSockets bool
// BlockEndpoints specifies whether P2P endpoints are blocked.
// If so, only DERPs can establish connections.
BlockEndpoints bool
Logger slog.Logger
ListenPort uint16
}
// NodeID creates a Tailscale NodeID from the last 8 bytes of a UUID. It ensures
// the returned NodeID is always positive.
func NodeID(uid uuid.UUID) tailcfg.NodeID {
id := int64(binary.BigEndian.Uint64(uid[8:]))
// ensure id is positive
y := id >> 63
id = (id ^ y) - y
return tailcfg.NodeID(id)
}
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
func NewConn(options *Options) (conn *Conn, err error) {
if options == nil {
options = &Options{}
}
if len(options.Addresses) == 0 {
return nil, xerrors.New("At least one IP range must be provided")
}
if options.DERPMap == nil {
return nil, xerrors.New("DERPMap must be provided")
}
nodePrivateKey := key.NewNode()
nodePublicKey := nodePrivateKey.Public()
netMap := &netmap.NetworkMap{
DERPMap: options.DERPMap,
NodeKey: nodePublicKey,
PrivateKey: nodePrivateKey,
Addresses: options.Addresses,
PacketFilter: []filter.Match{{
// Allow any protocol!
IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP},
// Allow traffic sourced from anywhere.
Srcs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
},
// Allow traffic to route anywhere.
Dsts: []filter.NetPortRange{
{
Net: netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
{
Net: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
},
Caps: []filter.CapMatch{},
}},
}
var nodeID tailcfg.NodeID
// If we're provided with a UUID, use it to populate our node ID.
if options.ID != uuid.Nil {
nodeID = NodeID(options.ID)
} else {
uid, err := cryptorand.Int63()
if err != nil {
return nil, xerrors.Errorf("generate node id: %w", err)
}
nodeID = tailcfg.NodeID(uid)
}
// This is used by functions below to identify the node via key
netMap.SelfNode = &tailcfg.Node{
ID: nodeID,
Key: nodePublicKey,
Addresses: options.Addresses,
AllowedIPs: options.Addresses,
}
wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
}
defer func() {
if err != nil {
wireguardMonitor.Close()
}
}()
dialer := &tsdial.Dialer{
Logf: Logger(options.Logger.Named("net.tsdial")),
}
sys := new(tsd.System)
wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("net.wgengine")), wgengine.Config{
NetMon: wireguardMonitor,
Dialer: dialer,
ListenPort: options.ListenPort,
SetSubsystem: sys.Set,
})
if err != nil {
return nil, xerrors.Errorf("create wgengine: %w", err)
}
defer func() {
if err != nil {
wireguardEngine.Close()
}
}()
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := wireguardEngine.PeerForIP(ip)
return ok
}
sys.Set(wireguardEngine)
magicConn := sys.MagicSock.Get()
magicConn.SetDERPForceWebsockets(options.DERPForceWebSockets)
if options.DERPHeader != nil {
magicConn.SetDERPHeader(options.DERPHeader.Clone())
}
if v, ok := os.LookupEnv(EnvMagicsockDebugLogging); ok {
vBool, err := strconv.ParseBool(v)
if err != nil {
options.Logger.Debug(context.Background(), fmt.Sprintf("magicsock debug logging disabled due to invalid value %s=%q, use true or false", EnvMagicsockDebugLogging, v))
} else {
magicConn.SetDebugLoggingEnabled(vBool)
options.Logger.Debug(context.Background(), fmt.Sprintf("magicsock debug logging set by %s=%t", EnvMagicsockDebugLogging, vBool))
}
} else {
options.Logger.Debug(context.Background(), fmt.Sprintf("magicsock debug logging disabled, use %s=true to enable", EnvMagicsockDebugLogging))
}
// Update the keys for the magic connection!
err = magicConn.SetPrivateKey(nodePrivateKey)
if err != nil {
return nil, xerrors.Errorf("set node private key: %w", err)
}
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
netStack, err := netstack.Create(
Logger(options.Logger.Named("net.netstack")),
sys.Tun.Get(),
wireguardEngine,
magicConn,
dialer,
sys.DNSManager.Get(),
)
if err != nil {
return nil, xerrors.Errorf("create netstack: %w", err)
}
dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
return netStack.DialContextTCP(ctx, dst)
}
netStack.ProcessLocalIPs = true
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
options.Logger.Debug(context.Background(), "updating network map")
wireguardEngine.SetNetworkMap(&netMapCopy)
localIPSet := netipx.IPSetBuilder{}
for _, addr := range netMap.Addresses {
localIPSet.AddPrefix(addr)
}
localIPs, _ := localIPSet.IPSet()
logIPSet := netipx.IPSetBuilder{}
logIPs, _ := logIPSet.IPSet()
wireguardEngine.SetFilter(filter.New(
netMap.PacketFilter,
localIPs,
logIPs,
nil,
Logger(options.Logger.Named("net.packet-filter")),
))
dialContext, dialCancel := context.WithCancel(context.Background())
server := &Conn{
blockEndpoints: options.BlockEndpoints,
derpForceWebSockets: options.DERPForceWebSockets,
dialContext: dialContext,
dialCancel: dialCancel,
closed: make(chan struct{}),
logger: options.Logger,
magicConn: magicConn,
dialer: dialer,
listeners: map[listenKey]*listener{},
peerMap: map[tailcfg.NodeID]*tailcfg.Node{},
lastDERPForcedWebSockets: map[int]string{},
tunDevice: sys.Tun.Get(),
netMap: netMap,
netStack: netStack,
wireguardMonitor: wireguardMonitor,
wireguardRouter: &router.Config{
LocalAddrs: netMap.Addresses,
},
wireguardEngine: wireguardEngine,
}
defer func() {
if err != nil {
_ = server.Close()
}
}()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
if err != nil {
return
}
server.lastMutex.Lock()
if s.AsOf.Before(server.lastStatus) {
// Don't process outdated status!
server.lastMutex.Unlock()
return
}
server.lastStatus = s.AsOf
if endpointsEqual(s.LocalAddrs, server.lastEndpoints) {
// No need to update the node if nothing changed!
server.lastMutex.Unlock()
return
}
server.lastEndpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...)
server.lastMutex.Unlock()
server.sendNode()
})
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni))
server.lastMutex.Lock()
if reflect.DeepEqual(server.lastNetInfo, ni) {
server.lastMutex.Unlock()
return
}
server.lastNetInfo = ni.Clone()
server.lastMutex.Unlock()
server.sendNode()
})
magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) {
server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason))
server.lastMutex.Lock()
if server.lastDERPForcedWebSockets[region] == reason {
server.lastMutex.Unlock()
return
}
server.lastDERPForcedWebSockets[region] = reason
server.lastMutex.Unlock()
server.sendNode()
})
netStack.GetTCPHandlerForFlow = server.forwardTCP
err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}
return server, nil
}
func maskUUID(uid uuid.UUID) uuid.UUID {
// This is Tailscale's ephemeral service prefix. This can be changed easily
// later-on, because all of our nodes are ephemeral.
// fd7a:115c:a1e0
uid[0] = 0xfd
uid[1] = 0x7a
uid[2] = 0x11
uid[3] = 0x5c
uid[4] = 0xa1
uid[5] = 0xe0
return uid
}
// IP generates a random IP with a static service prefix.
func IP() netip.Addr {
uid := maskUUID(uuid.New())
return netip.AddrFrom16(uid)
}
// IP generates a new IP from a UUID.
func IPFromUUID(uid uuid.UUID) netip.Addr {
return netip.AddrFrom16(maskUUID(uid))
}
// Conn is an actively listening Wireguard connection.
type Conn struct {
dialContext context.Context
dialCancel context.CancelFunc
mutex sync.Mutex
closed chan struct{}
logger slog.Logger
blockEndpoints bool
derpForceWebSockets bool
dialer *tsdial.Dialer
tunDevice *tstun.Wrapper
peerMap map[tailcfg.NodeID]*tailcfg.Node
netMap *netmap.NetworkMap
netStack *netstack.Impl
magicConn *magicsock.Conn
wireguardMonitor *netmon.Monitor
wireguardRouter *router.Config
wireguardEngine wgengine.Engine
listeners map[listenKey]*listener
lastMutex sync.Mutex
nodeSending bool
nodeChanged bool
// It's only possible to store these values via status functions,
// so the values must be stored for retrieval later on.
lastStatus time.Time
lastEndpoints []tailcfg.Endpoint
lastDERPForcedWebSockets map[int]string
lastNetInfo *tailcfg.NetInfo
nodeCallback func(node *Node)
trafficStats *connstats.Statistics
}
func (c *Conn) MagicsockSetDebugLoggingEnabled(enabled bool) {
c.magicConn.SetDebugLoggingEnabled(enabled)
}
func (c *Conn) SetAddresses(ips []netip.Prefix) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.netMap.Addresses = ips
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
func (c *Conn) Addresses() []netip.Prefix {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.Addresses
}
func (c *Conn) SetNodeCallback(callback func(node *Node)) {
c.lastMutex.Lock()
c.nodeCallback = callback
c.lastMutex.Unlock()
c.sendNode()
}
// SetDERPMap updates the DERPMap of a connection.
func (c *Conn) SetDERPMap(derpMap *tailcfg.DERPMap) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger.Debug(context.Background(), "updating derp map", slog.F("derp_map", derpMap))
c.wireguardEngine.SetDERPMap(derpMap)
c.netMap.DERPMap = derpMap
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
}
func (c *Conn) SetDERPForceWebSockets(v bool) {
c.magicConn.SetDERPForceWebsockets(v)
}
// SetBlockEndpoints sets whether or not to block P2P endpoints. This setting
// will only apply to new peers.
func (c *Conn) SetBlockEndpoints(blockEndpoints bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.blockEndpoints = blockEndpoints
}
// SetDERPRegionDialer updates the dialer to use for connecting to DERP regions.
func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tailcfg.DERPRegion) net.Conn) {
c.magicConn.SetDERPRegionDialer(dialer)
}
// UpdateNodes connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary. If replacePeers is
// true, all peers will be removed before adding the new ones.
//
//nolint:revive // Complains about replacePeers.
func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.isClosed() {
return xerrors.New("connection closed")
}
status := c.Status()
if replacePeers {
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
}
for _, peer := range c.netMap.Peers {
peerStatus, ok := status.Peer[peer.Key]
if !ok {
continue
}
// If this peer was added in the last 5 minutes, assume it
// could still be active.
if time.Since(peer.Created) < 5*time.Minute {
continue
}
// We double-check that it's safe to remove by ensuring no
// handshake has been sent in the past 5 minutes as well. Connections that
// are actively exchanging IP traffic will handshake every 2 minutes.
if time.Since(peerStatus.LastHandshake) < 5*time.Minute {
continue
}
c.logger.Debug(context.Background(), "removing peer, last handshake >5m ago",
slog.F("peer", peer.Key), slog.F("last_handshake", peerStatus.LastHandshake),
)
delete(c.peerMap, peer.ID)
}
for _, node := range nodes {
// If no preferred DERP is provided, we can't reach the node.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "no preferred DERP, skipping node", slog.F("node", node))
continue
}
c.logger.Debug(context.Background(), "adding node", slog.F("node", node))
peerStatus, ok := status.Peer[node.Key]
peerNode := &tailcfg.Node{
ID: node.ID,
Created: time.Now(),
Key: node.Key,
DiscoKey: node.DiscoKey,
Addresses: node.Addresses,
AllowedIPs: node.AllowedIPs,
Endpoints: node.Endpoints,
DERP: fmt.Sprintf("%s:%d", tailcfg.DerpMagicIP, node.PreferredDERP),
Hostinfo: (&tailcfg.Hostinfo{}).View(),
// Starting KeepAlive messages at the initialization of a connection
// causes a race condition. If we handshake before the peer has our
// node, we'll have wait for 5 seconds before trying again. Ideally,
// the first handshake starts when the user first initiates a
// connection to the peer. After a successful connection we enable
// keep alives to persist the connection and keep it from becoming
// idle. SSH connections don't send send packets while idle, so we
// use keep alives to avoid random hangs while we set up the
// connection again after inactivity.
KeepAlive: ok && peerStatus.Active,
}
if c.blockEndpoints {
peerNode.Endpoints = nil
}
c.peerMap[node.ID] = peerNode
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
// PeerSelector is used to select a peer from within a Tailnet.
type PeerSelector struct {
ID tailcfg.NodeID
IP netip.Prefix
}
func (c *Conn) RemovePeer(selector PeerSelector) (deleted bool, err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.isClosed() {
return false, xerrors.New("connection closed")
}
deleted = false
for _, peer := range c.peerMap {
if peer.ID == selector.ID {
delete(c.peerMap, peer.ID)
deleted = true
break
}
for _, peerIP := range peer.Addresses {
if peerIP.Bits() == selector.IP.Bits() && peerIP.Addr().Compare(selector.IP.Addr()) == 0 {
delete(c.peerMap, peer.ID)
deleted = true
break
}
}
}
if !deleted {
return false, nil
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err = c.reconfig()
if err != nil {
return false, xerrors.Errorf("reconfig: %w", err)
}
return true, nil
}
func (c *Conn) reconfig() error {
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("net.wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
return xerrors.Errorf("update wireguard config: %w", err)
}
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
if c.isClosed() {
return nil
}
if errors.Is(err, wgengine.ErrNoChanges) {
return nil
}
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
// NodeAddresses returns the addresses of a node from the NetworkMap.
func (c *Conn) NodeAddresses(publicKey key.NodePublic) ([]netip.Prefix, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
for _, node := range c.netMap.Peers {
if node.Key == publicKey {
return node.Addresses, true
}
}
return nil, false
}
// Status returns the current ipnstate of a connection.
func (c *Conn) Status() *ipnstate.Status {
sb := &ipnstate.StatusBuilder{WantPeers: true}
c.wireguardEngine.UpdateStatus(sb)
return sb.Status()
}
// Ping sends a Disco ping to the Wireguard engine.
// The bool returned is true if the ping was performed P2P.
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *ipnstate.PingResult, error) {
errCh := make(chan error, 1)
prChan := make(chan *ipnstate.PingResult, 1)
go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
if pr.Err != "" {
errCh <- xerrors.New(pr.Err)
return
}
prChan <- pr
})
select {
case err := <-errCh:
return 0, false, nil, err
case <-ctx.Done():
return 0, false, nil, ctx.Err()
case pr := <-prChan:
return time.Duration(pr.LatencySeconds * float64(time.Second)), pr.Endpoint != "", pr, nil
}
}
// DERPMap returns the currently set DERP mapping.
func (c *Conn) DERPMap() *tailcfg.DERPMap {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.DERPMap
}
// BlockEndpoints returns whether or not P2P is blocked.
func (c *Conn) BlockEndpoints() bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.blockEndpoints
}
// AwaitReachable pings the provided IP continually until the
// address is reachable. It's the callers responsibility to provide
// a timeout, otherwise this function will block forever.
func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // Cancel all pending pings on exit.
completedCtx, completed := context.WithCancel(context.Background())
defer completed()
run := func() {
// Safety timeout, initially we'll have around 10-20 goroutines
// running in parallel. The exponential backoff will converge
// around ~1 ping / 30s, this means we'll have around 10-20
// goroutines pending towards the end as well.
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
_, _, _, err := c.Ping(ctx, ip)
if err == nil {
completed()
}
}
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.InitialInterval = 50 * time.Millisecond
eb.MaxInterval = 30 * time.Second
// Consume the first interval since
// we'll fire off a ping immediately.
_ = eb.NextBackOff()
t := backoff.NewTicker(eb)
defer t.Stop()
go run()
for {
select {
case <-completedCtx.Done():
return true
case <-t.C:
// Pings can take a while, so we can run multiple
// in parallel to return ASAP.
go run()
case <-ctx.Done():
return false
}
}
}
// Closed is a channel that ends when the connection has
// been closed.
func (c *Conn) Closed() <-chan struct{} {
return c.closed
}
// Close shuts down the Wireguard connection.
func (c *Conn) Close() error {
c.mutex.Lock()
select {
case <-c.closed:
c.mutex.Unlock()
return nil
default:
}
close(c.closed)
c.mutex.Unlock()
var wg sync.WaitGroup
defer wg.Wait()
if c.trafficStats != nil {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)
}()
}
_ = c.netStack.Close()
c.dialCancel()
_ = c.wireguardMonitor.Close()
_ = c.dialer.Close()
// Stops internals, e.g. tunDevice, magicConn and dnsManager.
c.wireguardEngine.Close()
c.mutex.Lock()
for _, l := range c.listeners {
_ = l.closeNoLock()
}
c.listeners = nil
c.mutex.Unlock()
return nil
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func (c *Conn) sendNode() {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
if c.nodeSending {
c.nodeChanged = true
return
}
node := c.selfNode()
// Conn.UpdateNodes will skip any nodes that don't have the PreferredDERP
// set to non-zero, since we cannot reach nodes without DERP for discovery.
// Therefore, there is no point in sending the node without this, and we can
// save ourselves from churn in the tailscale/wireguard layer.
if node.PreferredDERP == 0 {
c.logger.Debug(context.Background(), "skipped sending node; no PreferredDERP", slog.F("node", node))
return
}
nodeCallback := c.nodeCallback
if nodeCallback == nil {
return
}
c.nodeSending = true
go func() {
c.logger.Debug(context.Background(), "sending node", slog.F("node", node))
nodeCallback(node)
c.lastMutex.Lock()
c.nodeSending = false
if c.nodeChanged {
c.nodeChanged = false
c.lastMutex.Unlock()
c.sendNode()
return
}
c.lastMutex.Unlock()
}()
}
// Node returns the last node that was sent to the node callback.
func (c *Conn) Node() *Node {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
return c.selfNode()
}
func (c *Conn) selfNode() *Node {
endpoints := make([]string, 0, len(c.lastEndpoints))
for _, addr := range c.lastEndpoints {
endpoints = append(endpoints, addr.Addr.String())
}
var preferredDERP int
var derpLatency map[string]float64
derpForcedWebsocket := make(map[int]string, 0)
if c.lastNetInfo != nil {
preferredDERP = c.lastNetInfo.PreferredDERP
derpLatency = c.lastNetInfo.DERPLatency
if c.derpForceWebSockets {
// We only need to store this for a single region, since this is
// mostly used for debugging purposes and doesn't actually have a
// code purpose.
derpForcedWebsocket[preferredDERP] = "DERP is configured to always fallback to WebSockets"
} else {
for k, v := range c.lastDERPForcedWebSockets {
derpForcedWebsocket[k] = v
}
}
}
node := &Node{
ID: c.netMap.SelfNode.ID,
AsOf: dbtime.Now(),
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: endpoints,
PreferredDERP: preferredDERP,
DERPLatency: derpLatency,
DERPForcedWebsocket: derpForcedWebsocket,
}
c.mutex.Lock()
if c.blockEndpoints {
node.Endpoints = nil
}
c.mutex.Unlock()
return node
}
// This and below is taken _mostly_ verbatim from Tailscale:
// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494
// Listen listens for connections only on the Tailscale network.
func (c *Conn) Listen(network, addr string) (net.Listener, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("tailnet: split host port for listen: %w", err)
}
lk := listenKey{network, host, port}
ln := &listener{
s: c,
key: lk,
addr: addr,
closed: make(chan struct{}),
conn: make(chan net.Conn),
}
c.mutex.Lock()
if c.isClosed() {
c.mutex.Unlock()
return nil, xerrors.New("closed")
}
if c.listeners == nil {
c.listeners = map[listenKey]*listener{}
}
if _, ok := c.listeners[lk]; ok {
c.mutex.Unlock()
return nil, xerrors.Errorf("tailnet: listener already open for %s, %s", network, addr)
}
c.listeners[lk] = ln
c.mutex.Unlock()
return ln, nil
}
func (c *Conn) DialContextTCP(ctx context.Context, ipp netip.AddrPort) (*gonet.TCPConn, error) {
return c.netStack.DialContextTCP(ctx, ipp)
}
func (c *Conn) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet.UDPConn, error) {
return c.netStack.DialContextUDP(ctx, ipp)
}
func (c *Conn) forwardTCP(_, dst netip.AddrPort) (handler func(net.Conn), opts []tcpip.SettableSocketOption, intercept bool) {
c.mutex.Lock()
ln, ok := c.listeners[listenKey{"tcp", "", fmt.Sprint(dst.Port())}]
c.mutex.Unlock()
if !ok {
return nil, nil, false
}
// See: https://github.com/tailscale/tailscale/blob/c7cea825aea39a00aca71ea02bab7266afc03e7c/wgengine/netstack/netstack.go#L888
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == 22 {
opt := tcpip.KeepaliveIdleOption(72 * time.Hour)
opts = append(opts, &opt)
}
return func(conn net.Conn) {
t := time.NewTimer(time.Second)
defer t.Stop()
select {
case ln.conn <- conn:
return
case <-ln.closed:
case <-c.closed:
case <-t.C:
}
_ = conn.Close()
}, opts, true
}
// SetConnStatsCallback sets a callback to be called after maxPeriod or
// maxConns, whichever comes first. Multiple calls overwrites the callback.
func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) {
connStats := connstats.NewStatistics(maxPeriod, maxConns, dump)
shutdown := func(s *connstats.Statistics) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.Shutdown(ctx)
}
c.mutex.Lock()
if c.isClosed() {
c.mutex.Unlock()
shutdown(connStats)
return
}
old := c.trafficStats
c.trafficStats = connStats
c.mutex.Unlock()
// Make sure to shutdown the old callback.
if old != nil {
shutdown(old)
}
c.tunDevice.SetStatistics(connStats)
}
func (c *Conn) MagicsockServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
c.magicConn.ServeHTTPDebug(w, r)
}
type listenKey struct {
network string
host string
port string
}
type listener struct {
s *Conn
key listenKey
addr string
conn chan net.Conn
closed chan struct{}
}
func (ln *listener) Accept() (net.Conn, error) {
var c net.Conn
select {
case c = <-ln.conn:
case <-ln.closed:
return nil, xerrors.Errorf("tailnet: %w", net.ErrClosed)
}
return c, nil
}
func (ln *listener) Addr() net.Addr { return addr{ln} }
func (ln *listener) Close() error {
ln.s.mutex.Lock()
defer ln.s.mutex.Unlock()
return ln.closeNoLock()
}
func (ln *listener) closeNoLock() error {
if v, ok := ln.s.listeners[ln.key]; ok && v == ln {
delete(ln.s.listeners, ln.key)
close(ln.closed)
}
return nil
}
type addr struct{ ln *listener }
func (a addr) Network() string { return a.ln.key.network }
func (a addr) String() string { return a.ln.addr }
// Logger converts the Tailscale logging function to use slog.
func Logger(logger slog.Logger) tslogger.Logf {
return tslogger.Logf(func(format string, args ...any) {
slog.Helper()
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
})
}
func endpointsEqual(x, y []tailcfg.Endpoint) bool {
if len(x) != len(y) {
return false
}
for i := range x {
if x[i] != y[i] {
return false
}
}
return true
}