From 4d2fe2685adc6626381b906ea99e7bdbeda99885 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Fri, 5 Jan 2024 10:22:07 +0000 Subject: [PATCH] chore(coderd): extract api version validation to util package (#11407) --- coderd/util/apiversion/apiversion.go | 84 +++++++++++++++++++++ coderd/util/apiversion/apiversion_test.go | 90 +++++++++++++++++++++++ coderd/workspaceagents.go | 2 +- tailnet/service.go | 47 +----------- tailnet/service_test.go | 65 ---------------- 5 files changed, 178 insertions(+), 110 deletions(-) create mode 100644 coderd/util/apiversion/apiversion.go create mode 100644 coderd/util/apiversion/apiversion_test.go diff --git a/coderd/util/apiversion/apiversion.go b/coderd/util/apiversion/apiversion.go new file mode 100644 index 0000000000..f9a1d0d539 --- /dev/null +++ b/coderd/util/apiversion/apiversion.go @@ -0,0 +1,84 @@ +package apiversion + +import ( + "strconv" + "strings" + + "golang.org/x/xerrors" +) + +// New returns an *APIVersion with the given major.minor and +// additional supported major versions. +func New(maj, min int) *APIVersion { + v := &APIVersion{ + supportedMajor: maj, + supportedMinor: min, + additionalMajors: make([]int, 0), + } + return v +} + +type APIVersion struct { + supportedMajor int + supportedMinor int + additionalMajors []int +} + +func (v *APIVersion) WithBackwardCompat(majs ...int) *APIVersion { + v.additionalMajors = append(v.additionalMajors, majs[:]...) + return v +} + +// Validate validates the given version against the given constraints: +// A given major.minor version is valid iff: +// 1. The requested major version is contained within v.supportedMajors +// 2. If the requested major version is the 'current major', then +// the requested minor version must be less than or equal to the supported +// minor version. +// +// For example, given majors {1, 2} and minor 2, then: +// - 0.x is not supported, +// - 1.x is supported, +// - 2.0, 2.1, and 2.2 are supported, +// - 2.3+ is not supported. +func (v *APIVersion) Validate(version string) error { + major, minor, err := Parse(version) + if err != nil { + return err + } + if major > v.supportedMajor { + return xerrors.Errorf("server is at version %d.%d, behind requested major version %s", + v.supportedMajor, v.supportedMinor, version) + } + if major == v.supportedMajor { + if minor > v.supportedMinor { + return xerrors.Errorf("server is at version %d.%d, behind requested minor version %s", + v.supportedMajor, v.supportedMinor, version) + } + return nil + } + for _, mjr := range v.additionalMajors { + if major == mjr { + return nil + } + } + return xerrors.Errorf("version %s is no longer supported", version) +} + +// Parse parses a valid major.minor version string into (major, minor). +// Both major and minor must be valid integers separated by a period '.'. +func Parse(version string) (major int, minor int, err error) { + parts := strings.Split(version, ".") + if len(parts) != 2 { + return 0, 0, xerrors.Errorf("invalid version string: %s", version) + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, xerrors.Errorf("invalid major version: %s", version) + } + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, xerrors.Errorf("invalid minor version: %s", version) + } + return major, minor, nil +} diff --git a/coderd/util/apiversion/apiversion_test.go b/coderd/util/apiversion/apiversion_test.go new file mode 100644 index 0000000000..0bd6fe0f6b --- /dev/null +++ b/coderd/util/apiversion/apiversion_test.go @@ -0,0 +1,90 @@ +package apiversion_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/util/apiversion" +) + +func TestAPIVersionValidate(t *testing.T) { + t.Parallel() + + // Given + v := apiversion.New(2, 1).WithBackwardCompat(1) + + for _, tc := range []struct { + name string + version string + expectedError string + }{ + { + name: "OK", + version: "2.1", + }, + { + name: "MinorOK", + version: "2.0", + }, + { + name: "MajorOK", + version: "1.0", + }, + { + name: "TooNewMinor", + version: "2.2", + expectedError: "behind requested minor version", + }, + { + name: "TooNewMajor", + version: "3.1", + expectedError: "behind requested major version", + }, + { + name: "Malformed0", + version: "cats", + expectedError: "invalid version string", + }, + { + name: "Malformed1", + version: "cats.dogs", + expectedError: "invalid major version", + }, + { + name: "Malformed2", + version: "1.dogs", + expectedError: "invalid minor version", + }, + { + name: "Malformed3", + version: "1.0.1", + expectedError: "invalid version string", + }, + { + name: "Malformed4", + version: "11", + expectedError: "invalid version string", + }, + { + name: "TooOld", + version: "0.8", + expectedError: "no longer supported", + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // When + err := v.Validate(tc.version) + + // Then + if tc.expectedError == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.expectedError) + } + }) + } +} diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index a59f7f297e..917e979e09 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -1180,7 +1180,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R if qv != "" { version = qv } - if err := tailnet.ValidateVersion(version); err != nil { + if err := tailnet.CurrentVersion.Validate(version); err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Unknown or unsupported API version", Validations: []codersdk.ValidationError{ diff --git a/tailnet/service.go b/tailnet/service.go index 154514c9f0..1529bf65c0 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -4,8 +4,6 @@ import ( "context" "io" "net" - "strconv" - "strings" "sync/atomic" "time" @@ -16,6 +14,7 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/util/apiversion" "github.com/coder/coder/v2/tailnet/proto" "golang.org/x/xerrors" @@ -26,47 +25,7 @@ const ( CurrentMinor = 0 ) -var SupportedMajors = []int{2, 1} - -func ValidateVersion(version string) error { - major, minor, err := parseVersion(version) - if err != nil { - return err - } - if major > CurrentMajor { - return xerrors.Errorf("server is at version %d.%d, behind requested version %s", - CurrentMajor, CurrentMinor, version) - } - if major == CurrentMajor { - if minor > CurrentMinor { - return xerrors.Errorf("server is at version %d.%d, behind requested version %s", - CurrentMajor, CurrentMinor, version) - } - return nil - } - for _, mjr := range SupportedMajors { - if major == mjr { - return nil - } - } - return xerrors.Errorf("version %s is no longer supported", version) -} - -func parseVersion(version string) (major int, minor int, err error) { - parts := strings.Split(version, ".") - if len(parts) != 2 { - return 0, 0, xerrors.Errorf("invalid version string: %s", version) - } - major, err = strconv.Atoi(parts[0]) - if err != nil { - return 0, 0, xerrors.Errorf("invalid major version: %s", version) - } - minor, err = strconv.Atoi(parts[1]) - if err != nil { - return 0, 0, xerrors.Errorf("invalid minor version: %s", version) - } - return major, minor, nil -} +var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1) type streamIDContextKey struct{} @@ -127,7 +86,7 @@ func NewClientService( } func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error { - major, _, err := parseVersion(version) + major, _, err := apiversion.Parse(version) if err != nil { s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err)) return err diff --git a/tailnet/service_test.go b/tailnet/service_test.go index adedbde90f..c6a8907644 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -2,7 +2,6 @@ package tailnet_test import ( "context" - "fmt" "io" "net" "net/http" @@ -25,70 +24,6 @@ import ( "github.com/coder/coder/v2/tailnet" ) -func TestValidateVersion(t *testing.T) { - t.Parallel() - for _, tc := range []struct { - name string - version string - supported bool - }{ - { - name: "Current", - version: fmt.Sprintf("%d.%d", tailnet.CurrentMajor, tailnet.CurrentMinor), - supported: true, - }, - { - name: "TooNewMinor", - version: fmt.Sprintf("%d.%d", tailnet.CurrentMajor, tailnet.CurrentMinor+1), - }, - { - name: "TooNewMajor", - version: fmt.Sprintf("%d.%d", tailnet.CurrentMajor+1, tailnet.CurrentMinor), - }, - { - name: "1.0", - version: "1.0", - supported: true, - }, - { - name: "2.0", - version: "2.0", - supported: true, - }, - { - name: "Malformed0", - version: "cats", - }, - { - name: "Malformed1", - version: "cats.dogs", - }, - { - name: "Malformed2", - version: "1.0.1", - }, - { - name: "Malformed3", - version: "11", - }, - { - name: "TooOld", - version: "0.8", - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - err := tailnet.ValidateVersion(tc.version) - if tc.supported { - require.NoError(t, err) - } else { - require.Error(t, err) - } - }) - } -} - func TestClientService_ServeClient_V2(t *testing.T) { t.Parallel() fCoord := newFakeCoordinator()