mirror of https://github.com/coder/coder.git
fix(enterprise/coderd): check provisionerd API version on connection (#12191)
This commit is contained in:
parent
f17149c59d
commit
a2cbb0f87f
|
@ -5,7 +5,7 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
"github.com/coder/coder/v2/apiversion"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAPIVersionValidate(t *testing.T) {
|
func TestAPIVersionValidate(t *testing.T) {
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/buildinfo"
|
"github.com/coder/coder/v2/buildinfo"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||||
|
@ -14,7 +15,6 @@ import (
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
|
||||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/provisionersdk"
|
"github.com/coder/coder/v2/provisionersdk"
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/coder/coder/v2/codersdk/drpc"
|
"github.com/coder/coder/v2/codersdk/drpc"
|
||||||
"github.com/coder/coder/v2/provisionerd/proto"
|
"github.com/coder/coder/v2/provisionerd/proto"
|
||||||
"github.com/coder/coder/v2/provisionerd/runner"
|
"github.com/coder/coder/v2/provisionerd/runner"
|
||||||
|
"github.com/coder/coder/v2/provisionersdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LogSource string
|
type LogSource string
|
||||||
|
@ -201,6 +202,8 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
|
||||||
query := serverURL.Query()
|
query := serverURL.Query()
|
||||||
query.Add("id", req.ID.String())
|
query.Add("id", req.ID.String())
|
||||||
query.Add("name", req.Name)
|
query.Add("name", req.Name)
|
||||||
|
query.Add("version", provisionersdk.VersionCurrent.String())
|
||||||
|
|
||||||
for _, provisioner := range req.Provisioners {
|
for _, provisioner := range req.Provisioners {
|
||||||
query.Add("provisioner", string(provisioner))
|
query.Add("provisioner", string(provisioner))
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,6 +239,16 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||||
apiVersion = qv
|
apiVersion = qv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := provisionersdk.VersionCurrent.Validate(apiVersion); err != nil {
|
||||||
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: "Incompatible or unparsable version",
|
||||||
|
Validations: []codersdk.ValidationError{
|
||||||
|
{Field: "version", Detail: err.Error()},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Create the daemon in the database.
|
// Create the daemon in the database.
|
||||||
now := dbtime.Now()
|
now := dbtime.Now()
|
||||||
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
|
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
|
||||||
|
|
|
@ -3,6 +3,8 @@ package coderd_test
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -12,6 +14,7 @@ import (
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/buildinfo"
|
"github.com/coder/coder/v2/buildinfo"
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
@ -63,6 +66,108 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("NoVersion", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
|
||||||
|
// endpoint. We do not pass the required machinery to start a websocket connection, so we expect a
|
||||||
|
// WebSocket protocol violation. This just means the pre-flight checks have passed though.
|
||||||
|
|
||||||
|
// Sending a HTTP request triggers an error log, which would otherwise fail the test.
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||||
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||||
|
Features: license.Features{
|
||||||
|
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ProvisionerDaemonPSK: "provisionersftw",
|
||||||
|
Options: &coderdtest.Options{
|
||||||
|
Logger: &logger,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Formulate the correct URL for provisionerd server.
|
||||||
|
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := srvURL.Query()
|
||||||
|
// Set required query parameters.
|
||||||
|
q.Add("provisioner", "echo")
|
||||||
|
// Note: Explicitly not setting API version.
|
||||||
|
q.Add("version", "")
|
||||||
|
srvURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Set PSK header for auth.
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")
|
||||||
|
|
||||||
|
// Do the request!
|
||||||
|
resp, err := client.HTTPClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// The below means that provisionerd tried to serve us!
|
||||||
|
require.Contains(t, string(b), "Internal error accepting websocket connection.")
|
||||||
|
|
||||||
|
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||||
|
require.NoError(t, err)
|
||||||
|
if assert.Len(t, daemons, 1) {
|
||||||
|
assert.Equal(t, "1.0", daemons[0].APIVersion) // The whole point of this test is here.
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OldVersion", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
|
||||||
|
// endpoint. We do not pass the required machinery to start a websocket connection, but we pass a
|
||||||
|
// version header that should cause provisionerd to refuse to serve us, so no websocket for you!
|
||||||
|
|
||||||
|
// Sending a HTTP request triggers an error log, which would otherwise fail the test.
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||||
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||||
|
Features: license.Features{
|
||||||
|
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ProvisionerDaemonPSK: "provisionersftw",
|
||||||
|
Options: &coderdtest.Options{
|
||||||
|
Logger: &logger,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Formulate the correct URL for provisionerd server.
|
||||||
|
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := srvURL.Query()
|
||||||
|
// Set required query parameters.
|
||||||
|
q.Add("provisioner", "echo")
|
||||||
|
|
||||||
|
// Set a different (newer) version than the current.
|
||||||
|
v := apiversion.New(provisionersdk.CurrentMajor+1, provisionersdk.CurrentMinor+1)
|
||||||
|
q.Add("version", v.String())
|
||||||
|
srvURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Set PSK header for auth.
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")
|
||||||
|
|
||||||
|
// Do the request!
|
||||||
|
resp, err := client.HTTPClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// The below means that provisionerd tried to serve us, checked our api version, and said nope.
|
||||||
|
require.Contains(t, string(b), "server is at version 1.0, behind requested major version 2.1")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("NoLicense", func(t *testing.T) {
|
t.Run("NoLicense", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
|
client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
|
||||||
|
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/coderd/httpapi"
|
"github.com/coder/coder/v2/coderd/httpapi"
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
agpl "github.com/coder/coder/v2/tailnet"
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,8 +16,8 @@ import (
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/coderd/tracing"
|
"github.com/coder/coder/v2/coderd/tracing"
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
|
||||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package proto
|
package proto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
"github.com/coder/coder/v2/apiversion"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
Loading…
Reference in New Issue