chore: add custom querier functions to dbgen (#8496)

* chore: add custom querier functions to dbgen
* chore: parse package was missing some imports, so force them
This commit is contained in:
Steven Masley 2023-07-13 13:12:29 -04:00 committed by GitHub
parent b650ab40f0
commit 3b433181be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 602 additions and 545 deletions

View File

@ -575,11 +575,6 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r
return nil
}
func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) {
// TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier.
return q.GetTemplatesWithFilter(ctx, arg)
}
func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error {
deleteF := func(ctx context.Context, id uuid.UUID) error {
return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{
@ -591,34 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro
return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id)
}
func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
// An actor is authorized to read template group roles if they are authorized to read the template.
template, err := q.db.GetTemplateByID(ctx, id)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
return nil, err
}
return q.db.GetTemplateGroupRoles(ctx, id)
}
func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
// An actor is authorized to query template user roles if they are authorized to read the template.
template, err := q.db.GetTemplateByID(ctx, id)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
return nil, err
}
return q.db.GetTemplateUserRoles(ctx, id)
}
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
}
func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) {
// TODO Implement this with a SQL filter. The count is incorrect without it.
rowUsers, err := q.db.GetUsers(ctx, arg)
@ -655,11 +622,6 @@ func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id)
}
func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
// TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier.
return q.GetWorkspaces(ctx, arg)
}
func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error {
return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{
@ -2642,3 +2604,41 @@ func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (d
}
return q.db.UpsertTailnetCoordinator(ctx, id)
}
func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) {
// TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier.
return q.GetTemplatesWithFilter(ctx, arg)
}
func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
// An actor is authorized to read template group roles if they are authorized to read the template.
template, err := q.db.GetTemplateByID(ctx, id)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
return nil, err
}
return q.db.GetTemplateGroupRoles(ctx, id)
}
func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
// An actor is authorized to query template user roles if they are authorized to read the template.
template, err := q.db.GetTemplateByID(ctx, id)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
return nil, err
}
return q.db.GetTemplateUserRoles(ctx, id)
}
func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
// TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier.
return q.GetWorkspaces(ctx, arg)
}
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
}

View File

@ -266,80 +266,6 @@ func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) {
return database.User{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
if err := validateDatabaseType(params); err != nil {
return 0, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return -1, err
}
}
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
continue
}
users = append(users, user)
}
// Filter out deleted since they should never be returned..
tmp := make([]database.User, 0, len(users))
for _, user := range users {
if !user.Deleted {
tmp = append(tmp, user)
}
}
users = tmp
if params.Search != "" {
tmp := make([]database.User, 0, len(users))
for i, user := range users {
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
}
}
users = tmp
}
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
return strings.EqualFold(string(a), string(b))
}) {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
usersFilteredByRole := make([]database.User, 0, len(users))
for i, user := range users {
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
usersFilteredByRole = append(usersFilteredByRole, users[i])
}
}
users = usersFilteredByRole
}
return int64(len(users)), nil
}
func convertUsers(users []database.User, count int64) []database.GetUsersRow {
rows := make([]database.GetUsersRow, len(users))
for i, u := range users {
@ -363,259 +289,6 @@ func convertUsers(users []database.User, count int64) []database.GetUsersRow {
return rows
}
//nolint:gocyclo
func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
if err := validateDatabaseType(arg); err != nil {
return nil, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
if prepared != nil {
// Call this to match the same function calls as the SQL implementation.
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return nil, err
}
}
workspaces := make([]database.Workspace, 0)
for _, workspace := range q.workspaces {
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
continue
}
if arg.OwnerUsername != "" {
owner, err := q.getUserByIDNoLock(workspace.OwnerID)
if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) {
continue
}
}
if arg.TemplateName != "" {
template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID)
if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) {
continue
}
}
if !arg.Deleted && workspace.Deleted {
continue
}
if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) {
continue
}
if arg.Status != "" {
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
if err != nil {
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
// This logic should match the logic in the workspace.sql file.
var statusMatch bool
switch database.WorkspaceStatus(arg.Status) {
case database.WorkspaceStatusPending:
statusMatch = isNull(job.StartedAt)
case database.WorkspaceStatusStarting:
statusMatch = isNotNull(job.StartedAt) &&
isNull(job.CanceledAt) &&
isNull(job.CompletedAt) &&
time.Since(job.UpdatedAt) < 30*time.Second &&
build.Transition == database.WorkspaceTransitionStart
case database.WorkspaceStatusRunning:
statusMatch = isNotNull(job.CompletedAt) &&
isNull(job.CanceledAt) &&
isNull(job.Error) &&
build.Transition == database.WorkspaceTransitionStart
case database.WorkspaceStatusStopping:
statusMatch = isNotNull(job.StartedAt) &&
isNull(job.CanceledAt) &&
isNull(job.CompletedAt) &&
time.Since(job.UpdatedAt) < 30*time.Second &&
build.Transition == database.WorkspaceTransitionStop
case database.WorkspaceStatusStopped:
statusMatch = isNotNull(job.CompletedAt) &&
isNull(job.CanceledAt) &&
isNull(job.Error) &&
build.Transition == database.WorkspaceTransitionStop
case database.WorkspaceStatusFailed:
statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) ||
(isNotNull(job.CompletedAt) && isNotNull(job.Error))
case database.WorkspaceStatusCanceling:
statusMatch = isNotNull(job.CanceledAt) &&
isNull(job.CompletedAt)
case database.WorkspaceStatusCanceled:
statusMatch = isNotNull(job.CanceledAt) &&
isNotNull(job.CompletedAt)
case database.WorkspaceStatusDeleted:
statusMatch = isNotNull(job.StartedAt) &&
isNull(job.CanceledAt) &&
isNotNull(job.CompletedAt) &&
time.Since(job.UpdatedAt) < 30*time.Second &&
build.Transition == database.WorkspaceTransitionDelete &&
isNull(job.Error)
case database.WorkspaceStatusDeleting:
statusMatch = isNull(job.CompletedAt) &&
isNull(job.CanceledAt) &&
isNull(job.Error) &&
build.Transition == database.WorkspaceTransitionDelete
default:
return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status)
}
if !statusMatch {
continue
}
}
if arg.HasAgent != "" {
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
if err != nil {
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace resources: %w", err)
}
var workspaceResourceIDs []uuid.UUID
for _, wr := range workspaceResources {
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
}
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
if err != nil {
return nil, xerrors.Errorf("get workspace agents: %w", err)
}
var hasAgentMatched bool
for _, wa := range workspaceAgents {
if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent {
hasAgentMatched = true
}
}
if !hasAgentMatched {
continue
}
}
if len(arg.TemplateIds) > 0 {
match := false
for _, id := range arg.TemplateIds {
if workspace.TemplateID == id {
match = true
break
}
}
if !match {
continue
}
}
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
continue
}
workspaces = append(workspaces, workspace)
}
// Sort workspaces (ORDER BY)
isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool {
return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart
}
preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{}
preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{}
preloadedUsers := map[uuid.UUID]database.User{}
for _, w := range workspaces {
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID)
if err == nil {
preloadedWorkspaceBuilds[w.ID] = build
} else if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err == nil {
preloadedProvisionerJobs[w.ID] = job
} else if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
user, err := q.getUserByIDNoLock(w.OwnerID)
if err == nil {
preloadedUsers[w.ID] = user
} else if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get user: %w", err)
}
}
sort.Slice(workspaces, func(i, j int) bool {
w1 := workspaces[i]
w2 := workspaces[j]
// Order by: running first
w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID])
w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID])
if w1IsRunning && !w2IsRunning {
return true
}
if !w1IsRunning && w2IsRunning {
return false
}
// Order by: usernames
if w1.ID != w2.ID {
return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username})
}
// Order by: workspace names
return sort.StringsAreSorted([]string{w1.Name, w2.Name})
})
beforePageCount := len(workspaces)
if arg.Offset > 0 {
if int(arg.Offset) > len(workspaces) {
return []database.GetWorkspacesRow{}, nil
}
workspaces = workspaces[arg.Offset:]
}
if arg.Limit > 0 {
if int(arg.Limit) > len(workspaces) {
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
}
workspaces = workspaces[:arg.Limit]
}
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
}
// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc.
// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent.
func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string {
@ -778,66 +451,6 @@ func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (da
return database.Template{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
if err := validateDatabaseType(arg); err != nil {
return nil, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
if err != nil {
return nil, err
}
}
var templates []database.Template
for _, template := range q.templates {
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
continue
}
if template.Deleted != arg.Deleted {
continue
}
if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID {
continue
}
if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) {
continue
}
if len(arg.IDs) > 0 {
match := false
for _, id := range arg.IDs {
if template.ID == id {
match = true
break
}
}
if !match {
continue
}
}
templates = append(templates, template.DeepCopy())
}
if len(templates) > 0 {
slices.SortFunc(templates, func(i, j database.Template) bool {
if i.Name != j.Name {
return i.Name < j.Name
}
return i.ID.String() < j.ID.String()
})
return templates, nil
}
return nil, sql.ErrNoRows
}
func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) {
for _, templateVersion := range q.templateVersions {
if templateVersion.ID != templateVersionID {
@ -848,84 +461,6 @@ func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVe
return database.TemplateVersion{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var template database.Template
for _, t := range q.templates {
if t.ID == id {
template = t
break
}
}
if template.ID == uuid.Nil {
return nil, sql.ErrNoRows
}
users := make([]database.TemplateUser, 0, len(template.UserACL))
for k, v := range template.UserACL {
user, err := q.getUserByIDNoLock(uuid.MustParse(k))
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get user by ID: %w", err)
}
// We don't delete users from the map if they
// get deleted so just skip.
if xerrors.Is(err, sql.ErrNoRows) {
continue
}
if user.Deleted || user.Status == database.UserStatusSuspended {
continue
}
users = append(users, database.TemplateUser{
User: user,
Actions: v,
})
}
return users, nil
}
func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var template database.Template
for _, t := range q.templates {
if t.ID == id {
template = t
break
}
}
if template.ID == uuid.Nil {
return nil, sql.ErrNoRows
}
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
for k, v := range template.GroupACL {
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get group by ID: %w", err)
}
// We don't delete groups from the map if they
// get deleted so just skip.
if xerrors.Is(err, sql.ErrNoRows) {
continue
}
groups = append(groups, database.TemplateGroup{
Group: group,
Actions: v,
})
}
return groups, nil
}
func (q *FakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
@ -5438,3 +4973,468 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC
func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) {
return database.TailnetCoordinator{}, ErrUnimplemented
}
func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
if err := validateDatabaseType(arg); err != nil {
return nil, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
if err != nil {
return nil, err
}
}
var templates []database.Template
for _, template := range q.templates {
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
continue
}
if template.Deleted != arg.Deleted {
continue
}
if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID {
continue
}
if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) {
continue
}
if len(arg.IDs) > 0 {
match := false
for _, id := range arg.IDs {
if template.ID == id {
match = true
break
}
}
if !match {
continue
}
}
templates = append(templates, template.DeepCopy())
}
if len(templates) > 0 {
slices.SortFunc(templates, func(i, j database.Template) bool {
if i.Name != j.Name {
return i.Name < j.Name
}
return i.ID.String() < j.ID.String()
})
return templates, nil
}
return nil, sql.ErrNoRows
}
func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var template database.Template
for _, t := range q.templates {
if t.ID == id {
template = t
break
}
}
if template.ID == uuid.Nil {
return nil, sql.ErrNoRows
}
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
for k, v := range template.GroupACL {
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get group by ID: %w", err)
}
// We don't delete groups from the map if they
// get deleted so just skip.
if xerrors.Is(err, sql.ErrNoRows) {
continue
}
groups = append(groups, database.TemplateGroup{
Group: group,
Actions: v,
})
}
return groups, nil
}
func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var template database.Template
for _, t := range q.templates {
if t.ID == id {
template = t
break
}
}
if template.ID == uuid.Nil {
return nil, sql.ErrNoRows
}
users := make([]database.TemplateUser, 0, len(template.UserACL))
for k, v := range template.UserACL {
user, err := q.getUserByIDNoLock(uuid.MustParse(k))
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get user by ID: %w", err)
}
// We don't delete users from the map if they
// get deleted so just skip.
if xerrors.Is(err, sql.ErrNoRows) {
continue
}
if user.Deleted || user.Status == database.UserStatusSuspended {
continue
}
users = append(users, database.TemplateUser{
User: user,
Actions: v,
})
}
return users, nil
}
//nolint:gocyclo
func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
if err := validateDatabaseType(arg); err != nil {
return nil, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
if prepared != nil {
// Call this to match the same function calls as the SQL implementation.
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return nil, err
}
}
workspaces := make([]database.Workspace, 0)
for _, workspace := range q.workspaces {
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
continue
}
if arg.OwnerUsername != "" {
owner, err := q.getUserByIDNoLock(workspace.OwnerID)
if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) {
continue
}
}
if arg.TemplateName != "" {
template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID)
if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) {
continue
}
}
if !arg.Deleted && workspace.Deleted {
continue
}
if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) {
continue
}
if arg.Status != "" {
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
if err != nil {
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
// This logic should match the logic in the workspace.sql file.
var statusMatch bool
switch database.WorkspaceStatus(arg.Status) {
case database.WorkspaceStatusPending:
statusMatch = isNull(job.StartedAt)
case database.WorkspaceStatusStarting:
statusMatch = isNotNull(job.StartedAt) &&
isNull(job.CanceledAt) &&
isNull(job.CompletedAt) &&
time.Since(job.UpdatedAt) < 30*time.Second &&
build.Transition == database.WorkspaceTransitionStart
case database.WorkspaceStatusRunning:
statusMatch = isNotNull(job.CompletedAt) &&
isNull(job.CanceledAt) &&
isNull(job.Error) &&
build.Transition == database.WorkspaceTransitionStart
case database.WorkspaceStatusStopping:
statusMatch = isNotNull(job.StartedAt) &&
isNull(job.CanceledAt) &&
isNull(job.CompletedAt) &&
time.Since(job.UpdatedAt) < 30*time.Second &&
build.Transition == database.WorkspaceTransitionStop
case database.WorkspaceStatusStopped:
statusMatch = isNotNull(job.CompletedAt) &&
isNull(job.CanceledAt) &&
isNull(job.Error) &&
build.Transition == database.WorkspaceTransitionStop
case database.WorkspaceStatusFailed:
statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) ||
(isNotNull(job.CompletedAt) && isNotNull(job.Error))
case database.WorkspaceStatusCanceling:
statusMatch = isNotNull(job.CanceledAt) &&
isNull(job.CompletedAt)
case database.WorkspaceStatusCanceled:
statusMatch = isNotNull(job.CanceledAt) &&
isNotNull(job.CompletedAt)
case database.WorkspaceStatusDeleted:
statusMatch = isNotNull(job.StartedAt) &&
isNull(job.CanceledAt) &&
isNotNull(job.CompletedAt) &&
time.Since(job.UpdatedAt) < 30*time.Second &&
build.Transition == database.WorkspaceTransitionDelete &&
isNull(job.Error)
case database.WorkspaceStatusDeleting:
statusMatch = isNull(job.CompletedAt) &&
isNull(job.CanceledAt) &&
isNull(job.Error) &&
build.Transition == database.WorkspaceTransitionDelete
default:
return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status)
}
if !statusMatch {
continue
}
}
if arg.HasAgent != "" {
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
if err != nil {
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace resources: %w", err)
}
var workspaceResourceIDs []uuid.UUID
for _, wr := range workspaceResources {
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
}
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
if err != nil {
return nil, xerrors.Errorf("get workspace agents: %w", err)
}
var hasAgentMatched bool
for _, wa := range workspaceAgents {
if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent {
hasAgentMatched = true
}
}
if !hasAgentMatched {
continue
}
}
if len(arg.TemplateIds) > 0 {
match := false
for _, id := range arg.TemplateIds {
if workspace.TemplateID == id {
match = true
break
}
}
if !match {
continue
}
}
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
continue
}
workspaces = append(workspaces, workspace)
}
// Sort workspaces (ORDER BY)
isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool {
return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart
}
preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{}
preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{}
preloadedUsers := map[uuid.UUID]database.User{}
for _, w := range workspaces {
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID)
if err == nil {
preloadedWorkspaceBuilds[w.ID] = build
} else if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err == nil {
preloadedProvisionerJobs[w.ID] = job
} else if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
user, err := q.getUserByIDNoLock(w.OwnerID)
if err == nil {
preloadedUsers[w.ID] = user
} else if !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get user: %w", err)
}
}
sort.Slice(workspaces, func(i, j int) bool {
w1 := workspaces[i]
w2 := workspaces[j]
// Order by: running first
w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID])
w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID])
if w1IsRunning && !w2IsRunning {
return true
}
if !w1IsRunning && w2IsRunning {
return false
}
// Order by: usernames
if w1.ID != w2.ID {
return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username})
}
// Order by: workspace names
return sort.StringsAreSorted([]string{w1.Name, w2.Name})
})
beforePageCount := len(workspaces)
if arg.Offset > 0 {
if int(arg.Offset) > len(workspaces) {
return []database.GetWorkspacesRow{}, nil
}
workspaces = workspaces[arg.Offset:]
}
if arg.Limit > 0 {
if int(arg.Limit) > len(workspaces) {
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
}
workspaces = workspaces[:arg.Limit]
}
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
}
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
if err := validateDatabaseType(params); err != nil {
return 0, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return -1, err
}
}
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
continue
}
users = append(users, user)
}
// Filter out deleted since they should never be returned..
tmp := make([]database.User, 0, len(users))
for _, user := range users {
if !user.Deleted {
tmp = append(tmp, user)
}
}
users = tmp
if params.Search != "" {
tmp := make([]database.User, 0, len(users))
for i, user := range users {
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
}
}
users = tmp
}
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
return strings.EqualFold(string(a), string(b))
}) {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
usersFilteredByRole := make([]database.User, 0, len(users))
for i, user := range users {
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
usersFilteredByRole = append(usersFilteredByRole, users[i])
}
}
users = usersFilteredByRole
}
return int64(len(users)), nil
}

View File

@ -16,6 +16,12 @@ import (
"github.com/coder/coder/coderd/rbac"
)
var (
// Force these imports, for some reason the autogen does not include them.
_ uuid.UUID
_ rbac.Action
)
const wrapname = "dbmetrics.metricsStore"
// New returns a database.Store that registers metrics for all queries to reg.
@ -73,41 +79,6 @@ func (m metricsStore) InTx(f func(database.Store) error, options *sql.TxOptions)
return err
}
func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
start := time.Now()
templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds())
return templates, err
}
func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
start := time.Now()
roles, err := m.s.GetTemplateGroupRoles(ctx, id)
m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds())
return roles, err
}
func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
start := time.Now()
roles, err := m.s.GetTemplateUserRoles(ctx, id)
m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds())
return roles, err
}
func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
start := time.Now()
workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds())
return workspaces, err
}
func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
start := time.Now()
count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds())
return count, err
}
func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
start := time.Now()
err := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
@ -1639,3 +1610,38 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID
defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds())
return m.s.UpsertTailnetCoordinator(ctx, id)
}
func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
start := time.Now()
templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds())
return templates, err
}
func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
start := time.Now()
roles, err := m.s.GetTemplateGroupRoles(ctx, id)
m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds())
return roles, err
}
func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
start := time.Now()
roles, err := m.s.GetTemplateUserRoles(ctx, id)
m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds())
return roles, err
}
func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
start := time.Now()
workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds())
return workspaces, err
}
func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
start := time.Now()
count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds())
return count, err
}

View File

@ -418,21 +418,44 @@ type querierFunction struct {
// readQuerierFunctions reads the functions from coderd/database/querier.go
func readQuerierFunctions() ([]querierFunction, error) {
f, err := parseDBFile("querier.go")
if err != nil {
return nil, xerrors.Errorf("parse querier.go: %w", err)
}
funcs, err := loadInterfaceFuncs(f, "sqlcQuerier")
if err != nil {
return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err)
}
customFile, err := parseDBFile("modelqueries.go")
if err != nil {
return nil, xerrors.Errorf("parse modelqueriers.go: %w", err)
}
// Custom funcs should be appended after the regular functions
customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier")
if err != nil {
return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err)
}
return append(funcs, customFuncs...), nil
}
func parseDBFile(filename string) (*dst.File, error) {
localPath, err := localFilePath()
if err != nil {
return nil, err
}
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go")
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename)
querierData, err := os.ReadFile(querierPath)
if err != nil {
return nil, xerrors.Errorf("read querier: %w", err)
return nil, xerrors.Errorf("read %s: %w", filename, err)
}
f, err := decorator.Parse(querierData)
if err != nil {
return nil, err
}
return f, err
}
func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) {
var querier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
@ -447,7 +470,7 @@ func readQuerierFunctions() ([]querierFunction, error) {
}
// This is the name of the interface. If that ever changes,
// this will need to be updated.
if typeSpec.Name.Name != "sqlcQuerier" {
if typeSpec.Name.Name != interfaceName {
continue
}
querier, ok = typeSpec.Type.(*dst.InterfaceType)
@ -461,7 +484,8 @@ func readQuerierFunctions() ([]querierFunction, error) {
return nil, xerrors.Errorf("querier not found")
}
funcs := []querierFunction{}
for _, method := range querier.Methods.List {
allMethods := interfaceMethods(querier)
for _, method := range allMethods {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
@ -540,3 +564,30 @@ func nameFromSnakeCase(s string) string {
}
return ret
}
// interfaceMethods returns all embedded methods of an interface.
func interfaceMethods(i *dst.InterfaceType) []*dst.Field {
var allMethods []*dst.Field
for _, field := range i.Methods.List {
switch fieldType := field.Type.(type) {
case *dst.FuncType:
allMethods = append(allMethods, field)
case *dst.InterfaceType:
allMethods = append(allMethods, interfaceMethods(fieldType)...)
case *dst.Ident:
// Embedded interfaces are Idents -> TypeSpec -> InterfaceType
// If the embedded interface is not in the parsed file, then
// the Obj will be nil.
if fieldType.Obj != nil {
objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec)
if ok {
isInterface, ok := objDecl.Type.(*dst.InterfaceType)
if ok {
allMethods = append(allMethods, interfaceMethods(isInterface)...)
}
}
}
}
}
return allMethods
}