mirror of https://github.com/coder/coder.git
chore: fix deadlock in dbfake and incorrect lock types (#7218)
I manually went through every single dbfake function and ensured it has the correct lock type depending on whether it writes or only reads. There were a surprising amount of methods that had the wrong lock type (Lock when only reading, or RLock when writing (!!!)). This also manually fixes every method that acquires a RLock and then calls a method that also acquires it's own RLock to use noLock methods instead. You cannot rely on acquiring a RLock twice in the same goroutine as RWMutex prioritizes any waiting Lock calls. I tried writing a ruleguard rule for this but because of limitations in ruleguard it doesn't seem possible.
This commit is contained in:
parent
5f5edb18b0
commit
528a0686c0
|
@ -332,7 +332,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C
|
|||
}
|
||||
|
||||
// Get resources for build.
|
||||
resources, err := q.GetWorkspaceResourcesByJobID(ctx, workspaceBuild.JobID)
|
||||
resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace resources: %w", err)
|
||||
}
|
||||
|
@ -345,7 +345,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C
|
|||
resourceIDs[i] = resource.ID
|
||||
}
|
||||
|
||||
agents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs)
|
||||
agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace agents: %w", err)
|
||||
}
|
||||
|
@ -435,8 +435,8 @@ func (q *fakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.Ins
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
seens := make(map[time.Time]map[uuid.UUID]struct{})
|
||||
|
||||
|
@ -478,8 +478,8 @@ func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) (
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetDeploymentDAUs(_ context.Context) ([]database.GetDeploymentDAUsRow, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
seens := make(map[time.Time]map[uuid.UUID]struct{})
|
||||
|
||||
|
@ -571,8 +571,8 @@ func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg datab
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, parameterValue := range q.parameterValues {
|
||||
if parameterValue.ID != id {
|
||||
|
@ -1181,7 +1181,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.GetProvisionerJobByID(ctx, build.JobID)
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
@ -1270,12 +1270,12 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
return nil, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := q.GetProvisionerJobByID(ctx, build.JobID)
|
||||
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
workspaceResources, err := q.GetWorkspaceResourcesByJobID(ctx, job.ID)
|
||||
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace resources: %w", err)
|
||||
}
|
||||
|
@ -1285,7 +1285,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
|||
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
|
||||
}
|
||||
|
||||
workspaceAgents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, workspaceResourceIDs)
|
||||
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace agents: %w", err)
|
||||
}
|
||||
|
@ -1395,10 +1395,14 @@ func convertToWorkspaceRows(workspaces []database.Workspace, count int64) []data
|
|||
return rows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getWorkspaceByIDNoLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) {
|
||||
for _, workspace := range q.workspaces {
|
||||
if workspace.ID == id {
|
||||
return workspace, nil
|
||||
|
@ -1407,10 +1411,14 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas
|
|||
return database.Workspace{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceByAgentID(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getWorkspaceByAgentIDNoLock(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
|
||||
var agent database.WorkspaceAgent
|
||||
for _, _agent := range q.workspaceAgents {
|
||||
if _agent.ID == agentID {
|
||||
|
@ -1496,7 +1504,7 @@ func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceA
|
|||
for _, workspaceApp := range q.workspaceApps {
|
||||
workspaceApp := workspaceApp
|
||||
if workspaceApp.ID == workspaceAppID {
|
||||
return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID)
|
||||
return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID)
|
||||
}
|
||||
}
|
||||
return database.Workspace{}, sql.ErrNoRows
|
||||
|
@ -1547,10 +1555,14 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.U
|
|||
return apps, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getWorkspaceBuildByIDNoLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
|
||||
for _, history := range q.workspaceBuilds {
|
||||
if history.ID == id {
|
||||
return history, nil
|
||||
|
@ -2359,7 +2371,7 @@ func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]
|
|||
|
||||
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
|
||||
for k, v := range template.GroupACL {
|
||||
group, err := q.GetGroupByID(context.Background(), uuid.MustParse(k))
|
||||
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get group by ID: %w", err)
|
||||
}
|
||||
|
@ -2490,10 +2502,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
|
|||
return database.WorkspaceAgent{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getWorkspaceAgentByIDNoLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
// The schema sorts this by created at, so we iterate the array backwards.
|
||||
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
|
||||
agent := q.workspaceAgents[i]
|
||||
|
@ -2518,10 +2534,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI
|
|||
return database.WorkspaceAgent{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
workspaceAgents := make([]database.WorkspaceAgent, 0)
|
||||
for _, agent := range q.workspaceAgents {
|
||||
for _, resourceID := range resourceIDs {
|
||||
|
@ -2596,10 +2616,14 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID)
|
|||
return database.WorkspaceResource{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
|
||||
resources := make([]database.WorkspaceResource, 0)
|
||||
for _, resource := range q.workspaceResources {
|
||||
if resource.JobID != jobID {
|
||||
|
@ -3674,8 +3698,8 @@ func (q *fakeQuerier) GetWorkspaceAgentStartupLogsAfter(_ context.Context, arg d
|
|||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
logs := []database.WorkspaceAgentStartupLog{}
|
||||
for _, log := range q.workspaceAgentLogs {
|
||||
|
@ -4051,13 +4075,13 @@ func (q *fakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, creat
|
|||
|
||||
stat.Username = user.Username
|
||||
|
||||
workspace, err := q.GetWorkspaceByID(ctx, agentStat.WorkspaceID)
|
||||
workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stat.WorkspaceName = workspace.Name
|
||||
|
||||
agent, err := q.GetWorkspaceAgentByID(ctx, agentStat.AgentID)
|
||||
agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -4403,7 +4427,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAudi
|
|||
}
|
||||
}
|
||||
if arg.BuildReason != "" {
|
||||
workspaceBuild, err := q.GetWorkspaceBuildByID(context.Background(), alog.ResourceID)
|
||||
workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID)
|
||||
if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) {
|
||||
continue
|
||||
}
|
||||
|
@ -4497,8 +4521,8 @@ func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
q.lastUpdateCheck = []byte(data)
|
||||
return nil
|
||||
|
@ -4672,8 +4696,8 @@ func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
//nolint:gosimple
|
||||
link := database.UserLink{
|
||||
|
@ -4695,8 +4719,8 @@ func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.Upda
|
|||
return database.UserLink{}, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
|
@ -4715,8 +4739,8 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
|
|||
return database.UserLink{}, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
|
@ -4732,10 +4756,14 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
|
|||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetGroupByID(_ context.Context, id uuid.UUID) (database.Group, error) {
|
||||
func (q *fakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
return q.getGroupByIDNoLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) {
|
||||
for _, group := range q.groups {
|
||||
if group.ID == id {
|
||||
return group, nil
|
||||
|
@ -4776,8 +4804,8 @@ func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupPar
|
|||
return database.Group{}, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for _, group := range q.groups {
|
||||
if group.OrganizationID == arg.OrganizationID &&
|
||||
|
@ -4995,8 +5023,9 @@ func (q *fakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UUID) (int64, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var sum int64
|
||||
for _, member := range q.groupMembers {
|
||||
if member.UserID != userID {
|
||||
|
@ -5012,8 +5041,9 @@ func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UU
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUID) (int64, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var sum int64
|
||||
for _, workspace := range q.workspaces {
|
||||
if workspace.OwnerID != userID {
|
||||
|
@ -5072,8 +5102,8 @@ func (q *fakeQuerier) UpdateWorkspaceAgentStartupLogOverflowByID(_ context.Conte
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies))
|
||||
|
||||
|
@ -5086,8 +5116,8 @@ func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.Workspa
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, proxy := range q.workspaceProxies {
|
||||
if proxy.ID == id {
|
||||
|
@ -5098,8 +5128,8 @@ func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (da
|
|||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, hostname string) (database.WorkspaceProxy, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// Return zero rows if this is called with a non-sanitized hostname. The SQL
|
||||
// version of this query does the same thing.
|
||||
|
|
|
@ -68,7 +68,7 @@ func TestFilterError(t *testing.T) {
|
|||
|
||||
auth := &MockAuthorizer{
|
||||
AuthorizeFunc: func(ctx context.Context, subject Subject, action Action, object Object) error {
|
||||
// Authorize func always returns nil, unless the context is cancelled.
|
||||
// Authorize func always returns nil, unless the context is canceled.
|
||||
return ctx.Err()
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue