mirror of https://github.com/coder/coder.git
feat: check agent API version on connection (#11696)
fixes #10531 Adds a check for `version` on connection to the Agent API websocket endpoint. This is primarily for future-proofing, so that up-level agents get a sensible error if they connect to a back-level Coderd. It also refactors the location of the `CurrentVersion` variables, to be part of the `proto` packages, since the versions refer to the APIs defined therein.
This commit is contained in:
parent
eb12fd7d92
commit
3e0e7f8739
|
@ -0,0 +1,10 @@
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CurrentVersion is the current version of the agent API. It is tied to the
|
||||||
|
// tailnet API version to avoid confusion, since agents connect to the tailnet
|
||||||
|
// API over the same websocket.
|
||||||
|
var CurrentVersion = proto.CurrentVersion
|
|
@ -43,6 +43,7 @@ import (
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// @Summary Get workspace agent by ID
|
// @Summary Get workspace agent by ID
|
||||||
|
@ -1162,7 +1163,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
||||||
if qv != "" {
|
if qv != "" {
|
||||||
version = qv
|
version = qv
|
||||||
}
|
}
|
||||||
if err := tailnet.CurrentVersion.Validate(version); err != nil {
|
if err := proto.CurrentVersion.Validate(version); err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: "Unknown or unsupported API version",
|
Message: "Unknown or unsupported API version",
|
||||||
Validations: []codersdk.ValidationError{
|
Validations: []codersdk.ValidationError{
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/agent/proto"
|
||||||
"github.com/coder/coder/v2/coderd/agentapi"
|
"github.com/coder/coder/v2/coderd/agentapi"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
|
@ -37,6 +38,23 @@ import (
|
||||||
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
|
version := r.URL.Query().Get("version")
|
||||||
|
if version == "" {
|
||||||
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Missing required query parameter: version",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := proto.CurrentVersion.Validate(version); err != nil {
|
||||||
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Unknown or unsupported API version",
|
||||||
|
Validations: []codersdk.ValidationError{
|
||||||
|
{Field: "version", Detail: err.Error()},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
api.WebsocketWaitMutex.Lock()
|
api.WebsocketWaitMutex.Lock()
|
||||||
api.WebsocketWaitGroup.Add(1)
|
api.WebsocketWaitGroup.Add(1)
|
||||||
api.WebsocketWaitMutex.Unlock()
|
api.WebsocketWaitMutex.Unlock()
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/agent/proto"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
||||||
"github.com/coder/retry"
|
"github.com/coder/retry"
|
||||||
|
@ -281,18 +282,22 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen connects to the workspace agent coordinate WebSocket
|
// Listen connects to the workspace agent API WebSocket
|
||||||
// that handles connection negotiation.
|
// that handles connection negotiation.
|
||||||
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
|
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
|
||||||
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
|
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("parse url: %w", err)
|
return nil, xerrors.Errorf("parse url: %w", err)
|
||||||
}
|
}
|
||||||
|
q := rpcURL.Query()
|
||||||
|
q.Add("version", proto.CurrentVersion.String())
|
||||||
|
rpcURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
jar, err := cookiejar.New(nil)
|
jar, err := cookiejar.New(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||||
}
|
}
|
||||||
jar.SetCookies(coordinateURL, []*http.Cookie{{
|
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||||
Name: codersdk.SessionTokenCookie,
|
Name: codersdk.SessionTokenCookie,
|
||||||
Value: c.SDK.SessionToken(),
|
Value: c.SDK.SessionToken(),
|
||||||
}})
|
}})
|
||||||
|
@ -301,7 +306,7 @@ func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
|
||||||
Transport: c.SDK.HTTPClient.Transport,
|
Transport: c.SDK.HTTPClient.Transport,
|
||||||
}
|
}
|
||||||
// nolint:bodyclose
|
// nolint:bodyclose
|
||||||
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/coderd/tracing"
|
"github.com/coder/coder/v2/coderd/tracing"
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
"github.com/coder/retry"
|
"github.com/coder/retry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -314,7 +315,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
|
||||||
return nil, xerrors.Errorf("parse url: %w", err)
|
return nil, xerrors.Errorf("parse url: %w", err)
|
||||||
}
|
}
|
||||||
q := coordinateURL.Query()
|
q := coordinateURL.Query()
|
||||||
q.Add("version", tailnet.CurrentVersion.String())
|
q.Add("version", proto.CurrentVersion.String())
|
||||||
coordinateURL.RawQuery = q.Encode()
|
coordinateURL.RawQuery = q.Encode()
|
||||||
closedCoordinator := make(chan struct{})
|
closedCoordinator := make(chan struct{})
|
||||||
// Must only ever be used once, send error OR close to avoid
|
// Must only ever be used once, send error OR close to avoid
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// @Summary Agent is legacy
|
// @Summary Agent is legacy
|
||||||
|
@ -59,7 +59,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
||||||
if qv != "" {
|
if qv != "" {
|
||||||
version = qv
|
version = qv
|
||||||
}
|
}
|
||||||
if err := agpl.CurrentVersion.Validate(version); err != nil {
|
if err := proto.CurrentVersion.Validate(version); err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: "Unknown or unsupported API version",
|
Message: "Unknown or unsupported API version",
|
||||||
Validations: []codersdk.ValidationError{
|
Validations: []codersdk.ValidationError{
|
||||||
|
|
|
@ -439,7 +439,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
|
||||||
return nil, xerrors.Errorf("parse url: %w", err)
|
return nil, xerrors.Errorf("parse url: %w", err)
|
||||||
}
|
}
|
||||||
q := coordinateURL.Query()
|
q := coordinateURL.Query()
|
||||||
q.Add("version", agpl.CurrentVersion.String())
|
q.Add("version", proto.CurrentVersion.String())
|
||||||
coordinateURL.RawQuery = q.Encode()
|
coordinateURL.RawQuery = q.Encode()
|
||||||
coordinateHeaders := make(http.Header)
|
coordinateHeaders := make(http.Header)
|
||||||
tokenHeader := codersdk.SessionTokenHeader
|
tokenHeader := codersdk.SessionTokenHeader
|
||||||
|
|
|
@ -194,7 +194,7 @@ func TestDialCoordinator(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
version := r.URL.Query().Get("version")
|
version := r.URL.Query().Get("version")
|
||||||
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
|
if !assert.Equal(t, version, proto.CurrentVersion.String()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
|
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
|
||||||
|
|
|
@ -460,7 +460,7 @@ func TestRemoteCoordination(t *testing.T) {
|
||||||
|
|
||||||
serveErr := make(chan error, 1)
|
serveErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
|
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
|
||||||
serveErr <- err
|
serveErr <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
package proto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CurrentMajor = 2
|
||||||
|
CurrentMinor = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
|
|
@ -20,13 +20,6 @@ import (
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
CurrentMajor = 2
|
|
||||||
CurrentMinor = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
|
|
||||||
|
|
||||||
type streamIDContextKey struct{}
|
type streamIDContextKey struct{}
|
||||||
|
|
||||||
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
||||||
|
|
Loading…
Reference in New Issue