diff --git a/coderd/util/apiversion/apiversion.go b/apiversion/apiversion.go similarity index 100% rename from coderd/util/apiversion/apiversion.go rename to apiversion/apiversion.go diff --git a/coderd/util/apiversion/apiversion_test.go b/apiversion/apiversion_test.go similarity index 96% rename from coderd/util/apiversion/apiversion_test.go rename to apiversion/apiversion_test.go index 0bd6fe0f6b..8a18a0bd5c 100644 --- a/coderd/util/apiversion/apiversion_test.go +++ b/apiversion/apiversion_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/coderd/util/apiversion" + "github.com/coder/coder/v2/apiversion" ) func TestAPIVersionValidate(t *testing.T) { diff --git a/coderd/healthcheck/provisioner.go b/coderd/healthcheck/provisioner.go index 4ff961454b..a61836a3d4 100644 --- a/coderd/healthcheck/provisioner.go +++ b/coderd/healthcheck/provisioner.go @@ -7,6 +7,7 @@ import ( "golang.org/x/mod/semver" + "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -14,7 +15,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/healthcheck/health" "github.com/coder/coder/v2/coderd/provisionerdserver" - "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index e8f8ed8eb6..ba5fd64b1a 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -18,6 +18,7 @@ import ( "github.com/coder/coder/v2/codersdk/drpc" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" + "github.com/coder/coder/v2/provisionersdk" ) type LogSource string @@ -201,6 +202,8 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione query := serverURL.Query() query.Add("id", req.ID.String()) query.Add("name", req.Name) + query.Add("version", provisionersdk.VersionCurrent.String()) + for _, provisioner := range req.Provisioners { query.Add("provisioner", string(provisioner)) } diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index f81f17befd..a343542914 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -239,6 +239,16 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) apiVersion = qv } + if err := provisionersdk.VersionCurrent.Validate(apiVersion); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Incompatible or unparsable version", + Validations: []codersdk.ValidationError{ + {Field: "version", Detail: err.Error()}, + }, + }) + return + } + // Create the daemon in the database. now := dbtime.Now() daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{ diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index ac48e21cdd..d326974c96 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -3,6 +3,8 @@ package coderd_test import ( "bytes" "context" + "fmt" + "io" "net/http" "testing" @@ -12,6 +14,7 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" @@ -63,6 +66,108 @@ func TestProvisionerDaemonServe(t *testing.T) { } }) + t.Run("NoVersion", func(t *testing.T) { + t.Parallel() + // In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons + // endpoint. We do not pass the required machinery to start a websocket connection, so we expect a + // WebSocket protocol violation. This just means the pre-flight checks have passed though. + + // Sending a HTTP request triggers an error log, which would otherwise fail the test. + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + client, user := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + ProvisionerDaemonPSK: "provisionersftw", + Options: &coderdtest.Options{ + Logger: &logger, + }, + }) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Formulate the correct URL for provisionerd server. + srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID)) + require.NoError(t, err) + q := srvURL.Query() + // Set required query parameters. + q.Add("provisioner", "echo") + // Note: Explicitly not setting API version. + q.Add("version", "") + srvURL.RawQuery = q.Encode() + + // Set PSK header for auth. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil) + require.NoError(t, err) + req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw") + + // Do the request! + resp, err := client.HTTPClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + // The below means that provisionerd tried to serve us! + require.Contains(t, string(b), "Internal error accepting websocket connection.") + + daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion. + require.NoError(t, err) + if assert.Len(t, daemons, 1) { + assert.Equal(t, "1.0", daemons[0].APIVersion) // The whole point of this test is here. + } + }) + + t.Run("OldVersion", func(t *testing.T) { + t.Parallel() + // In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons + // endpoint. We do not pass the required machinery to start a websocket connection, but we pass a + // version header that should cause provisionerd to refuse to serve us, so no websocket for you! + + // Sending a HTTP request triggers an error log, which would otherwise fail the test. + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + client, user := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + ProvisionerDaemonPSK: "provisionersftw", + Options: &coderdtest.Options{ + Logger: &logger, + }, + }) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Formulate the correct URL for provisionerd server. + srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID)) + require.NoError(t, err) + q := srvURL.Query() + // Set required query parameters. + q.Add("provisioner", "echo") + + // Set a different (newer) version than the current. + v := apiversion.New(provisionersdk.CurrentMajor+1, provisionersdk.CurrentMinor+1) + q.Add("version", v.String()) + srvURL.RawQuery = q.Encode() + + // Set PSK header for auth. + req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil) + require.NoError(t, err) + req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw") + + // Do the request! + resp, err := client.HTTPClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + // The below means that provisionerd tried to serve us, checked our api version, and said nope. + require.Contains(t, string(b), "server is at version 1.0, behind requested major version 2.1") + }) + t.Run("NoLicense", func(t *testing.T) { t.Parallel() client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true}) diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go index a85cc0488e..58522e59ac 100644 --- a/enterprise/coderd/workspaceproxycoordinate.go +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -6,8 +6,8 @@ import ( "github.com/google/uuid" "nhooyr.io/websocket" + "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/coderd/httpapi" - "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/tailnet/proto" ) diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index d8f64aa398..b7daabd891 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -14,7 +14,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/util/apiversion" + "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" ) diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index 0b2e10234f..1f19ca6c83 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -16,8 +16,8 @@ import ( "cdr.dev/slog" + "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/coderd/tracing" - "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/provisionersdk/proto" ) diff --git a/tailnet/proto/version.go b/tailnet/proto/version.go index 449595feb4..a6040a9fea 100644 --- a/tailnet/proto/version.go +++ b/tailnet/proto/version.go @@ -1,7 +1,7 @@ package proto import ( - "github.com/coder/coder/v2/coderd/util/apiversion" + "github.com/coder/coder/v2/apiversion" ) const ( diff --git a/tailnet/service.go b/tailnet/service.go index 3be0abcab6..4af8d6913c 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -14,7 +14,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" - "github.com/coder/coder/v2/coderd/util/apiversion" + "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/tailnet/proto" "golang.org/x/xerrors"