chore: fix deadlock in dbfake and incorrect lock types (#7218)

I manually went through every single dbfake function and ensured it has
the correct lock type depending on whether it writes or only reads.
There were a surprising amount of methods that had the wrong lock type
(Lock when only reading, or RLock when writing (!!!)).

This also manually fixes every method that acquires a RLock and then
calls a method that also acquires it's own RLock to use noLock methods
instead. You cannot rely on acquiring a RLock twice in the same
goroutine as RWMutex prioritizes any waiting Lock calls.

I tried writing a ruleguard rule for this but because of limitations in
ruleguard it doesn't seem possible.
This commit is contained in:
Dean Sheather 2023-04-20 04:53:34 -07:00 committed by GitHub
parent 5f5edb18b0
commit 528a0686c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 47 deletions

View File

@ -332,7 +332,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C
}
// Get resources for build.
resources, err := q.GetWorkspaceResourcesByJobID(ctx, workspaceBuild.JobID)
resources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, workspaceBuild.JobID)
if err != nil {
return nil, xerrors.Errorf("get workspace resources: %w", err)
}
@ -345,7 +345,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.C
resourceIDs[i] = resource.ID
}
agents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, resourceIDs)
agents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
if err != nil {
return nil, xerrors.Errorf("get workspace agents: %w", err)
}
@ -435,8 +435,8 @@ func (q *fakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.Ins
}
func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
seens := make(map[time.Time]map[uuid.UUID]struct{})
@ -478,8 +478,8 @@ func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) (
}
func (q *fakeQuerier) GetDeploymentDAUs(_ context.Context) ([]database.GetDeploymentDAUsRow, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
seens := make(map[time.Time]map[uuid.UUID]struct{})
@ -571,8 +571,8 @@ func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg datab
}
func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, parameterValue := range q.parameterValues {
if parameterValue.ID != id {
@ -1181,7 +1181,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.GetProvisionerJobByID(ctx, build.JobID)
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
@ -1270,12 +1270,12 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
return nil, xerrors.Errorf("get latest build: %w", err)
}
job, err := q.GetProvisionerJobByID(ctx, build.JobID)
job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID)
if err != nil {
return nil, xerrors.Errorf("get provisioner job: %w", err)
}
workspaceResources, err := q.GetWorkspaceResourcesByJobID(ctx, job.ID)
workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace resources: %w", err)
}
@ -1285,7 +1285,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
workspaceResourceIDs = append(workspaceResourceIDs, wr.ID)
}
workspaceAgents, err := q.GetWorkspaceAgentsByResourceIDs(ctx, workspaceResourceIDs)
workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs)
if err != nil {
return nil, xerrors.Errorf("get workspace agents: %w", err)
}
@ -1395,10 +1395,14 @@ func convertToWorkspaceRows(workspaces []database.Workspace, count int64) []data
return rows
}
func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (database.Workspace, error) {
func (q *fakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getWorkspaceByIDNoLock(ctx, id)
}
func (q *fakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) {
for _, workspace := range q.workspaces {
if workspace.ID == id {
return workspace, nil
@ -1407,10 +1411,14 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceByAgentID(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
func (q *fakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getWorkspaceByAgentIDNoLock(ctx, agentID)
}
func (q *fakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) {
var agent database.WorkspaceAgent
for _, _agent := range q.workspaceAgents {
if _agent.ID == agentID {
@ -1496,7 +1504,7 @@ func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceA
for _, workspaceApp := range q.workspaceApps {
workspaceApp := workspaceApp
if workspaceApp.ID == workspaceAppID {
return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID)
return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID)
}
}
return database.Workspace{}, sql.ErrNoRows
@ -1547,10 +1555,14 @@ func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.U
return apps, nil
}
func (q *fakeQuerier) GetWorkspaceBuildByID(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
func (q *fakeQuerier) GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getWorkspaceBuildByIDNoLock(ctx, id)
}
func (q *fakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) {
for _, history := range q.workspaceBuilds {
if history.ID == id {
return history, nil
@ -2359,7 +2371,7 @@ func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]
groups := make([]database.TemplateGroup, 0, len(template.GroupACL))
for k, v := range template.GroupACL {
group, err := q.GetGroupByID(context.Background(), uuid.MustParse(k))
group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k))
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get group by ID: %w", err)
}
@ -2490,10 +2502,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
return database.WorkspaceAgent{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
func (q *fakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getWorkspaceAgentByIDNoLock(ctx, id)
}
func (q *fakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.workspaceAgents) - 1; i >= 0; i-- {
agent := q.workspaceAgents[i]
@ -2518,10 +2534,14 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI
return database.WorkspaceAgent{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getWorkspaceAgentsByResourceIDsNoLock(ctx, resourceIDs)
}
func (q *fakeQuerier) getWorkspaceAgentsByResourceIDsNoLock(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
workspaceAgents := make([]database.WorkspaceAgent, 0)
for _, agent := range q.workspaceAgents {
for _, resourceID := range resourceIDs {
@ -2596,10 +2616,14 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID)
return database.WorkspaceResource{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
func (q *fakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID)
}
func (q *fakeQuerier) getWorkspaceResourcesByJobIDNoLock(_ context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) {
resources := make([]database.WorkspaceResource, 0)
for _, resource := range q.workspaceResources {
if resource.JobID != jobID {
@ -3674,8 +3698,8 @@ func (q *fakeQuerier) GetWorkspaceAgentStartupLogsAfter(_ context.Context, arg d
return nil, err
}
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
logs := []database.WorkspaceAgentStartupLog{}
for _, log := range q.workspaceAgentLogs {
@ -4051,13 +4075,13 @@ func (q *fakeQuerier) GetWorkspaceAgentStatsAndLabels(ctx context.Context, creat
stat.Username = user.Username
workspace, err := q.GetWorkspaceByID(ctx, agentStat.WorkspaceID)
workspace, err := q.getWorkspaceByIDNoLock(ctx, agentStat.WorkspaceID)
if err != nil {
return nil, err
}
stat.WorkspaceName = workspace.Name
agent, err := q.GetWorkspaceAgentByID(ctx, agentStat.AgentID)
agent, err := q.getWorkspaceAgentByIDNoLock(ctx, agentStat.AgentID)
if err != nil {
return nil, err
}
@ -4403,7 +4427,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(_ context.Context, arg database.GetAudi
}
}
if arg.BuildReason != "" {
workspaceBuild, err := q.GetWorkspaceBuildByID(context.Background(), alog.ResourceID)
workspaceBuild, err := q.getWorkspaceBuildByIDNoLock(context.Background(), alog.ResourceID)
if err == nil && !strings.EqualFold(arg.BuildReason, string(workspaceBuild.Reason)) {
continue
}
@ -4497,8 +4521,8 @@ func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
}
func (q *fakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error {
q.mutex.RLock()
defer q.mutex.RUnlock()
q.mutex.Lock()
defer q.mutex.Unlock()
q.lastUpdateCheck = []byte(data)
return nil
@ -4672,8 +4696,8 @@ func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
}
func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
q.mutex.Lock()
defer q.mutex.Unlock()
//nolint:gosimple
link := database.UserLink{
@ -4695,8 +4719,8 @@ func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.Upda
return database.UserLink{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
q.mutex.Lock()
defer q.mutex.Unlock()
for i, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
@ -4715,8 +4739,8 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
return database.UserLink{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
q.mutex.Lock()
defer q.mutex.Unlock()
for i, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
@ -4732,10 +4756,14 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
return database.UserLink{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetGroupByID(_ context.Context, id uuid.UUID) (database.Group, error) {
func (q *fakeQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.getGroupByIDNoLock(ctx, id)
}
func (q *fakeQuerier) getGroupByIDNoLock(_ context.Context, id uuid.UUID) (database.Group, error) {
for _, group := range q.groups {
if group.ID == id {
return group, nil
@ -4776,8 +4804,8 @@ func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupPar
return database.Group{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
q.mutex.Lock()
defer q.mutex.Unlock()
for _, group := range q.groups {
if group.OrganizationID == arg.OrganizationID &&
@ -4995,8 +5023,9 @@ func (q *fakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
}
func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UUID) (int64, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
var sum int64
for _, member := range q.groupMembers {
if member.UserID != userID {
@ -5012,8 +5041,9 @@ func (q *fakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UU
}
func (q *fakeQuerier) GetQuotaConsumedForUser(_ context.Context, userID uuid.UUID) (int64, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
var sum int64
for _, workspace := range q.workspaces {
if workspace.OwnerID != userID {
@ -5072,8 +5102,8 @@ func (q *fakeQuerier) UpdateWorkspaceAgentStartupLogOverflowByID(_ context.Conte
}
func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies))
@ -5086,8 +5116,8 @@ func (q *fakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.Workspa
}
func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, proxy := range q.workspaceProxies {
if proxy.ID == id {
@ -5098,8 +5128,8 @@ func (q *fakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (da
}
func (q *fakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, hostname string) (database.WorkspaceProxy, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.mutex.RLock()
defer q.mutex.RUnlock()
// Return zero rows if this is called with a non-sanitized hostname. The SQL
// version of this query does the same thing.

View File

@ -68,7 +68,7 @@ func TestFilterError(t *testing.T) {
auth := &MockAuthorizer{
AuthorizeFunc: func(ctx context.Context, subject Subject, action Action, object Object) error {
// Authorize func always returns nil, unless the context is cancelled.
// Authorize func always returns nil, unless the context is canceled.
return ctx.Err()
},
}