coder/coderd/tailnet.go

556 lines
16 KiB
Go

package coderd
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/retry"
)
var tailnetTransport *http.Transport
func init() {
tp, valid := http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
tailnetTransport = tp.Clone()
// We do not want to respect the proxy settings from the environment, since
// all network traffic happens over wireguard.
tailnetTransport.Proxy = nil
}
var _ workspaceapps.AgentProvider = (*ServerTailnet)(nil)
// NewServerTailnet creates a new tailnet intended for use by coderd.
func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
derpServer *derp.Server,
derpMapFn func() *tailcfg.DERPMap,
derpForceWebSockets bool,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
blockEndpoints bool,
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPForceWebSockets: derpForceWebSockets,
Logger: logger,
BlockEndpoints: blockEndpoints,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
serverCtx, cancel := context.WithCancel(ctx)
// This is set to allow local DERP traffic to be proxied through memory
// instead of needing to hit the external access URL. Don't use the ctx
// given in this callback, it's only valid while connecting.
if derpServer != nil {
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
logger.Debug(ctx, "connecting to embedded DERP via in-memory pipe")
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
}
derpMapUpdaterClosed := make(chan struct{})
originalDerpMap := derpMapFn()
// it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if
// there is an embedded relay, we use the local in-memory dialer.
conn.SetDERPMap(originalDerpMap)
go func() {
defer close(derpMapUpdaterClosed)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-serverCtx.Done():
return
case <-ticker.C:
}
newDerpMap := derpMapFn()
if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) {
conn.SetDERPMap(newDerpMap)
originalDerpMap = newDerpMap
}
}
}()
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
derpMapUpdaterClosed: derpMapUpdaterClosed,
logger: logger,
tracer: traceProvider.Tracer(tracing.TracerName),
conn: conn,
coordinatee: conn,
getMultiAgent: getMultiAgent,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "servertailnet",
Name: "open_connections",
Help: "Total number of TCP connections currently open to workspace agents.",
}, []string{"network"}),
totalConns: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "servertailnet",
Name: "connections_total",
Help: "Total number of TCP connections made to workspace agents.",
}, []string{"network"}),
}
tn.transport.DialContext = tn.dialContext
// These options are mostly just picked at random, and they can likely be
// fine tuned further. Generally, users are running applications in dev mode
// which can generate hundreds of requests per page load, so we increased
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
// conns.
tn.transport.MaxIdleConnsPerHost = 6
tn.transport.MaxIdleConns = 0
tn.transport.IdleConnTimeout = 10 * time.Minute
// We intentionally don't verify the certificate chain here.
// The connection to the workspace is already established and most
// apps are already going to be accessed over plain HTTP, this config
// simply allows apps being run over HTTPS to be accessed without error --
// many of which may be using self-signed certs.
tn.transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
//nolint:gosec
InsecureSkipVerify: true,
}
agentConn, err := getMultiAgent(ctx)
if err != nil {
return nil, xerrors.Errorf("get initial multi agent: %w", err)
}
tn.agentConn.Store(&agentConn)
// registering the callback also triggers send of the initial node
tn.coordinatee.SetNodeCallback(tn.nodeCallback)
go tn.watchAgentUpdates()
go tn.expireOldAgents()
return tn, nil
}
// Conn is used to access the underlying tailnet conn of the ServerTailnet. It
// should only be used for read-only purposes.
func (s *ServerTailnet) Conn() *tailnet.Conn {
return s.conn
}
func (s *ServerTailnet) nodeCallback(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if err != nil {
s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
return
}
err = s.getAgentConn().UpdateSelf(pn)
if err != nil {
s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
}
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
s.connsPerAgent.Describe(descs)
s.totalConns.Describe(descs)
}
func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
s.connsPerAgent.Collect(metrics)
s.totalConns.Collect(metrics)
}
func (s *ServerTailnet) expireOldAgents() {
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}
s.doExpireOldAgents(cutoff)
}
}
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
s.nodesMu.Lock()
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentConnectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
err := agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
continue
}
deletedCount++
delete(s.agentConnectionTimes, agentID)
}
}
s.nodesMu.Unlock()
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (s *ServerTailnet) watchAgentUpdates() {
for {
conn := s.getAgentConn()
resp, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
s.coordinatee.SetAllPeersLost()
s.reinitCoordinator()
continue
}
return
}
err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
if xerrors.Is(err, tailnet.ErrConnClosed) {
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
return
}
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
return
}
}
}
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
return *s.agentConn.Load()
}
func (s *ServerTailnet) reinitCoordinator() {
start := time.Now()
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
if err != nil {
s.nodesMu.Unlock()
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
continue
}
s.agentConn.Store(&agentConn)
// reset the Node callback, which triggers the conn to send the node immediately, and also
// register for updates
s.coordinatee.SetNodeCallback(s.nodeCallback)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentConnectionTimes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.logger.Info(s.ctx, "successfully reinitialized multiagent",
slog.F("agents", len(s.agentConnectionTimes)),
slog.F("took", time.Since(start)),
)
s.nodesMu.Unlock()
return
}
}
type ServerTailnet struct {
ctx context.Context
cancel func()
derpMapUpdaterClosed chan struct{}
logger slog.Logger
tracer trace.Tracer
// in prod, these are the same, but coordinatee is a subset of Conn's
// methods which makes some tests easier.
conn *tailnet.Conn
coordinatee tailnet.Coordinatee
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
nodesMu sync.Mutex
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
// to.
agentConnectionTimes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
transport *http.Transport
connsPerAgent *prometheus.GaugeVec
totalConns *prometheus.CounterVec
}
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHostname string) *httputil.ReverseProxy {
// Rewrite the targetURL's Host to point to the agent's IP. This is
// necessary because due to TCP connection caching, each agent needs to be
// addressed invidivually. Otherwise, all connections get dialed as
// "localhost:port", causing connections to be shared across agents.
tgt := *targetURL
_, port, _ := net.SplitHostPort(tgt.Host)
tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port)
proxy := httputil.NewSingleHostReverseProxy(&tgt)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, theErr error) {
var (
desc = "Failed to proxy request to application: " + theErr.Error()
additionalInfo = ""
additionalButtonLink = ""
additionalButtonText = ""
)
var tlsError tls.RecordHeaderError
if (errors.As(theErr, &tlsError) && tlsError.Msg == "first record does not look like a TLS handshake") ||
errors.Is(theErr, http.ErrSchemeMismatch) {
// If the error is due to an HTTP/HTTPS mismatch, we can provide a
// more helpful error message with redirect buttons.
switchURL := url.URL{
Scheme: dashboardURL.Scheme,
}
_, protocol, isPort := app.PortInfo()
if isPort {
targetProtocol := "https"
if protocol == "https" {
targetProtocol = "http"
}
app = app.ChangePortProtocol(targetProtocol)
switchURL.Host = fmt.Sprintf("%s%s", app.String(), strings.TrimPrefix(wildcardHostname, "*"))
additionalButtonLink = switchURL.String()
additionalButtonText = fmt.Sprintf("Switch to %s", strings.ToUpper(targetProtocol))
additionalInfo += fmt.Sprintf("This error seems to be due to an app protocol mismatch, try switching to %s.", strings.ToUpper(targetProtocol))
}
}
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: desc,
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
AdditionalInfo: additionalInfo,
AdditionalButtonLink: additionalButtonLink,
AdditionalButtonText: additionalButtonText,
})
}
proxy.Director = s.director(agentID, proxy.Director)
proxy.Transport = s.transport
return proxy
}
type agentIDKey struct{}
// director makes sure agentIDKey is set on the context in the reverse proxy.
// This allows the transport to correctly identify which agent to dial to.
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
return func(req *http.Request) {
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
*req = *req.WithContext(ctx)
prev(req)
}
}
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
if !ok {
return nil, xerrors.Errorf("no agent id attached")
}
nc, err := s.DialAgentNetConn(ctx, agentID, network, addr)
if err != nil {
return nil, err
}
s.connsPerAgent.WithLabelValues("tcp").Inc()
s.totalConns.WithLabelValues("tcp").Inc()
return &instrumentedConn{
Conn: nc,
agentID: agentID,
connsPerAgent: s.connsPerAgent,
}, nil
}
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentConnectionTimes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentConnectionTimes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
s.nodesMu.Lock()
s.agentTickets[agentID][id] = struct{}{}
s.nodesMu.Unlock()
return func() {
s.nodesMu.Lock()
delete(s.agentTickets[agentID], id)
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
var (
conn *workspacesdk.AgentConn
ret func()
)
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return workspacesdk.ErrSkipClose },
})
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
reachable := conn.AwaitReachable(ctx)
if !reachable {
ret()
return nil, nil, xerrors.New("agent is unreachable")
}
return conn, ret, nil
}
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
conn, release, err := s.AgentConn(ctx, agentID)
if err != nil {
return nil, xerrors.Errorf("acquire agent conn: %w", err)
}
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
nc, err := conn.DialContext(ctx, network, addr)
if err != nil {
release()
return nil, xerrors.Errorf("dial context: %w", err)
}
return &netConnCloser{Conn: nc, close: func() {
release()
}}, err
}
func (s *ServerTailnet) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
s.conn.MagicsockServeHTTPDebug(w, r)
}
type netConnCloser struct {
net.Conn
close func()
}
func (c *netConnCloser) Close() error {
c.close()
return c.Conn.Close()
}
func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
<-s.derpMapUpdaterClosed
return nil
}
type instrumentedConn struct {
net.Conn
agentID uuid.UUID
closeOnce sync.Once
connsPerAgent *prometheus.GaugeVec
}
func (c *instrumentedConn) Close() error {
c.closeOnce.Do(func() {
c.connsPerAgent.WithLabelValues("tcp").Dec()
})
return c.Conn.Close()
}