fix(enterprise/coderd): check provisionerd API version on connection (#12191)

This commit is contained in:
Cian Johnston 2024-02-16 18:43:07 +00:00 committed by GitHub
parent f17149c59d
commit a2cbb0f87f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 125 additions and 7 deletions

View File

@ -5,7 +5,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
)
func TestAPIVersionValidate(t *testing.T) {

View File

@ -7,6 +7,7 @@ import (
"golang.org/x/mod/semver"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@ -14,7 +15,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk"

View File

@ -18,6 +18,7 @@ import (
"github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionerd/runner"
"github.com/coder/coder/v2/provisionersdk"
)
type LogSource string
@ -201,6 +202,8 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione
query := serverURL.Query()
query.Add("id", req.ID.String())
query.Add("name", req.Name)
query.Add("version", provisionersdk.VersionCurrent.String())
for _, provisioner := range req.Provisioners {
query.Add("provisioner", string(provisioner))
}

View File

@ -239,6 +239,16 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
apiVersion = qv
}
if err := provisionersdk.VersionCurrent.Validate(apiVersion); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Incompatible or unparsable version",
Validations: []codersdk.ValidationError{
{Field: "version", Detail: err.Error()},
},
})
return
}
// Create the daemon in the database.
now := dbtime.Now()
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{

View File

@ -3,6 +3,8 @@ package coderd_test
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"testing"
@ -12,6 +14,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
@ -63,6 +66,108 @@ func TestProvisionerDaemonServe(t *testing.T) {
}
})
t.Run("NoVersion", func(t *testing.T) {
t.Parallel()
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
// endpoint. We do not pass the required machinery to start a websocket connection, so we expect a
// WebSocket protocol violation. This just means the pre-flight checks have passed though.
// Sending a HTTP request triggers an error log, which would otherwise fail the test.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
client, user := coderdenttest.New(t, &coderdenttest.Options{
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
Options: &coderdtest.Options{
Logger: &logger,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Formulate the correct URL for provisionerd server.
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
require.NoError(t, err)
q := srvURL.Query()
// Set required query parameters.
q.Add("provisioner", "echo")
// Note: Explicitly not setting API version.
q.Add("version", "")
srvURL.RawQuery = q.Encode()
// Set PSK header for auth.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
require.NoError(t, err)
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")
// Do the request!
resp, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// The below means that provisionerd tried to serve us!
require.Contains(t, string(b), "Internal error accepting websocket connection.")
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
if assert.Len(t, daemons, 1) {
assert.Equal(t, "1.0", daemons[0].APIVersion) // The whole point of this test is here.
}
})
t.Run("OldVersion", func(t *testing.T) {
t.Parallel()
// In this test, we just send a HTTP request with minimal parameters to the provisionerdaemons
// endpoint. We do not pass the required machinery to start a websocket connection, but we pass a
// version header that should cause provisionerd to refuse to serve us, so no websocket for you!
// Sending a HTTP request triggers an error log, which would otherwise fail the test.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
client, user := coderdenttest.New(t, &coderdenttest.Options{
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
Options: &coderdtest.Options{
Logger: &logger,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Formulate the correct URL for provisionerd server.
srvURL, err := client.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", user.OrganizationID))
require.NoError(t, err)
q := srvURL.Query()
// Set required query parameters.
q.Add("provisioner", "echo")
// Set a different (newer) version than the current.
v := apiversion.New(provisionersdk.CurrentMajor+1, provisionersdk.CurrentMinor+1)
q.Add("version", v.String())
srvURL.RawQuery = q.Encode()
// Set PSK header for auth.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvURL.String(), nil)
require.NoError(t, err)
req.Header.Set(codersdk.ProvisionerDaemonPSK, "provisionersftw")
// Do the request!
resp, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// The below means that provisionerd tried to serve us, checked our api version, and said nope.
require.Contains(t, string(b), "server is at version 1.0, behind requested major version 2.1")
})
t.Run("NoLicense", func(t *testing.T) {
t.Parallel()
client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})

View File

@ -6,8 +6,8 @@ import (
"github.com/google/uuid"
"nhooyr.io/websocket"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet/proto"
)

View File

@ -14,7 +14,7 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet"
)

View File

@ -16,8 +16,8 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/provisionersdk/proto"
)

View File

@ -1,7 +1,7 @@
package proto
import (
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
)
const (

View File

@ -14,7 +14,7 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/apiversion"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/tailnet/proto"
"golang.org/x/xerrors"