diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 4134e591dd..af56626a8d 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -128,6 +128,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } return api.fetchRegions(ctx) } + api.tailnetService, err = tailnet.NewClientService( + api.Logger.Named("tailnetclient"), + &api.AGPL.TailnetCoordinator, + api.Options.DERPMapUpdateFrequency, + api.AGPL.DERPMap, + ) + if err != nil { + api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err)) + } oauthConfigs := &httpmw.OAuth2Configs{ Github: options.GithubOAuth2Config, @@ -483,6 +492,7 @@ type API struct { provisionerDaemonAuth *provisionerDaemonAuth licenseMetricsCollector license.MetricsCollector + tailnetService *tailnet.ClientService } func (api *API) Close() error { diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index 501095d444..bf291e45ce 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -9,8 +9,8 @@ import ( "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/tailnet" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" + agpl "github.com/coder/coder/v2/tailnet" ) // @Summary Agent is legacy @@ -52,6 +52,21 @@ func (api *API) agentIsLegacy(rw http.ResponseWriter, r *http.Request) { func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + version := "1.0" + qv := r.URL.Query().Get("version") + if qv != "" { + version = qv + } + if err := agpl.CurrentVersion.Validate(version); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unknown or unsupported API version", + Validations: []codersdk.ValidationError{ + {Field: "version", Detail: err.Error()}, + }, + }) + return + } + api.AGPL.WebsocketWaitMutex.Lock() api.AGPL.WebsocketWaitGroup.Add(1) api.AGPL.WebsocketWaitMutex.Unlock() @@ -66,14 +81,14 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request return } - id := uuid.New() - sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) - ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText) defer nc.Close() - err = tailnet.ServeWorkspaceProxy(ctx, nc, sub) + id := uuid.New() + err = api.tailnetService.ServeMultiAgentClient(ctx, version, nc, id) if err != nil { _ = conn.Close(websocket.StatusInternalError, err.Error()) + } else { + _ = conn.Close(websocket.StatusGoingAway, "") } } diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index 3150890c13..0471c076b0 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -6,14 +6,65 @@ import ( "encoding/json" "errors" "net" + "sync/atomic" "time" + "github.com/google/uuid" "golang.org/x/xerrors" + "tailscale.com/tailcfg" + "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" ) +type ClientService struct { + *agpl.ClientService +} + +// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is +// loaded on each processed connection. +func NewClientService( + logger slog.Logger, + coordPtr *atomic.Pointer[agpl.Coordinator], + derpMapUpdateFrequency time.Duration, + derpMapFn func() *tailcfg.DERPMap, +) ( + *ClientService, error, +) { + s, err := agpl.NewClientService(logger, coordPtr, derpMapUpdateFrequency, derpMapFn) + if err != nil { + return nil, err + } + return &ClientService{ClientService: s}, nil +} + +func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID) error { + major, _, err := apiversion.Parse(version) + if err != nil { + s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err)) + return err + } + switch major { + case 1: + coord := *(s.CoordPtr.Load()) + sub := coord.ServeMultiAgent(id) + return ServeWorkspaceProxy(ctx, conn, sub) + case 2: + auth := agpl.SingleTailnetTunnelAuth{} + streamID := agpl.StreamID{ + Name: id.String(), + ID: id, + Auth: auth, + } + return s.ServeConnV2(ctx, conn, streamID) + default: + s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) + return xerrors.New("unsupported version") + } +} + func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { go func() { err := forwardNodesToWorkspaceProxy(ctx, conn, ma) diff --git a/tailnet/service.go b/tailnet/service.go index 1529bf65c0..191319d16c 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -46,8 +46,8 @@ func WithStreamID(ctx context.Context, streamID StreamID) context.Context { // ClientService is a tailnet coordination service that accepts a connection and version from a // tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol. type ClientService struct { - logger slog.Logger - coordPtr *atomic.Pointer[Coordinator] + Logger slog.Logger + CoordPtr *atomic.Pointer[Coordinator] drpc *drpcserver.Server } @@ -61,7 +61,7 @@ func NewClientService( ) ( *ClientService, error, ) { - s := &ClientService{logger: logger, coordPtr: coordPtr} + s := &ClientService{Logger: logger, CoordPtr: coordPtr} mux := drpcmux.New() drpcService := &DRPCService{ CoordPtr: coordPtr, @@ -88,34 +88,38 @@ func NewClientService( func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error { major, _, err := apiversion.Parse(version) if err != nil { - s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err)) + s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err)) return err } switch major { case 1: - coord := *(s.coordPtr.Load()) + coord := *(s.CoordPtr.Load()) return coord.ServeClient(conn, id, agent) case 2: - config := yamux.DefaultConfig() - config.LogOutput = io.Discard - session, err := yamux.Server(conn, config) - if err != nil { - return xerrors.Errorf("yamux init failed: %w", err) - } auth := ClientTunnelAuth{AgentID: agent} streamID := StreamID{ Name: "client", ID: id, Auth: auth, } - ctx = WithStreamID(ctx, streamID) - return s.drpc.Serve(ctx, session) + return s.ServeConnV2(ctx, conn, streamID) default: - s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) + s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) return xerrors.New("unsupported version") } } +func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Server(conn, config) + if err != nil { + return xerrors.Errorf("yamux init failed: %w", err) + } + ctx = WithStreamID(ctx, streamID) + return s.drpc.Serve(ctx, session) +} + // DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer type DRPCService struct { CoordPtr *atomic.Pointer[Coordinator]