feat: check agent API version on connection (#11696)

fixes #10531

Adds a check for `version` on connection to the Agent API websocket endpoint.  This is primarily for future-proofing, so that up-level agents get a sensible error if they connect to a back-level Coderd.

It also refactors the location of the `CurrentVersion` variables, to be part of the `proto` packages, since the versions refer to the APIs defined therein.
This commit is contained in:
Spike Curtis 2024-01-23 14:27:49 +04:00 committed by GitHub
parent eb12fd7d92
commit 3e0e7f8739
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 58 additions and 18 deletions

10
agent/proto/version.go Normal file
View File

@ -0,0 +1,10 @@
package proto
import (
"github.com/coder/coder/v2/tailnet/proto"
)
// CurrentVersion is the current version of the agent API. It is tied to the
// tailnet API version to avoid confusion, since agents connect to the tailnet
// API over the same websocket.
var CurrentVersion = proto.CurrentVersion

View File

@ -43,6 +43,7 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
// @Summary Get workspace agent by ID
@ -1162,7 +1163,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
if qv != "" {
version = qv
}
if err := tailnet.CurrentVersion.Validate(version); err != nil {
if err := proto.CurrentVersion.Validate(version); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown or unsupported API version",
Validations: []codersdk.ValidationError{

View File

@ -16,6 +16,7 @@ import (
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@ -37,6 +38,23 @@ import (
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
version := r.URL.Query().Get("version")
if version == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing required query parameter: version",
})
return
}
if err := proto.CurrentVersion.Validate(version); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown or unsupported API version",
Validations: []codersdk.ValidationError{
{Field: "version", Detail: err.Error()},
},
})
return
}
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()

View File

@ -21,6 +21,7 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/retry"
@ -281,18 +282,22 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
}, nil
}
// Listen connects to the workspace agent coordinate WebSocket
// Listen connects to the workspace agent API WebSocket
// that handles connection negotiation.
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
q := rpcURL.Query()
q.Add("version", proto.CurrentVersion.String())
rpcURL.RawQuery = q.Encode()
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(coordinateURL, []*http.Cookie{{
jar.SetCookies(rpcURL, []*http.Cookie{{
Name: codersdk.SessionTokenCookie,
Value: c.SDK.SessionToken(),
}})
@ -301,7 +306,7 @@ func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
Transport: c.SDK.HTTPClient.Transport,
}
// nolint:bodyclose
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
})
if err != nil {

View File

@ -22,6 +22,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/retry"
)
@ -314,7 +315,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
q.Add("version", tailnet.CurrentVersion.String())
q.Add("version", proto.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
closedCoordinator := make(chan struct{})
// Must only ever be used once, send error OR close to avoid

View File

@ -11,7 +11,7 @@ import (
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
// @Summary Agent is legacy
@ -59,7 +59,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
if qv != "" {
version = qv
}
if err := agpl.CurrentVersion.Validate(version); err != nil {
if err := proto.CurrentVersion.Validate(version); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown or unsupported API version",
Validations: []codersdk.ValidationError{

View File

@ -439,7 +439,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
return nil, xerrors.Errorf("parse url: %w", err)
}
q := coordinateURL.Query()
q.Add("version", agpl.CurrentVersion.String())
q.Add("version", proto.CurrentVersion.String())
coordinateURL.RawQuery = q.Encode()
coordinateHeaders := make(http.Header)
tokenHeader := codersdk.SessionTokenHeader

View File

@ -194,7 +194,7 @@ func TestDialCoordinator(t *testing.T) {
return
}
version := r.URL.Query().Get("version")
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
if !assert.Equal(t, version, proto.CurrentVersion.String()) {
return
}
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)

View File

@ -460,7 +460,7 @@ func TestRemoteCoordination(t *testing.T) {
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
serveErr <- err
}()

12
tailnet/proto/version.go Normal file
View File

@ -0,0 +1,12 @@
package proto
import (
"github.com/coder/coder/v2/coderd/util/apiversion"
)
const (
CurrentMajor = 2
CurrentMinor = 0
)
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)

View File

@ -20,13 +20,6 @@ import (
"golang.org/x/xerrors"
)
const (
CurrentMajor = 2
CurrentMinor = 0
)
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
type streamIDContextKey struct{}
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this