package coderd import ( "bufio" "context" "crypto/tls" "errors" "fmt" "net" "net/http" "net/http/httputil" "net/netip" "net/url" "strings" "sync" "sync/atomic" "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" "tailscale.com/derp" "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/site" "github.com/coder/coder/v2/tailnet" "github.com/coder/retry" ) var tailnetTransport *http.Transport func init() { tp, valid := http.DefaultTransport.(*http.Transport) if !valid { panic("dev error: default transport is the wrong type") } tailnetTransport = tp.Clone() // We do not want to respect the proxy settings from the environment, since // all network traffic happens over wireguard. tailnetTransport.Proxy = nil } var _ workspaceapps.AgentProvider = (*ServerTailnet)(nil) // NewServerTailnet creates a new tailnet intended for use by coderd. func NewServerTailnet( ctx context.Context, logger slog.Logger, derpServer *derp.Server, derpMapFn func() *tailcfg.DERPMap, derpForceWebSockets bool, getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error), blockEndpoints bool, traceProvider trace.TracerProvider, ) (*ServerTailnet, error) { logger = logger.Named("servertailnet") conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, DERPForceWebSockets: derpForceWebSockets, Logger: logger, BlockEndpoints: blockEndpoints, }) if err != nil { return nil, xerrors.Errorf("create tailnet conn: %w", err) } serverCtx, cancel := context.WithCancel(ctx) // This is set to allow local DERP traffic to be proxied through memory // instead of needing to hit the external access URL. Don't use the ctx // given in this callback, it's only valid while connecting. if derpServer != nil { conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { if !region.EmbeddedRelay { return nil } logger.Debug(ctx, "connecting to embedded DERP via in-memory pipe") left, right := net.Pipe() go func() { defer left.Close() defer right.Close() brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) derpServer.Accept(ctx, right, brw, "internal") }() return left }) } derpMapUpdaterClosed := make(chan struct{}) originalDerpMap := derpMapFn() // it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if // there is an embedded relay, we use the local in-memory dialer. conn.SetDERPMap(originalDerpMap) go func() { defer close(derpMapUpdaterClosed) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-serverCtx.Done(): return case <-ticker.C: } newDerpMap := derpMapFn() if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) { conn.SetDERPMap(newDerpMap) originalDerpMap = newDerpMap } } }() tn := &ServerTailnet{ ctx: serverCtx, cancel: cancel, derpMapUpdaterClosed: derpMapUpdaterClosed, logger: logger, tracer: traceProvider.Tracer(tracing.TracerName), conn: conn, coordinatee: conn, getMultiAgent: getMultiAgent, agentConnectionTimes: map[uuid.UUID]time.Time{}, agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{}, transport: tailnetTransport.Clone(), connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coder", Subsystem: "servertailnet", Name: "open_connections", Help: "Total number of TCP connections currently open to workspace agents.", }, []string{"network"}), totalConns: prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "coder", Subsystem: "servertailnet", Name: "connections_total", Help: "Total number of TCP connections made to workspace agents.", }, []string{"network"}), } tn.transport.DialContext = tn.dialContext // These options are mostly just picked at random, and they can likely be // fine tuned further. Generally, users are running applications in dev mode // which can generate hundreds of requests per page load, so we increased // MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle // conns. tn.transport.MaxIdleConnsPerHost = 6 tn.transport.MaxIdleConns = 0 tn.transport.IdleConnTimeout = 10 * time.Minute // We intentionally don't verify the certificate chain here. // The connection to the workspace is already established and most // apps are already going to be accessed over plain HTTP, this config // simply allows apps being run over HTTPS to be accessed without error -- // many of which may be using self-signed certs. tn.transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, //nolint:gosec InsecureSkipVerify: true, } agentConn, err := getMultiAgent(ctx) if err != nil { return nil, xerrors.Errorf("get initial multi agent: %w", err) } tn.agentConn.Store(&agentConn) // registering the callback also triggers send of the initial node tn.coordinatee.SetNodeCallback(tn.nodeCallback) go tn.watchAgentUpdates() go tn.expireOldAgents() return tn, nil } // Conn is used to access the underlying tailnet conn of the ServerTailnet. It // should only be used for read-only purposes. func (s *ServerTailnet) Conn() *tailnet.Conn { return s.conn } func (s *ServerTailnet) nodeCallback(node *tailnet.Node) { pn, err := tailnet.NodeToProto(node) if err != nil { s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) return } err = s.getAgentConn().UpdateSelf(pn) if err != nil { s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) } } func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) { s.connsPerAgent.Describe(descs) s.totalConns.Describe(descs) } func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) { s.connsPerAgent.Collect(metrics) s.totalConns.Collect(metrics) } func (s *ServerTailnet) expireOldAgents() { const ( tick = 5 * time.Minute cutoff = 30 * time.Minute ) ticker := time.NewTicker(tick) defer ticker.Stop() for { select { case <-s.ctx.Done(): return case <-ticker.C: } s.doExpireOldAgents(cutoff) } } func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) { // TODO: add some attrs to this. ctx, span := s.tracer.Start(s.ctx, tracing.FuncName()) defer span.End() start := time.Now() deletedCount := 0 s.nodesMu.Lock() s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes))) agentConn := s.getAgentConn() for agentID, lastConnection := range s.agentConnectionTimes { // If no one has connected since the cutoff and there are no active // connections, remove the agent. if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 { err := agentConn.UnsubscribeAgent(agentID) if err != nil { s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) continue } deletedCount++ delete(s.agentConnectionTimes, agentID) } } s.nodesMu.Unlock() s.logger.Debug(s.ctx, "successfully pruned inactive agents", slog.F("deleted", deletedCount), slog.F("took", time.Since(start)), ) } func (s *ServerTailnet) watchAgentUpdates() { for { conn := s.getAgentConn() resp, ok := conn.NextUpdate(s.ctx) if !ok { if conn.IsClosed() && s.ctx.Err() == nil { s.logger.Warn(s.ctx, "multiagent closed, reinitializing") s.coordinatee.SetAllPeersLost() s.reinitCoordinator() continue } return } err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates()) if err != nil { if xerrors.Is(err, tailnet.ErrConnClosed) { s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err)) return } s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err)) return } } } func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn { return *s.agentConn.Load() } func (s *ServerTailnet) reinitCoordinator() { start := time.Now() for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); { s.nodesMu.Lock() agentConn, err := s.getMultiAgent(s.ctx) if err != nil { s.nodesMu.Unlock() s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err)) continue } s.agentConn.Store(&agentConn) // reset the Node callback, which triggers the conn to send the node immediately, and also // register for updates s.coordinatee.SetNodeCallback(s.nodeCallback) // Resubscribe to all of the agents we're tracking. for agentID := range s.agentConnectionTimes { err := agentConn.SubscribeAgent(agentID) if err != nil { s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) } } s.logger.Info(s.ctx, "successfully reinitialized multiagent", slog.F("agents", len(s.agentConnectionTimes)), slog.F("took", time.Since(start)), ) s.nodesMu.Unlock() return } } type ServerTailnet struct { ctx context.Context cancel func() derpMapUpdaterClosed chan struct{} logger slog.Logger tracer trace.Tracer // in prod, these are the same, but coordinatee is a subset of Conn's // methods which makes some tests easier. conn *tailnet.Conn coordinatee tailnet.Coordinatee getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error) agentConn atomic.Pointer[tailnet.MultiAgentConn] nodesMu sync.Mutex // agentConnectionTimes is a map of agent tailnetNodes the server wants to // keep a connection to. It contains the last time the agent was connected // to. agentConnectionTimes map[uuid.UUID]time.Time // agentTockets holds a map of all open connections to an agent. agentTickets map[uuid.UUID]map[uuid.UUID]struct{} transport *http.Transport connsPerAgent *prometheus.GaugeVec totalConns *prometheus.CounterVec } func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHostname string) *httputil.ReverseProxy { // Rewrite the targetURL's Host to point to the agent's IP. This is // necessary because due to TCP connection caching, each agent needs to be // addressed invidivually. Otherwise, all connections get dialed as // "localhost:port", causing connections to be shared across agents. tgt := *targetURL _, port, _ := net.SplitHostPort(tgt.Host) tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port) proxy := httputil.NewSingleHostReverseProxy(&tgt) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, theErr error) { var ( desc = "Failed to proxy request to application: " + theErr.Error() additionalInfo = "" additionalButtonLink = "" additionalButtonText = "" ) var tlsError tls.RecordHeaderError if (errors.As(theErr, &tlsError) && tlsError.Msg == "first record does not look like a TLS handshake") || errors.Is(theErr, http.ErrSchemeMismatch) { // If the error is due to an HTTP/HTTPS mismatch, we can provide a // more helpful error message with redirect buttons. switchURL := url.URL{ Scheme: dashboardURL.Scheme, } _, protocol, isPort := app.PortInfo() if isPort { targetProtocol := "https" if protocol == "https" { targetProtocol = "http" } app = app.ChangePortProtocol(targetProtocol) switchURL.Host = fmt.Sprintf("%s%s", app.String(), strings.TrimPrefix(wildcardHostname, "*")) additionalButtonLink = switchURL.String() additionalButtonText = fmt.Sprintf("Switch to %s", strings.ToUpper(targetProtocol)) additionalInfo += fmt.Sprintf("This error seems to be due to an app protocol mismatch, try switching to %s.", strings.ToUpper(targetProtocol)) } } site.RenderStaticErrorPage(w, r, site.ErrorPageData{ Status: http.StatusBadGateway, Title: "Bad Gateway", Description: desc, RetryEnabled: true, DashboardURL: dashboardURL.String(), AdditionalInfo: additionalInfo, AdditionalButtonLink: additionalButtonLink, AdditionalButtonText: additionalButtonText, }) } proxy.Director = s.director(agentID, proxy.Director) proxy.Transport = s.transport return proxy } type agentIDKey struct{} // director makes sure agentIDKey is set on the context in the reverse proxy. // This allows the transport to correctly identify which agent to dial to. func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) { return func(req *http.Request) { ctx := context.WithValue(req.Context(), agentIDKey{}, agentID) *req = *req.WithContext(ctx) prev(req) } } func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID) if !ok { return nil, xerrors.Errorf("no agent id attached") } nc, err := s.DialAgentNetConn(ctx, agentID, network, addr) if err != nil { return nil, err } s.connsPerAgent.WithLabelValues("tcp").Inc() s.totalConns.WithLabelValues("tcp").Inc() return &instrumentedConn{ Conn: nc, agentID: agentID, connsPerAgent: s.connsPerAgent, }, nil } func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error { s.nodesMu.Lock() defer s.nodesMu.Unlock() _, ok := s.agentConnectionTimes[agentID] // If we don't have the node, subscribe. if !ok { s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID)) err := s.getAgentConn().SubscribeAgent(agentID) if err != nil { return xerrors.Errorf("subscribe agent: %w", err) } s.agentTickets[agentID] = map[uuid.UUID]struct{}{} } s.agentConnectionTimes[agentID] = time.Now() return nil } func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) { id := uuid.New() s.nodesMu.Lock() s.agentTickets[agentID][id] = struct{}{} s.nodesMu.Unlock() return func() { s.nodesMu.Lock() delete(s.agentTickets[agentID], id) s.nodesMu.Unlock() } } func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) { var ( conn *workspacesdk.AgentConn ret func() ) s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID)) err := s.ensureAgent(agentID) if err != nil { return nil, nil, xerrors.Errorf("ensure agent: %w", err) } ret = s.acquireTicket(agentID) conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{ AgentID: agentID, CloseFunc: func() error { return workspacesdk.ErrSkipClose }, }) // Since we now have an open conn, be careful to close it if we error // without returning it to the user. reachable := conn.AwaitReachable(ctx) if !reachable { ret() return nil, nil, xerrors.New("agent is unreachable") } return conn, ret, nil } func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) { conn, release, err := s.AgentConn(ctx, agentID) if err != nil { return nil, xerrors.Errorf("acquire agent conn: %w", err) } // Since we now have an open conn, be careful to close it if we error // without returning it to the user. nc, err := conn.DialContext(ctx, network, addr) if err != nil { release() return nil, xerrors.Errorf("dial context: %w", err) } return &netConnCloser{Conn: nc, close: func() { release() }}, err } func (s *ServerTailnet) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { s.conn.MagicsockServeHTTPDebug(w, r) } type netConnCloser struct { net.Conn close func() } func (c *netConnCloser) Close() error { c.close() return c.Conn.Close() } func (s *ServerTailnet) Close() error { s.cancel() _ = s.conn.Close() s.transport.CloseIdleConnections() <-s.derpMapUpdaterClosed return nil } type instrumentedConn struct { net.Conn agentID uuid.UUID closeOnce sync.Once connsPerAgent *prometheus.GaugeVec } func (c *instrumentedConn) Close() error { c.closeOnce.Do(func() { c.connsPerAgent.WithLabelValues("tcp").Dec() }) return c.Conn.Close() }