From 2f54f769be37b54cbeca079bdf3aa7818a059635 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Mon, 4 Dec 2023 10:01:45 -0600 Subject: [PATCH] feat: allow IDP to return single string for roles/groups claim (#10993) * feat: allow IDP to return single string instead of array for roles/groups claim This is to support ADFS --- coderd/userauth.go | 114 +++++++++++++++--------- coderd/userauth_internal_test.go | 135 +++++++++++++++++++++++++++++ enterprise/coderd/userauth_test.go | 53 +++++++++++ 3 files changed, 261 insertions(+), 41 deletions(-) create mode 100644 coderd/userauth_internal_test.go diff --git a/coderd/userauth.go b/coderd/userauth.go index a969e37269..0ffa4ca271 100644 --- a/coderd/userauth.go +++ b/coderd/userauth.go @@ -1019,31 +1019,26 @@ func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interfac if api.OIDCConfig.GroupField != "" { usingGroups = true groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField] - if ok && api.OIDCConfig.GroupField != "" { - // Convert the []interface{} we get to a []string. - groupsInterface, ok := groupsRaw.([]interface{}) - if ok { - api.Logger.Debug(ctx, "groups returned in oidc claims", - slog.F("len", len(groupsInterface)), - slog.F("groups", groupsInterface), - ) - - for _, groupInterface := range groupsInterface { - group, ok := groupInterface.(string) - if !ok { - return false, nil, xerrors.Errorf("Invalid group type. Expected string, got: %T", groupInterface) - } - - if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok { - group = mappedGroup - } - - groups = append(groups, group) - } - } else { - api.Logger.Debug(ctx, "groups field was an unknown type", + if ok { + parsedGroups, err := parseStringSliceClaim(groupsRaw) + if err != nil { + api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims", slog.F("type", fmt.Sprintf("%T", groupsRaw)), + slog.Error(err), ) + return false, nil, err + } + + api.Logger.Debug(ctx, "groups returned in oidc claims", + slog.F("len", len(parsedGroups)), + slog.F("groups", parsedGroups), + ) + + for _, group := range parsedGroups { + if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok { + group = mappedGroup + } + groups = append(groups, group) } } } @@ -1079,10 +1074,11 @@ func (api *API) oidcRoles(ctx context.Context, rw http.ResponseWriter, r *http.R rolesRow = []interface{}{} } - rolesInterface, ok := rolesRow.([]interface{}) - if !ok { - api.Logger.Error(ctx, "oidc claim user roles field was an unknown type", + parsedRoles, err := parseStringSliceClaim(rolesRow) + if err != nil { + api.Logger.Error(ctx, "oidc claims user roles field was an unknown type", slog.F("type", fmt.Sprintf("%T", rolesRow)), + slog.Error(err), ) site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ Status: http.StatusInternalServerError, @@ -1096,21 +1092,10 @@ func (api *API) oidcRoles(ctx context.Context, rw http.ResponseWriter, r *http.R } api.Logger.Debug(ctx, "roles returned in oidc claims", - slog.F("len", len(rolesInterface)), - slog.F("roles", rolesInterface), + slog.F("len", len(parsedRoles)), + slog.F("roles", parsedRoles), ) - for _, roleInterface := range rolesInterface { - role, ok := roleInterface.(string) - if !ok { - api.Logger.Error(ctx, "invalid oidc user role type", - slog.F("type", fmt.Sprintf("%T", rolesRow)), - ) - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: fmt.Sprintf("Invalid user role type. Expected string, got: %T", roleInterface), - }) - return nil, false - } - + for _, role := range parsedRoles { if mappedRoles, ok := api.OIDCConfig.UserRoleMapping[role]; ok { if len(mappedRoles) == 0 { continue @@ -1449,7 +1434,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C if err != nil { return httpError{ code: http.StatusBadRequest, - msg: "Invalid roles through OIDC claim", + msg: "Invalid roles through OIDC claims", detail: fmt.Sprintf("Error from role assignment attempt: %s", err.Error()), renderStaticPage: true, } @@ -1744,3 +1729,50 @@ func wrongLoginTypeHTTPError(user database.LoginType, params database.LoginType) params, user, addedMsg), } } + +// parseStringSliceClaim parses the claim for groups and roles, expected []string. +// +// Some providers like ADFS return a single string instead of an array if there +// is only 1 element. So this function handles the edge cases. +func parseStringSliceClaim(claim interface{}) ([]string, error) { + groups := make([]string, 0) + if claim == nil { + return groups, nil + } + + // The simple case is the type is exactly what we expected + asStringArray, ok := claim.([]string) + if ok { + return asStringArray, nil + } + + asArray, ok := claim.([]interface{}) + if ok { + for i, item := range asArray { + asString, ok := item.(string) + if !ok { + return nil, xerrors.Errorf("invalid claim type. Element %d expected a string, got: %T", i, item) + } + groups = append(groups, asString) + } + return groups, nil + } + + asString, ok := claim.(string) + if ok { + if asString == "" { + // Empty string should be 0 groups. + return []string{}, nil + } + // If it is a single string, first check if it is a csv. + // If a user hits this, it is likely a misconfiguration and they need + // to reconfigure their IDP to send an array instead. + if strings.Contains(asString, ",") { + return nil, xerrors.Errorf("invalid claim type. Got a csv string (%q), change this claim to return an array of strings instead.", asString) + } + return []string{asString}, nil + } + + // Not sure what the user gave us. + return nil, xerrors.Errorf("invalid claim type. Expected an array of strings, got: %T", claim) +} diff --git a/coderd/userauth_internal_test.go b/coderd/userauth_internal_test.go new file mode 100644 index 0000000000..421e654995 --- /dev/null +++ b/coderd/userauth_internal_test.go @@ -0,0 +1,135 @@ +package coderd + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseStringSliceClaim(t *testing.T) { + t.Parallel() + + cases := []struct { + Name string + GoClaim interface{} + // JSON Claim allows testing the json -> go conversion + // of some strings. + JSONClaim string + ErrorExpected bool + ExpectedSlice []string + }{ + { + Name: "Nil", + GoClaim: nil, + ExpectedSlice: []string{}, + }, + // Go Slices + { + Name: "EmptySlice", + GoClaim: []string{}, + ExpectedSlice: []string{}, + }, + { + Name: "StringSlice", + GoClaim: []string{"a", "b", "c"}, + ExpectedSlice: []string{"a", "b", "c"}, + }, + { + Name: "InterfaceSlice", + GoClaim: []interface{}{"a", "b", "c"}, + ExpectedSlice: []string{"a", "b", "c"}, + }, + { + Name: "MixedSlice", + GoClaim: []interface{}{"a", string("b"), interface{}("c")}, + ExpectedSlice: []string{"a", "b", "c"}, + }, + { + Name: "StringSliceOneElement", + GoClaim: []string{"a"}, + ExpectedSlice: []string{"a"}, + }, + // Json Slices + { + Name: "JSONEmptySlice", + JSONClaim: `[]`, + ExpectedSlice: []string{}, + }, + { + Name: "JSONStringSlice", + JSONClaim: `["a", "b", "c"]`, + ExpectedSlice: []string{"a", "b", "c"}, + }, + { + Name: "JSONStringSliceOneElement", + JSONClaim: `["a"]`, + ExpectedSlice: []string{"a"}, + }, + // Go string + { + Name: "String", + GoClaim: "a", + ExpectedSlice: []string{"a"}, + }, + { + Name: "EmptyString", + GoClaim: "", + ExpectedSlice: []string{}, + }, + { + Name: "Interface", + GoClaim: interface{}("a"), + ExpectedSlice: []string{"a"}, + }, + // JSON string + { + Name: "JSONString", + JSONClaim: `"a"`, + ExpectedSlice: []string{"a"}, + }, + { + Name: "JSONEmptyString", + JSONClaim: `""`, + ExpectedSlice: []string{}, + }, + // Go Errors + { + Name: "IntegerInSlice", + GoClaim: []interface{}{"a", "b", 1}, + ErrorExpected: true, + }, + // Json Errors + { + Name: "JSONIntegerInSlice", + JSONClaim: `["a", "b", 1]`, + ErrorExpected: true, + }, + { + Name: "JSON_CSV", + JSONClaim: `"a,b,c"`, + ErrorExpected: true, + }, + } + + for _, c := range cases { + c := c + t.Run(c.Name, func(t *testing.T) { + t.Parallel() + + if len(c.JSONClaim) > 0 { + require.Nil(t, c.GoClaim, "go claim should be nil if json set") + err := json.Unmarshal([]byte(c.JSONClaim), &c.GoClaim) + require.NoError(t, err, "unmarshal json claim") + } + + found, err := parseStringSliceClaim(c.GoClaim) + if c.ErrorExpected { + require.Error(t, err) + } else { + require.NoError(t, err) + require.ElementsMatch(t, c.ExpectedSlice, found, "expected groups") + } + }) + } +} diff --git a/enterprise/coderd/userauth_test.go b/enterprise/coderd/userauth_test.go index 9d7e2762f0..70e63f6a1e 100644 --- a/enterprise/coderd/userauth_test.go +++ b/enterprise/coderd/userauth_test.go @@ -55,6 +55,33 @@ func TestUserOIDC(t *testing.T) { runner.AssertRoles(t, "alice", []string{}) }) + // Some IDPs (ADFS) send the "string" type vs "[]string" if only + // 1 role exists. + t.Run("SingleRoleString", func(t *testing.T) { + t.Parallel() + + const oidcRoleName = "TemplateAuthor" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.UserRoleField = "roles" + cfg.UserRoleMapping = map[string][]string{ + oidcRoleName: {rbac.RoleTemplateAdmin()}, + } + }, + }) + + // User starts with the owner role + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + // This is sent as a **string** intentionally instead + // of an array. + "roles": oidcRoleName, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertRoles(t, "alice", []string{rbac.RoleTemplateAdmin()}) + }) + // A user has some roles, then on an oauth refresh will lose said // roles from an updated claim. t.Run("NewUserAndRemoveRolesOnRefresh", func(t *testing.T) { @@ -334,6 +361,32 @@ func TestUserOIDC(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) runner.AssertGroups(t, "alice", []string{groupName}) }) + + // Some IDPs (ADFS) send the "string" type vs "[]string" if only + // 1 group exists. + t.Run("SingleRoleGroup", func(t *testing.T) { + t.Parallel() + + const groupClaim = "custom-groups" + const groupName = "bingbong" + runner := setupOIDCTest(t, oidcTestConfig{ + Config: func(cfg *coderd.OIDCConfig) { + cfg.AllowSignups = true + cfg.GroupField = groupClaim + cfg.CreateMissingGroups = true + }, + }) + + // User starts with the owner role + _, resp := runner.Login(t, jwt.MapClaims{ + "email": "alice@coder.com", + // This is sent as a **string** intentionally instead + // of an array. + groupClaim: groupName, + }) + require.Equal(t, http.StatusOK, resp.StatusCode) + runner.AssertGroups(t, "alice", []string{groupName}) + }) }) t.Run("Refresh", func(t *testing.T) {