mirror of https://github.com/coder/coder.git
fix(enterprise/coderd): check provisionerd API version on connection (#12191)
This commit is contained in:
parent
f17149c59d
commit
a2cbb0f87f
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
)
|
||||
|
||||
func TestAPIVersionValidate(t *testing.T) {
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
|
@ -14,7 +15,6 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/coder/coder/v2/codersdk/drpc"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionerd/runner"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
)
|
||||
|
||||
type LogSource string
|
||||
|
@ -201,6 +202,8 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
|
|||
query := serverURL.Query()
|
||||
query.Add("id", req.ID.String())
|
||||
query.Add("name", req.Name)
|
||||
query.Add("version", provisionersdk.VersionCurrent.String())
|
||||
|
||||
for _, provisioner := range req.Provisioners {
|
||||
query.Add("provisioner", string(provisioner))
|
||||
}
|
||||
|
|
|
@ -239,6 +239,16 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||
apiVersion = qv
|
||||
}
|
||||
|
||||
if err := provisionersdk.VersionCurrent.Validate(apiVersion); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Incompatible or unparsable version",
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "version", Detail: err.Error()},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create the daemon in the database.
|
||||
now := dbtime.Now()
|
||||
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
|
||||
|
|
|
@ -3,6 +3,8 @@ package coderd_test
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
|
@ -12,6 +14,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
|
@ -63,6 +66,108 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("NoVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
|
||||
// endpoint. We do not pass the required machinery to start a websocket connection, so we expect a
|
||||
// WebSocket protocol violation. This just means the pre-flight checks have passed though.
|
||||
|
||||
// Sending a HTTP request triggers an error log, which would otherwise fail the test.
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
},
|
||||
},
|
||||
ProvisionerDaemonPSK: "provisionersftw",
|
||||
Options: &coderdtest.Options{
|
||||
Logger: &logger,
|
||||
},
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Formulate the correct URL for provisionerd server.
|
||||
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
|
||||
require.NoError(t, err)
|
||||
q := srvURL.Query()
|
||||
// Set required query parameters.
|
||||
q.Add("provisioner", "echo")
|
||||
// Note: Explicitly not setting API version.
|
||||
q.Add("version", "")
|
||||
srvURL.RawQuery = q.Encode()
|
||||
|
||||
// Set PSK header for auth.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")
|
||||
|
||||
// Do the request!
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
// The below means that provisionerd tried to serve us!
|
||||
require.Contains(t, string(b), "Internal error accepting websocket connection.")
|
||||
|
||||
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
|
||||
require.NoError(t, err)
|
||||
if assert.Len(t, daemons, 1) {
|
||||
assert.Equal(t, "1.0", daemons[0].APIVersion) // The whole point of this test is here.
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OldVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
|
||||
// endpoint. We do not pass the required machinery to start a websocket connection, but we pass a
|
||||
// version header that should cause provisionerd to refuse to serve us, so no websocket for you!
|
||||
|
||||
// Sending a HTTP request triggers an error log, which would otherwise fail the test.
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
},
|
||||
},
|
||||
ProvisionerDaemonPSK: "provisionersftw",
|
||||
Options: &coderdtest.Options{
|
||||
Logger: &logger,
|
||||
},
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Formulate the correct URL for provisionerd server.
|
||||
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
|
||||
require.NoError(t, err)
|
||||
q := srvURL.Query()
|
||||
// Set required query parameters.
|
||||
q.Add("provisioner", "echo")
|
||||
|
||||
// Set a different (newer) version than the current.
|
||||
v := apiversion.New(provisionersdk.CurrentMajor+1, provisionersdk.CurrentMinor+1)
|
||||
q.Add("version", v.String())
|
||||
srvURL.RawQuery = q.Encode()
|
||||
|
||||
// Set PSK header for auth.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")
|
||||
|
||||
// Do the request!
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
// The below means that provisionerd tried to serve us, checked our api version, and said nope.
|
||||
require.Contains(t, string(b), "server is at version 1.0, behind requested major version 2.1")
|
||||
})
|
||||
|
||||
t.Run("NoLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
)
|
||||
|
|
|
@ -16,8 +16,8 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package proto
|
||||
|
||||
import (
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/util/apiversion"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
|
Loading…
Reference in New Issue