mirror of https://github.com/coder/coder.git
feat: check agent API version on connection (#11696)
fixes #10531 Adds a check for `version` on connection to the Agent API websocket endpoint. This is primarily for future-proofing, so that up-level agents get a sensible error if they connect to a back-level Coderd. It also refactors the location of the `CurrentVersion` variables, to be part of the `proto` packages, since the versions refer to the APIs defined therein.
This commit is contained in:
parent
eb12fd7d92
commit
3e0e7f8739
|
@ -0,0 +1,10 @@
|
|||
package proto
|
||||
|
||||
import (
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
// CurrentVersion is the current version of the agent API. It is tied to the
|
||||
// tailnet API version to avoid confusion, since agents connect to the tailnet
|
||||
// API over the same websocket.
|
||||
var CurrentVersion = proto.CurrentVersion
|
|
@ -43,6 +43,7 @@ import (
|
|||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
// @Summary Get workspace agent by ID
|
||||
|
@ -1162,7 +1163,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
|||
if qv != "" {
|
||||
version = qv
|
||||
}
|
||||
if err := tailnet.CurrentVersion.Validate(version); err != nil {
|
||||
if err := proto.CurrentVersion.Validate(version); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unknown or unsupported API version",
|
||||
Validations: []codersdk.ValidationError{
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
|
@ -37,6 +38,23 @@ import (
|
|||
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
version := r.URL.Query().Get("version")
|
||||
if version == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing required query parameter: version",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := proto.CurrentVersion.Validate(version); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unknown or unsupported API version",
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "version", Detail: err.Error()},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
api.WebsocketWaitMutex.Lock()
|
||||
api.WebsocketWaitGroup.Add(1)
|
||||
api.WebsocketWaitMutex.Unlock()
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
||||
"github.com/coder/retry"
|
||||
|
@ -281,18 +282,22 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Listen connects to the workspace agent coordinate WebSocket
|
||||
// Listen connects to the workspace agent API WebSocket
|
||||
// that handles connection negotiation.
|
||||
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
|
||||
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
|
||||
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
q := rpcURL.Query()
|
||||
q.Add("version", proto.CurrentVersion.String())
|
||||
rpcURL.RawQuery = q.Encode()
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(coordinateURL, []*http.Cookie{{
|
||||
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: c.SDK.SessionToken(),
|
||||
}})
|
||||
|
@ -301,7 +306,7 @@ func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
|
|||
Transport: c.SDK.HTTPClient.Transport,
|
||||
}
|
||||
// nolint:bodyclose
|
||||
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
||||
conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
|
@ -314,7 +315,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
|
|||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
q := coordinateURL.Query()
|
||||
q.Add("version", tailnet.CurrentVersion.String())
|
||||
q.Add("version", proto.CurrentVersion.String())
|
||||
coordinateURL.RawQuery = q.Encode()
|
||||
closedCoordinator := make(chan struct{})
|
||||
// Must only ever be used once, send error OR close to avoid
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
// @Summary Agent is legacy
|
||||
|
@ -59,7 +59,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
if qv != "" {
|
||||
version = qv
|
||||
}
|
||||
if err := agpl.CurrentVersion.Validate(version); err != nil {
|
||||
if err := proto.CurrentVersion.Validate(version); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unknown or unsupported API version",
|
||||
Validations: []codersdk.ValidationError{
|
||||
|
|
|
@ -439,7 +439,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
|
|||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
q := coordinateURL.Query()
|
||||
q.Add("version", agpl.CurrentVersion.String())
|
||||
q.Add("version", proto.CurrentVersion.String())
|
||||
coordinateURL.RawQuery = q.Encode()
|
||||
coordinateHeaders := make(http.Header)
|
||||
tokenHeader := codersdk.SessionTokenHeader
|
||||
|
|
|
@ -194,7 +194,7 @@ func TestDialCoordinator(t *testing.T) {
|
|||
return
|
||||
}
|
||||
version := r.URL.Query().Get("version")
|
||||
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
|
||||
if !assert.Equal(t, version, proto.CurrentVersion.String()) {
|
||||
return
|
||||
}
|
||||
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
|
|
|
@ -460,7 +460,7 @@ func TestRemoteCoordination(t *testing.T) {
|
|||
|
||||
serveErr := make(chan error, 1)
|
||||
go func() {
|
||||
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
|
||||
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
|
||||
serveErr <- err
|
||||
}()
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
package proto
|
||||
|
||||
import (
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
)
|
||||
|
||||
const (
|
||||
CurrentMajor = 2
|
||||
CurrentMinor = 0
|
||||
)
|
||||
|
||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
|
|
@ -20,13 +20,6 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
CurrentMajor = 2
|
||||
CurrentMinor = 0
|
||||
)
|
||||
|
||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
|
||||
|
||||
type streamIDContextKey struct{}
|
||||
|
||||
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
||||
|
|
Loading…
Reference in New Issue