mirror of https://github.com/coder/coder.git
chore: add custom querier functions to dbgen (#8496)
* chore: add custom querier functions to dbgen * chore: parse package was missing some imports, so force them
This commit is contained in:
parent
b650ab40f0
commit
3b433181be
|
@ -575,11 +575,6 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r
|
|||
return nil
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
// TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier.
|
||||
return q.GetTemplatesWithFilter(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error {
|
||||
deleteF := func(ctx context.Context, id uuid.UUID) error {
|
||||
return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{
|
||||
|
@ -591,34 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro
|
|||
return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
// An actor is authorized to read template group roles if they are authorized to read the template.
|
||||
template, err := q.db.GetTemplateByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTemplateGroupRoles(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
// An actor is authorized to query template user roles if they are authorized to read the template.
|
||||
template, err := q.db.GetTemplateByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTemplateUserRoles(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) {
|
||||
// TODO Implement this with a SQL filter. The count is incorrect without it.
|
||||
rowUsers, err := q.db.GetUsers(ctx, arg)
|
||||
|
@ -655,11 +622,6 @@ func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error {
|
|||
return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
// TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier.
|
||||
return q.GetWorkspaces(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error {
|
||||
return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error {
|
||||
return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{
|
||||
|
@ -2642,3 +2604,41 @@ func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (d
|
|||
}
|
||||
return q.db.UpsertTailnetCoordinator(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
// TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier.
|
||||
return q.GetTemplatesWithFilter(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
// An actor is authorized to read template group roles if they are authorized to read the template.
|
||||
template, err := q.db.GetTemplateByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTemplateGroupRoles(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
// An actor is authorized to query template user roles if they are authorized to read the template.
|
||||
template, err := q.db.GetTemplateByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetTemplateUserRoles(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
// TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier.
|
||||
return q.GetWorkspaces(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.GetAuthorizedUserCount(ctx, arg, prepared)
|
||||
}
|
||||
|
|
|
@ -266,80 +266,6 @@ func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) {
|
|||
return database.User{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
if err := validateDatabaseType(params); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
users := make([]database.User, 0, len(q.users))
|
||||
|
||||
for _, user := range q.users {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
// Filter out deleted since they should never be returned..
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
if !user.Deleted {
|
||||
tmp = append(tmp, user)
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
|
||||
if params.Search != "" {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
}
|
||||
|
||||
if len(params.Status) > 0 {
|
||||
usersFilteredByStatus := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
|
||||
return strings.EqualFold(string(a), string(b))
|
||||
}) {
|
||||
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
|
||||
}
|
||||
}
|
||||
users = usersFilteredByStatus
|
||||
}
|
||||
|
||||
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
|
||||
usersFilteredByRole := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
|
||||
usersFilteredByRole = append(usersFilteredByRole, users[i])
|
||||
}
|
||||
}
|
||||
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
return int64(len(users)), nil
|
||||
}
|
||||
|
||||
func convertUsers(users []database.User, count int64) []database.GetUsersRow {
|
||||
rows := make([]database.GetUsersRow, len(users))
|
||||
for i, u := range users {
|
||||
|
@ -363,259 +289,6 @@ func convertUsers(users []database.User, count int64) []database.GetUsersRow {
|
|||
return rows
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if prepared != nil {
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
workspaces := make([]database.Workspace, 0)
|
||||
for _, workspace := range q.workspaces {
|
||||
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.OwnerUsername != "" {
|
||||
owner, err := q.getUserByIDNoLock(workspace.OwnerID)
|
||||
if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if arg.TemplateName != "" {
|
||||
template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID)
|
||||
if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if !arg.Deleted && workspace.Deleted {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.Status != "" {
|
||||
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
// This logic should match the logic in the workspace.sql file.
|
||||
var statusMatch bool
|
||||
switch database.WorkspaceStatus(arg.Status) {
|
||||
case database.WorkspaceStatusPending:
|
||||
statusMatch = isNull(job.StartedAt)
|
||||
case database.WorkspaceStatusStarting:
|
||||
statusMatch = isNotNull(job.StartedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.CompletedAt) &&
|
||||
time.Since(job.UpdatedAt) < 30*time.Second &&
|
||||
build.Transition == database.WorkspaceTransitionStart
|
||||
|
||||
case database.WorkspaceStatusRunning:
|
||||
statusMatch = isNotNull(job.CompletedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.Error) &&
|
||||
build.Transition == database.WorkspaceTransitionStart
|
||||
|
||||
case database.WorkspaceStatusStopping:
|
||||
statusMatch = isNotNull(job.StartedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.CompletedAt) &&
|
||||
time.Since(job.UpdatedAt) < 30*time.Second &&
|
||||
build.Transition == database.WorkspaceTransitionStop
|
||||
|
||||
case database.WorkspaceStatusStopped:
|
||||
statusMatch = isNotNull(job.CompletedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.Error) &&
|
||||
build.Transition == database.WorkspaceTransitionStop
|
||||
case database.WorkspaceStatusFailed:
|
||||
statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) ||
|
||||
(isNotNull(job.CompletedAt) && isNotNull(job.Error))
|
||||
|
||||
case database.WorkspaceStatusCanceling:
|
||||
statusMatch = isNotNull(job.CanceledAt) &&
|
||||
isNull(job.CompletedAt)
|
||||
|
||||
case database.WorkspaceStatusCanceled:
|
||||
statusMatch = isNotNull(job.CanceledAt) &&
|
||||
isNotNull(job.CompletedAt)
|
||||
|
||||
case database.WorkspaceStatusDeleted:
|
||||
statusMatch = isNotNull(job.StartedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNotNull(job.CompletedAt) &&
|
||||
time.Since(job.UpdatedAt) < 30*time.Second &&
|
||||
build.Transition == database.WorkspaceTransitionDelete &&
|
||||
isNull(job.Error)
|
||||
|
||||
case database.WorkspaceStatusDeleting:
|
||||
statusMatch = isNull(job.CompletedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.Error) &&
|
||||
build.Transition == database.WorkspaceTransitionDelete
|
||||
|
||||
default:
|
||||
return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status)
|
||||
}
|
||||
if !statusMatch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if arg.HasAgent != "" {
|
||||
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace resources: %w", err)
|
||||
}
|
||||
|
||||
var workspaceResourceIDs []uuid.UUID
|
||||
for _, wr := range workspaceResources {
|
||||
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
|
||||
}
|
||||
|
||||
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace agents: %w", err)
|
||||
}
|
||||
|
||||
var hasAgentMatched bool
|
||||
for _, wa := range workspaceAgents {
|
||||
if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent {
|
||||
hasAgentMatched = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAgentMatched {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(arg.TemplateIds) > 0 {
|
||||
match := false
|
||||
for _, id := range arg.TemplateIds {
|
||||
if workspace.TemplateID == id {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
workspaces = append(workspaces, workspace)
|
||||
}
|
||||
|
||||
// Sort workspaces (ORDER BY)
|
||||
isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool {
|
||||
return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart
|
||||
}
|
||||
|
||||
preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{}
|
||||
preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{}
|
||||
preloadedUsers := map[uuid.UUID]database.User{}
|
||||
|
||||
for _, w := range workspaces {
|
||||
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID)
|
||||
if err == nil {
|
||||
preloadedWorkspaceBuilds[w.ID] = build
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err == nil {
|
||||
preloadedProvisionerJobs[w.ID] = job
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
user, err := q.getUserByIDNoLock(w.OwnerID)
|
||||
if err == nil {
|
||||
preloadedUsers[w.ID] = user
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(workspaces, func(i, j int) bool {
|
||||
w1 := workspaces[i]
|
||||
w2 := workspaces[j]
|
||||
|
||||
// Order by: running first
|
||||
w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID])
|
||||
w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID])
|
||||
|
||||
if w1IsRunning && !w2IsRunning {
|
||||
return true
|
||||
}
|
||||
|
||||
if !w1IsRunning && w2IsRunning {
|
||||
return false
|
||||
}
|
||||
|
||||
// Order by: usernames
|
||||
if w1.ID != w2.ID {
|
||||
return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username})
|
||||
}
|
||||
|
||||
// Order by: workspace names
|
||||
return sort.StringsAreSorted([]string{w1.Name, w2.Name})
|
||||
})
|
||||
|
||||
beforePageCount := len(workspaces)
|
||||
|
||||
if arg.Offset > 0 {
|
||||
if int(arg.Offset) > len(workspaces) {
|
||||
return []database.GetWorkspacesRow{}, nil
|
||||
}
|
||||
workspaces = workspaces[arg.Offset:]
|
||||
}
|
||||
if arg.Limit > 0 {
|
||||
if int(arg.Limit) > len(workspaces) {
|
||||
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
|
||||
}
|
||||
workspaces = workspaces[:arg.Limit]
|
||||
}
|
||||
|
||||
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
|
||||
}
|
||||
|
||||
// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc.
|
||||
// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent.
|
||||
func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string {
|
||||
|
@ -778,66 +451,6 @@ func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (da
|
|||
return database.Template{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var templates []database.Template
|
||||
for _, template := range q.templates {
|
||||
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if template.Deleted != arg.Deleted {
|
||||
continue
|
||||
}
|
||||
if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(arg.IDs) > 0 {
|
||||
match := false
|
||||
for _, id := range arg.IDs {
|
||||
if template.ID == id {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
}
|
||||
templates = append(templates, template.DeepCopy())
|
||||
}
|
||||
if len(templates) > 0 {
|
||||
slices.SortFunc(templates, func(i, j database.Template) bool {
|
||||
if i.Name != j.Name {
|
||||
return i.Name < j.Name
|
||||
}
|
||||
return i.ID.String() < j.ID.String()
|
||||
})
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) {
|
||||
for _, templateVersion := range q.templateVersions {
|
||||
if templateVersion.ID != templateVersionID {
|
||||
|
@ -848,84 +461,6 @@ func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVe
|
|||
return database.TemplateVersion{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var template database.Template
|
||||
for _, t := range q.templates {
|
||||
if t.ID == id {
|
||||
template = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if template.ID == uuid.Nil {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
users := make([]database.TemplateUser, 0, len(template.UserACL))
|
||||
for k, v := range template.UserACL {
|
||||
user, err := q.getUserByIDNoLock(uuid.MustParse(k))
|
||||
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get user by ID: %w", err)
|
||||
}
|
||||
// We don't delete users from the map if they
|
||||
// get deleted so just skip.
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
|
||||
if user.Deleted || user.Status == database.UserStatusSuspended {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, database.TemplateUser{
|
||||
User: user,
|
||||
Actions: v,
|
||||
})
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var template database.Template
|
||||
for _, t := range q.templates {
|
||||
if t.ID == id {
|
||||
template = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if template.ID == uuid.Nil {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
|
||||
for k, v := range template.GroupACL {
|
||||
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get group by ID: %w", err)
|
||||
}
|
||||
// We don't delete groups from the map if they
|
||||
// get deleted so just skip.
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
|
||||
groups = append(groups, database.TemplateGroup{
|
||||
Group: group,
|
||||
Actions: v,
|
||||
})
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
// The schema sorts this by created at, so we iterate the array backwards.
|
||||
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
|
||||
|
@ -5438,3 +4973,468 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC
|
|||
func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) {
|
||||
return database.TailnetCoordinator{}, ErrUnimplemented
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var templates []database.Template
|
||||
for _, template := range q.templates {
|
||||
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if template.Deleted != arg.Deleted {
|
||||
continue
|
||||
}
|
||||
if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(arg.IDs) > 0 {
|
||||
match := false
|
||||
for _, id := range arg.IDs {
|
||||
if template.ID == id {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
}
|
||||
templates = append(templates, template.DeepCopy())
|
||||
}
|
||||
if len(templates) > 0 {
|
||||
slices.SortFunc(templates, func(i, j database.Template) bool {
|
||||
if i.Name != j.Name {
|
||||
return i.Name < j.Name
|
||||
}
|
||||
return i.ID.String() < j.ID.String()
|
||||
})
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var template database.Template
|
||||
for _, t := range q.templates {
|
||||
if t.ID == id {
|
||||
template = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if template.ID == uuid.Nil {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
|
||||
for k, v := range template.GroupACL {
|
||||
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get group by ID: %w", err)
|
||||
}
|
||||
// We don't delete groups from the map if they
|
||||
// get deleted so just skip.
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
|
||||
groups = append(groups, database.TemplateGroup{
|
||||
Group: group,
|
||||
Actions: v,
|
||||
})
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var template database.Template
|
||||
for _, t := range q.templates {
|
||||
if t.ID == id {
|
||||
template = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if template.ID == uuid.Nil {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
users := make([]database.TemplateUser, 0, len(template.UserACL))
|
||||
for k, v := range template.UserACL {
|
||||
user, err := q.getUserByIDNoLock(uuid.MustParse(k))
|
||||
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get user by ID: %w", err)
|
||||
}
|
||||
// We don't delete users from the map if they
|
||||
// get deleted so just skip.
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
|
||||
if user.Deleted || user.Status == database.UserStatusSuspended {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, database.TemplateUser{
|
||||
User: user,
|
||||
Actions: v,
|
||||
})
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if prepared != nil {
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
workspaces := make([]database.Workspace, 0)
|
||||
for _, workspace := range q.workspaces {
|
||||
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.OwnerUsername != "" {
|
||||
owner, err := q.getUserByIDNoLock(workspace.OwnerID)
|
||||
if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if arg.TemplateName != "" {
|
||||
template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID)
|
||||
if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if !arg.Deleted && workspace.Deleted {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.Status != "" {
|
||||
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
// This logic should match the logic in the workspace.sql file.
|
||||
var statusMatch bool
|
||||
switch database.WorkspaceStatus(arg.Status) {
|
||||
case database.WorkspaceStatusPending:
|
||||
statusMatch = isNull(job.StartedAt)
|
||||
case database.WorkspaceStatusStarting:
|
||||
statusMatch = isNotNull(job.StartedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.CompletedAt) &&
|
||||
time.Since(job.UpdatedAt) < 30*time.Second &&
|
||||
build.Transition == database.WorkspaceTransitionStart
|
||||
|
||||
case database.WorkspaceStatusRunning:
|
||||
statusMatch = isNotNull(job.CompletedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.Error) &&
|
||||
build.Transition == database.WorkspaceTransitionStart
|
||||
|
||||
case database.WorkspaceStatusStopping:
|
||||
statusMatch = isNotNull(job.StartedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.CompletedAt) &&
|
||||
time.Since(job.UpdatedAt) < 30*time.Second &&
|
||||
build.Transition == database.WorkspaceTransitionStop
|
||||
|
||||
case database.WorkspaceStatusStopped:
|
||||
statusMatch = isNotNull(job.CompletedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.Error) &&
|
||||
build.Transition == database.WorkspaceTransitionStop
|
||||
case database.WorkspaceStatusFailed:
|
||||
statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) ||
|
||||
(isNotNull(job.CompletedAt) && isNotNull(job.Error))
|
||||
|
||||
case database.WorkspaceStatusCanceling:
|
||||
statusMatch = isNotNull(job.CanceledAt) &&
|
||||
isNull(job.CompletedAt)
|
||||
|
||||
case database.WorkspaceStatusCanceled:
|
||||
statusMatch = isNotNull(job.CanceledAt) &&
|
||||
isNotNull(job.CompletedAt)
|
||||
|
||||
case database.WorkspaceStatusDeleted:
|
||||
statusMatch = isNotNull(job.StartedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNotNull(job.CompletedAt) &&
|
||||
time.Since(job.UpdatedAt) < 30*time.Second &&
|
||||
build.Transition == database.WorkspaceTransitionDelete &&
|
||||
isNull(job.Error)
|
||||
|
||||
case database.WorkspaceStatusDeleting:
|
||||
statusMatch = isNull(job.CompletedAt) &&
|
||||
isNull(job.CanceledAt) &&
|
||||
isNull(job.Error) &&
|
||||
build.Transition == database.WorkspaceTransitionDelete
|
||||
|
||||
default:
|
||||
return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status)
|
||||
}
|
||||
if !statusMatch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if arg.HasAgent != "" {
|
||||
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace resources: %w", err)
|
||||
}
|
||||
|
||||
var workspaceResourceIDs []uuid.UUID
|
||||
for _, wr := range workspaceResources {
|
||||
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
|
||||
}
|
||||
|
||||
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace agents: %w", err)
|
||||
}
|
||||
|
||||
var hasAgentMatched bool
|
||||
for _, wa := range workspaceAgents {
|
||||
if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent {
|
||||
hasAgentMatched = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAgentMatched {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(arg.TemplateIds) > 0 {
|
||||
match := false
|
||||
for _, id := range arg.TemplateIds {
|
||||
if workspace.TemplateID == id {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
workspaces = append(workspaces, workspace)
|
||||
}
|
||||
|
||||
// Sort workspaces (ORDER BY)
|
||||
isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool {
|
||||
return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart
|
||||
}
|
||||
|
||||
preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{}
|
||||
preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{}
|
||||
preloadedUsers := map[uuid.UUID]database.User{}
|
||||
|
||||
for _, w := range workspaces {
|
||||
build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID)
|
||||
if err == nil {
|
||||
preloadedWorkspaceBuilds[w.ID] = build
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err == nil {
|
||||
preloadedProvisionerJobs[w.ID] = job
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
user, err := q.getUserByIDNoLock(w.OwnerID)
|
||||
if err == nil {
|
||||
preloadedUsers[w.ID] = user
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(workspaces, func(i, j int) bool {
|
||||
w1 := workspaces[i]
|
||||
w2 := workspaces[j]
|
||||
|
||||
// Order by: running first
|
||||
w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID])
|
||||
w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID])
|
||||
|
||||
if w1IsRunning && !w2IsRunning {
|
||||
return true
|
||||
}
|
||||
|
||||
if !w1IsRunning && w2IsRunning {
|
||||
return false
|
||||
}
|
||||
|
||||
// Order by: usernames
|
||||
if w1.ID != w2.ID {
|
||||
return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username})
|
||||
}
|
||||
|
||||
// Order by: workspace names
|
||||
return sort.StringsAreSorted([]string{w1.Name, w2.Name})
|
||||
})
|
||||
|
||||
beforePageCount := len(workspaces)
|
||||
|
||||
if arg.Offset > 0 {
|
||||
if int(arg.Offset) > len(workspaces) {
|
||||
return []database.GetWorkspacesRow{}, nil
|
||||
}
|
||||
workspaces = workspaces[arg.Offset:]
|
||||
}
|
||||
if arg.Limit > 0 {
|
||||
if int(arg.Limit) > len(workspaces) {
|
||||
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
|
||||
}
|
||||
workspaces = workspaces[:arg.Limit]
|
||||
}
|
||||
|
||||
return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
if err := validateDatabaseType(params); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Call this to match the same function calls as the SQL implementation.
|
||||
if prepared != nil {
|
||||
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
users := make([]database.User, 0, len(q.users))
|
||||
|
||||
for _, user := range q.users {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
// Filter out deleted since they should never be returned..
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
if !user.Deleted {
|
||||
tmp = append(tmp, user)
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
|
||||
if params.Search != "" {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
}
|
||||
|
||||
if len(params.Status) > 0 {
|
||||
usersFilteredByStatus := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
|
||||
return strings.EqualFold(string(a), string(b))
|
||||
}) {
|
||||
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
|
||||
}
|
||||
}
|
||||
users = usersFilteredByStatus
|
||||
}
|
||||
|
||||
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
|
||||
usersFilteredByRole := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
|
||||
usersFilteredByRole = append(usersFilteredByRole, users[i])
|
||||
}
|
||||
}
|
||||
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
return int64(len(users)), nil
|
||||
}
|
||||
|
|
|
@ -16,6 +16,12 @@ import (
|
|||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
||||
var (
|
||||
// Force these imports, for some reason the autogen does not include them.
|
||||
_ uuid.UUID
|
||||
_ rbac.Action
|
||||
)
|
||||
|
||||
const wrapname = "dbmetrics.metricsStore"
|
||||
|
||||
// New returns a database.Store that registers metrics for all queries to reg.
|
||||
|
@ -73,41 +79,6 @@ func (m metricsStore) InTx(f func(database.Store) error, options *sql.TxOptions)
|
|||
return err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
start := time.Now()
|
||||
templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds())
|
||||
return templates, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
start := time.Now()
|
||||
roles, err := m.s.GetTemplateGroupRoles(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds())
|
||||
return roles, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
start := time.Now()
|
||||
roles, err := m.s.GetTemplateUserRoles(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds())
|
||||
return roles, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
start := time.Now()
|
||||
workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds())
|
||||
return workspaces, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
start := time.Now()
|
||||
count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds())
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
|
||||
start := time.Now()
|
||||
err := m.s.AcquireLock(ctx, pgAdvisoryXactLock)
|
||||
|
@ -1639,3 +1610,38 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID
|
|||
defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds())
|
||||
return m.s.UpsertTailnetCoordinator(ctx, id)
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
start := time.Now()
|
||||
templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds())
|
||||
return templates, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
start := time.Now()
|
||||
roles, err := m.s.GetTemplateGroupRoles(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds())
|
||||
return roles, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
start := time.Now()
|
||||
roles, err := m.s.GetTemplateUserRoles(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds())
|
||||
return roles, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
start := time.Now()
|
||||
workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds())
|
||||
return workspaces, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
start := time.Now()
|
||||
count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds())
|
||||
return count, err
|
||||
}
|
||||
|
|
|
@ -418,21 +418,44 @@ type querierFunction struct {
|
|||
|
||||
// readQuerierFunctions reads the functions from coderd/database/querier.go
|
||||
func readQuerierFunctions() ([]querierFunction, error) {
|
||||
f, err := parseDBFile("querier.go")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse querier.go: %w", err)
|
||||
}
|
||||
funcs, err := loadInterfaceFuncs(f, "sqlcQuerier")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err)
|
||||
}
|
||||
|
||||
customFile, err := parseDBFile("modelqueries.go")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse modelqueriers.go: %w", err)
|
||||
}
|
||||
// Custom funcs should be appended after the regular functions
|
||||
customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err)
|
||||
}
|
||||
|
||||
return append(funcs, customFuncs...), nil
|
||||
}
|
||||
|
||||
func parseDBFile(filename string) (*dst.File, error) {
|
||||
localPath, err := localFilePath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go")
|
||||
|
||||
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename)
|
||||
querierData, err := os.ReadFile(querierPath)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read querier: %w", err)
|
||||
return nil, xerrors.Errorf("read %s: %w", filename, err)
|
||||
}
|
||||
f, err := decorator.Parse(querierData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, err
|
||||
}
|
||||
|
||||
func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) {
|
||||
var querier *dst.InterfaceType
|
||||
for _, decl := range f.Decls {
|
||||
genDecl, ok := decl.(*dst.GenDecl)
|
||||
|
@ -447,7 +470,7 @@ func readQuerierFunctions() ([]querierFunction, error) {
|
|||
}
|
||||
// This is the name of the interface. If that ever changes,
|
||||
// this will need to be updated.
|
||||
if typeSpec.Name.Name != "sqlcQuerier" {
|
||||
if typeSpec.Name.Name != interfaceName {
|
||||
continue
|
||||
}
|
||||
querier, ok = typeSpec.Type.(*dst.InterfaceType)
|
||||
|
@ -461,7 +484,8 @@ func readQuerierFunctions() ([]querierFunction, error) {
|
|||
return nil, xerrors.Errorf("querier not found")
|
||||
}
|
||||
funcs := []querierFunction{}
|
||||
for _, method := range querier.Methods.List {
|
||||
allMethods := interfaceMethods(querier)
|
||||
for _, method := range allMethods {
|
||||
funcType, ok := method.Type.(*dst.FuncType)
|
||||
if !ok {
|
||||
continue
|
||||
|
@ -540,3 +564,30 @@ func nameFromSnakeCase(s string) string {
|
|||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// interfaceMethods returns all embedded methods of an interface.
|
||||
func interfaceMethods(i *dst.InterfaceType) []*dst.Field {
|
||||
var allMethods []*dst.Field
|
||||
for _, field := range i.Methods.List {
|
||||
switch fieldType := field.Type.(type) {
|
||||
case *dst.FuncType:
|
||||
allMethods = append(allMethods, field)
|
||||
case *dst.InterfaceType:
|
||||
allMethods = append(allMethods, interfaceMethods(fieldType)...)
|
||||
case *dst.Ident:
|
||||
// Embedded interfaces are Idents -> TypeSpec -> InterfaceType
|
||||
// If the embedded interface is not in the parsed file, then
|
||||
// the Obj will be nil.
|
||||
if fieldType.Obj != nil {
|
||||
objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec)
|
||||
if ok {
|
||||
isInterface, ok := objDecl.Type.(*dst.InterfaceType)
|
||||
if ok {
|
||||
allMethods = append(allMethods, interfaceMethods(isInterface)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return allMethods
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue