From 3e0e7f87390d160124cfb435fda92b44c57973dc Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 23 Jan 2024 14:27:49 +0400 Subject: [PATCH] feat: check agent API version on connection (#11696) fixes #10531 Adds a check for `version` on connection to the Agent API websocket endpoint. This is primarily for future-proofing, so that up-level agents get a sensible error if they connect to a back-level Coderd. It also refactors the location of the `CurrentVersion` variables, to be part of the `proto` packages, since the versions refer to the APIs defined therein. --- agent/proto/version.go | 10 ++++++++++ coderd/workspaceagents.go | 3 ++- coderd/workspaceagentsrpc.go | 18 ++++++++++++++++++ codersdk/agentsdk/agentsdk.go | 13 +++++++++---- codersdk/workspaceagents.go | 3 ++- enterprise/coderd/workspaceproxycoordinate.go | 4 ++-- enterprise/wsproxy/wsproxysdk/wsproxysdk.go | 2 +- .../wsproxy/wsproxysdk/wsproxysdk_test.go | 2 +- tailnet/coordinator_test.go | 2 +- tailnet/proto/version.go | 12 ++++++++++++ tailnet/service.go | 7 ------- 11 files changed, 58 insertions(+), 18 deletions(-) create mode 100644 agent/proto/version.go create mode 100644 tailnet/proto/version.go diff --git a/agent/proto/version.go b/agent/proto/version.go new file mode 100644 index 0000000000..34d5c4f1bd --- /dev/null +++ b/agent/proto/version.go @@ -0,0 +1,10 @@ +package proto + +import ( + "github.com/coder/coder/v2/tailnet/proto" +) + +// CurrentVersion is the current version of the agent API. It is tied to the +// tailnet API version to avoid confusion, since agents connect to the tailnet +// API over the same websocket. +var CurrentVersion = proto.CurrentVersion diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index ad508eebed..568fb17f20 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -43,6 +43,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" ) // @Summary Get workspace agent by ID @@ -1162,7 +1163,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R if qv != "" { version = qv } - if err := tailnet.CurrentVersion.Validate(version); err != nil { + if err := proto.CurrentVersion.Validate(version); err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Unknown or unsupported API version", Validations: []codersdk.ValidationError{ diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 6b9438a8b8..dd89f07460 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -16,6 +16,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/agentapi" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -37,6 +38,23 @@ import ( func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + version := r.URL.Query().Get("version") + if version == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing required query parameter: version", + }) + return + } + if err := proto.CurrentVersion.Validate(version); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unknown or unsupported API version", + Validations: []codersdk.ValidationError{ + {Field: "version", Detail: err.Error()}, + }, + }) + return + } + api.WebsocketWaitMutex.Lock() api.WebsocketWaitGroup.Add(1) api.WebsocketWaitMutex.Unlock() diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 2b65f3a316..f1c29c4517 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -21,6 +21,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" + "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/codersdk" drpcsdk "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/retry" @@ -281,18 +282,22 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C }, nil } -// Listen connects to the workspace agent coordinate WebSocket +// Listen connects to the workspace agent API WebSocket // that handles connection negotiation. func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) { - coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc") + rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc") if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } + q := rpcURL.Query() + q.Add("version", proto.CurrentVersion.String()) + rpcURL.RawQuery = q.Encode() + jar, err := cookiejar.New(nil) if err != nil { return nil, xerrors.Errorf("create cookie jar: %w", err) } - jar.SetCookies(coordinateURL, []*http.Cookie{{ + jar.SetCookies(rpcURL, []*http.Cookie{{ Name: codersdk.SessionTokenCookie, Value: c.SDK.SessionToken(), }}) @@ -301,7 +306,7 @@ func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) { Transport: c.SDK.HTTPClient.Transport, } // nolint:bodyclose - conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ + conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, }) if err != nil { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 8165b78c12..63b8de3c04 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -22,6 +22,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/retry" ) @@ -314,7 +315,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID, return nil, xerrors.Errorf("parse url: %w", err) } q := coordinateURL.Query() - q.Add("version", tailnet.CurrentVersion.String()) + q.Add("version", proto.CurrentVersion.String()) coordinateURL.RawQuery = q.Encode() closedCoordinator := make(chan struct{}) // Must only ever be used once, send error OR close to avoid diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index 4fe25827b5..02302a0a30 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -11,7 +11,7 @@ import ( "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" - agpl "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" ) // @Summary Agent is legacy @@ -59,7 +59,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request if qv != "" { version = qv } - if err := agpl.CurrentVersion.Validate(version); err != nil { + if err := proto.CurrentVersion.Validate(version); err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Unknown or unsupported API version", Validations: []codersdk.ValidationError{ diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index f8d8c22543..68e7ec9c90 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -439,7 +439,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro return nil, xerrors.Errorf("parse url: %w", err) } q := coordinateURL.Query() - q.Add("version", agpl.CurrentVersion.String()) + q.Add("version", proto.CurrentVersion.String()) coordinateURL.RawQuery = q.Encode() coordinateHeaders := make(http.Header) tokenHeader := codersdk.SessionTokenHeader diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index 8cf8b1ee18..99a207ccdf 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -194,7 +194,7 @@ func TestDialCoordinator(t *testing.T) { return } version := r.URL.Query().Get("version") - if !assert.Equal(t, version, agpl.CurrentVersion.String()) { + if !assert.Equal(t, version, proto.CurrentVersion.String()) { return } nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary) diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 7207f93d78..c3e1508b7d 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -460,7 +460,7 @@ func TestRemoteCoordination(t *testing.T) { serveErr := make(chan error, 1) go func() { - err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID) + err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID) serveErr <- err }() diff --git a/tailnet/proto/version.go b/tailnet/proto/version.go new file mode 100644 index 0000000000..449595feb4 --- /dev/null +++ b/tailnet/proto/version.go @@ -0,0 +1,12 @@ +package proto + +import ( + "github.com/coder/coder/v2/coderd/util/apiversion" +) + +const ( + CurrentMajor = 2 + CurrentMinor = 0 +) + +var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1) diff --git a/tailnet/service.go b/tailnet/service.go index 7347afbb32..02bc50a571 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -20,13 +20,6 @@ import ( "golang.org/x/xerrors" ) -const ( - CurrentMajor = 2 - CurrentMinor = 0 -) - -var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1) - type streamIDContextKey struct{} // StreamID identifies the caller of the CoordinateTailnet RPC. We store this