feat: add tailnet v2 support to wsproxy coordinate endpoint (#11637)

wsproxy also needs to be updated to use tailnet v2 because the `tailnet.Conn` stores peers by ID, and the peerID was not being carried by the JSON protocol.  This adds a query param to the endpoint to conditionally switch to the new protocol.
This commit is contained in:
Spike Curtis 2024-01-18 10:10:36 +04:00 committed by GitHub
parent 07427e06f7
commit 8910ac715c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 99 additions and 19 deletions

View File

@ -128,6 +128,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
return api.fetchRegions(ctx)
}
api.tailnetService, err = tailnet.NewClientService(
api.Logger.Named("tailnetclient"),
&api.AGPL.TailnetCoordinator,
api.Options.DERPMapUpdateFrequency,
api.AGPL.DERPMap,
)
if err != nil {
api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err))
}
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
@ -483,6 +492,7 @@ type API struct {
provisionerDaemonAuth *provisionerDaemonAuth
licenseMetricsCollector license.MetricsCollector
tailnetService *tailnet.ClientService
}
func (api *API) Close() error {

View File

@ -9,8 +9,8 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/tailnet"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
)
// @Summary Agent is legacy
@ -52,6 +52,21 @@ func (api *API) agentIsLegacy(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
version := "1.0"
qv := r.URL.Query().Get("version")
if qv != "" {
version = qv
}
if err := agpl.CurrentVersion.Validate(version); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown or unsupported API version",
Validations: []codersdk.ValidationError{
{Field: "version", Detail: err.Error()},
},
})
return
}
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock()
@ -66,14 +81,14 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
return
}
id := uuid.New()
sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id)
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText)
defer nc.Close()
err = tailnet.ServeWorkspaceProxy(ctx, nc, sub)
id := uuid.New()
err = api.tailnetService.ServeMultiAgentClient(ctx, version, nc, id)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
} else {
_ = conn.Close(websocket.StatusGoingAway, "")
}
}

View File

@ -6,14 +6,65 @@ import (
"encoding/json"
"errors"
"net"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
)
type ClientService struct {
*agpl.ClientService
}
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
// loaded on each processed connection.
func NewClientService(
logger slog.Logger,
coordPtr *atomic.Pointer[agpl.Coordinator],
derpMapUpdateFrequency time.Duration,
derpMapFn func() *tailcfg.DERPMap,
) (
*ClientService, error,
) {
s, err := agpl.NewClientService(logger, coordPtr, derpMapUpdateFrequency, derpMapFn)
if err != nil {
return nil, err
}
return &ClientService{ClientService: s}, nil
}
func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID) error {
major, _, err := apiversion.Parse(version)
if err != nil {
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
return err
}
switch major {
case 1:
coord := *(s.CoordPtr.Load())
sub := coord.ServeMultiAgent(id)
return ServeWorkspaceProxy(ctx, conn, sub)
case 2:
auth := agpl.SingleTailnetTunnelAuth{}
streamID := agpl.StreamID{
Name: id.String(),
ID: id,
Auth: auth,
}
return s.ServeConnV2(ctx, conn, streamID)
default:
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return xerrors.New("unsupported version")
}
}
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
go func() {
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)

View File

@ -46,8 +46,8 @@ func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
// ClientService is a tailnet coordination service that accepts a connection and version from a
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
type ClientService struct {
logger slog.Logger
coordPtr *atomic.Pointer[Coordinator]
Logger slog.Logger
CoordPtr *atomic.Pointer[Coordinator]
drpc *drpcserver.Server
}
@ -61,7 +61,7 @@ func NewClientService(
) (
*ClientService, error,
) {
s := &ClientService{logger: logger, coordPtr: coordPtr}
s := &ClientService{Logger: logger, CoordPtr: coordPtr}
mux := drpcmux.New()
drpcService := &DRPCService{
CoordPtr: coordPtr,
@ -88,34 +88,38 @@ func NewClientService(
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
major, _, err := apiversion.Parse(version)
if err != nil {
s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
return err
}
switch major {
case 1:
coord := *(s.coordPtr.Load())
coord := *(s.CoordPtr.Load())
return coord.ServeClient(conn, id, agent)
case 2:
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(conn, config)
if err != nil {
return xerrors.Errorf("yamux init failed: %w", err)
}
auth := ClientTunnelAuth{AgentID: agent}
streamID := StreamID{
Name: "client",
ID: id,
Auth: auth,
}
ctx = WithStreamID(ctx, streamID)
return s.drpc.Serve(ctx, session)
return s.ServeConnV2(ctx, conn, streamID)
default:
s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return xerrors.New("unsupported version")
}
}
func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error {
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(conn, config)
if err != nil {
return xerrors.Errorf("yamux init failed: %w", err)
}
ctx = WithStreamID(ctx, streamID)
return s.drpc.Serve(ctx, session)
}
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
type DRPCService struct {
CoordPtr *atomic.Pointer[Coordinator]