chore: Rbac errors should be returned, and not hidden behind 404 (#7122)

* chore: Rbac errors should be returned, and not hidden behind 404

SqlErrNoRows was hiding actual errors
* Replace sql.ErrNoRow checks
* Remove sql err no rows check from dbauthz test
* Fix to use dbauthz system user
This commit is contained in:
Steven Masley 2023-04-13 13:06:16 -05:00 committed by GitHub
parent fa64c58e56
commit 38e5b9679b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 50 additions and 72 deletions

View File

@ -3,8 +3,6 @@ package coderd
import (
"context"
"crypto/sha256"
"database/sql"
"errors"
"fmt"
"net"
"net/http"
@ -167,7 +165,7 @@ func (api *API) apiKeyByID(rw http.ResponseWriter, r *http.Request) {
keyID := chi.URLParam(r, "keyid")
key, err := api.Database.GetAPIKeyByID(ctx, keyID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -202,7 +200,7 @@ func (api *API) apiKeyByName(rw http.ResponseWriter, r *http.Request) {
TokenName: tokenName,
UserID: user.ID,
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -323,7 +321,7 @@ func (api *API) deleteAPIKey(rw http.ResponseWriter, r *http.Request) {
defer commitAudit()
err = api.Database.DeleteAPIKeyByID(ctx, keyID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -34,8 +34,8 @@ func (e NotAuthorizedError) Error() string {
// Unwrap will always unwrap to a sql.ErrNoRows so the API returns a 404.
// So 'errors.Is(err, sql.ErrNoRows)' will always be true.
func (NotAuthorizedError) Unwrap() error {
return sql.ErrNoRows
func (e NotAuthorizedError) Unwrap() error {
return e.Err
}
func IsNotAuthorizedError(err error) bool {

View File

@ -2,7 +2,6 @@ package dbauthz_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
@ -219,7 +218,6 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
if err != nil || !hasEmptySliceResponse(resp) {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
}
})

View File

@ -126,7 +126,7 @@ func (api *API) fileByID(rw http.ResponseWriter, r *http.Request) {
}
file, err := api.Database.GetFileByID(ctx, id)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -3,6 +3,7 @@ package httpapi
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"flag"
@ -15,6 +16,8 @@ import (
"github.com/go-playground/validator/v10"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk"
)
@ -80,6 +83,16 @@ func init() {
}
}
// Is404Error returns true if the given error should return a 404 status code.
// Both actual 404s and unauthorized errors should return 404s to not leak
// information about the existence of resources.
func Is404Error(err error) bool {
if err == nil {
return false
}
return xerrors.Is(err, sql.ErrNoRows) || dbauthz.IsNotAuthorizedError(err) || rbac.IsUnauthorizedError(err)
}
// Convenience error functions don't take contexts since their responses are
// static, it doesn't make much sense to trace them.

View File

@ -2,12 +2,9 @@ package httpmw
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
@ -45,7 +42,7 @@ func ExtractGroupByNameParam(db database.Store) func(http.Handler) http.Handler
OrganizationID: org.ID,
Name: name,
})
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -73,7 +70,7 @@ func ExtractGroupParam(db database.Store) func(http.Handler) http.Handler {
}
group, err := db.GetGroupByID(r.Context(), groupID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -2,8 +2,6 @@ package httpmw
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/coder/coder/coderd/database"
@ -47,7 +45,7 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
}
organization, err := db.GetOrganizationByID(ctx, orgID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -77,7 +75,7 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H
OrganizationID: organization.ID,
UserID: user.ID,
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -2,8 +2,6 @@ package httpmw
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
@ -34,7 +32,7 @@ func ExtractTemplateParam(db database.Store) func(http.Handler) http.Handler {
return
}
template, err := db.GetTemplateByID(r.Context(), templateID)
if errors.Is(err, sql.ErrNoRows) || (err == nil && template.Deleted) {
if httpapi.Is404Error(err) || (err == nil && template.Deleted) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -3,7 +3,6 @@ package httpmw
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
@ -35,7 +34,7 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand
return
}
templateVersion, err := db.GetTemplateVersionByID(ctx, templateVersionID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -2,11 +2,8 @@ package httpmw
import (
"context"
"database/sql"
"net/http"
"golang.org/x/xerrors"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@ -71,7 +68,7 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
}
//nolint:gocritic // System needs to be able to get user from param.
user, err = db.GetUserByID(dbauthz.AsSystemRestricted(ctx), apiKey.UserID)
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -2,8 +2,6 @@ package httpmw
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
@ -34,7 +32,7 @@ func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handl
return
}
workspaceBuild, err := db.GetWorkspaceBuildByID(ctx, workspaceBuildID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -2,8 +2,6 @@ package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
@ -37,7 +35,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
return
}
workspace, err := db.GetWorkspaceByID(ctx, workspaceID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -74,7 +72,7 @@ func ExtractWorkspaceAndAgentParam(db database.Store) func(http.Handler) http.Ha
Name: workspaceParts[0],
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -141,7 +141,7 @@ func (api *API) deleteParameter(rw http.ResponseWriter, r *http.Request) {
ScopeID: scopeID,
Name: name,
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -224,7 +224,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
}
// nolint:gocritic // GetWorkspaceAppsByAgentIDs is a system function.
apps, err := api.Database.GetWorkspaceAppsByAgentIDs(ctx, resourceAgentIDs)
apps, err := api.Database.GetWorkspaceAppsByAgentIDs(dbauthz.AsSystemRestricted(ctx), resourceAgentIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}

View File

@ -407,7 +407,7 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
Name: templateName,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -419,11 +419,6 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
return
}
if !api.Authorize(r, rbac.ActionRead, template) {
httpapi.ResourceNotFound(rw)
return
}
createdByNameMap, err := getCreatedByNamesByTemplateIDs(ctx, api.Database, []database.Template{template})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@ -583,10 +578,6 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
func (api *API) templateDAUs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
template := httpmw.TemplateParam(r)
if !api.Authorize(r, rbac.ActionRead, template) {
httpapi.ResourceNotFound(rw)
return
}
resp, _ := api.metricsCache.TemplateDAUs(template.ID)
if resp == nil || resp.Entries == nil {

View File

@ -737,7 +737,7 @@ func (api *API) fetchTemplateVersionDryRunJob(rw http.ResponseWriter, r *http.Re
}
job, err := api.Database.GetProvisionerJobByID(ctx, jobUUID)
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Provisioner job %q not found.", jobUUID),
})
@ -905,7 +905,7 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
},
Name: templateVersionName,
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("No template version found by name %q.", templateVersionName),
})
@ -959,7 +959,7 @@ func (api *API) templateVersionByOrganizationTemplateAndName(rw http.ResponseWri
Name: templateName,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -979,7 +979,7 @@ func (api *API) templateVersionByOrganizationTemplateAndName(rw http.ResponseWri
},
Name: templateVersionName,
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("No template version found by name %q.", templateVersionName),
})
@ -1032,7 +1032,7 @@ func (api *API) previousTemplateVersionByOrganizationTemplateAndName(rw http.Res
Name: templateName,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -1053,7 +1053,7 @@ func (api *API) previousTemplateVersionByOrganizationTemplateAndName(rw http.Res
Name: templateVersionName,
})
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("No template version found by name %q.", templateVersionName),
})
@ -1073,7 +1073,7 @@ func (api *API) previousTemplateVersionByOrganizationTemplateAndName(rw http.Res
TemplateID: templateVersion.TemplateID,
})
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("No previous template version found for %q.", templateVersionName),
})
@ -1138,7 +1138,7 @@ func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Reque
return
}
version, err := api.Database.GetTemplateVersionByID(ctx, req.ID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Template version not found.",
})
@ -1222,7 +1222,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
if req.TemplateID != uuid.Nil {
_, err := api.Database.GetTemplateByID(ctx, req.TemplateID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Template does not exist.",
})
@ -1318,7 +1318,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
if req.FileID != uuid.Nil {
file, err = api.Database.GetFileByID(ctx, req.FileID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "File not found.",
})

View File

@ -314,7 +314,7 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
}
_, err = api.Database.GetOrganizationByID(ctx, req.OrganizationID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Organization does not exist with the provided id %q.", req.OrganizationID),
})
@ -938,7 +938,7 @@ func (api *API) organizationByUserAndName(rw http.ResponseWriter, r *http.Reques
ctx := r.Context()
organizationName := chi.URLParam(r, "organizationname")
organization, err := api.Database.GetOrganizationByName(ctx, organizationName)
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -397,15 +397,10 @@ func (api *API) workspaceAgentStartupLogs(rw http.ResponseWriter, r *http.Reques
ctx = r.Context()
actor, _ = dbauthz.ActorFromContext(ctx)
workspaceAgent = httpmw.WorkspaceAgentParam(r)
workspace = httpmw.WorkspaceParam(r)
logger = api.Logger.With(slog.F("workspace_agent_id", workspaceAgent.ID))
follow = r.URL.Query().Has("follow")
afterRaw = r.URL.Query().Get("after")
)
if !api.Authorize(r, rbac.ActionRead, workspace) {
httpapi.ResourceNotFound(rw)
return
}
var after int64
// Only fetch logs created after the time provided.

View File

@ -227,7 +227,7 @@ func (api *API) workspaceBuildByBuildNumber(rw http.ResponseWriter, r *http.Requ
OwnerID: owner.ID,
Name: workspaceName,
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
@ -243,7 +243,7 @@ func (api *API) workspaceBuildByBuildNumber(rw http.ResponseWriter, r *http.Requ
WorkspaceID: workspace.ID,
BuildNumber: int32(buildNumber),
})
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Workspace %q Build %d does not exist.", workspaceName, buildNumber),
})

View File

@ -1,9 +1,7 @@
package coderd
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -131,7 +129,7 @@ func (api *API) handleAuthInstanceID(rw http.ResponseWriter, r *http.Request, in
ctx := r.Context()
//nolint:gocritic // needed for auth instance id
agent, err := api.Database.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), instanceID)
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Instance with id %q not found.", instanceID),
})

View File

@ -225,7 +225,7 @@ func (api *API) workspaceByOwnerAndName(rw http.ResponseWriter, r *http.Request)
Deleted: includeDeleted,
})
}
if errors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -231,7 +231,7 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) {
})
return
}
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to add or remove non-existent group member",
Detail: err.Error(),

View File

@ -235,7 +235,7 @@ func (api *API) deleteLicense(rw http.ResponseWriter, r *http.Request) {
}
_, err = api.Database.DeleteLicense(ctx, int32(id))
if xerrors.Is(err, sql.ErrNoRows) {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Unknown license ID",
})