chore: enable exhaustruct linter for database param structs (#9995)

This commit is contained in:
Cian Johnston 2023-10-03 09:23:45 +01:00 committed by GitHub
parent 352ec7bc4f
commit e55c25e037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 186 additions and 116 deletions

View File

@ -10,6 +10,9 @@ linters-settings:
include:
# Gradually extend to cover more of the codebase.
- 'httpmw\.\w+'
# We want to enforce all values are specified when inserting or updating
# a database row. Ref: #9936
- 'github.com/coder/coder/v2/coderd/database\.[^G][^e][^t]\w+Params'
gocognit:
min-complexity: 300

View File

@ -76,6 +76,7 @@ func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error)
return database.InsertAPIKeyParams{
ID: keyID,
UserID: params.UserID,
LastUsed: time.Time{},
LifetimeSeconds: params.LifetimeSeconds,
IPAddress: pqtype.Inet{
IPNet: net.IPNet{

View File

@ -154,6 +154,9 @@ func (api *API) generateFakeAuditLog(rw http.ResponseWriter, r *http.Request) {
Diff: diff,
StatusCode: http.StatusOK,
AdditionalFields: params.AdditionalFields,
RequestID: uuid.Nil, // no request ID to attach this to
ResourceIcon: "",
OrganizationID: uuid.New(),
})
if err != nil {
httpapi.InternalServerError(rw, err)

View File

@ -1,6 +1,7 @@
package oidctest
import (
"database/sql"
"net/http"
"testing"
"time"
@ -70,7 +71,9 @@ func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *code
// Expire the oauth link for the given user.
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: time.Now().Add(time.Hour * -1),
UserID: link.UserID,
LoginType: link.LoginType,

View File

@ -4142,6 +4142,8 @@ func (q *FakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID)
Name: database.EveryoneGroup,
DisplayName: "",
OrganizationID: orgID,
AvatarURL: "",
QuotaAllowance: 0,
})
}

View File

@ -125,7 +125,7 @@ func APIKey(t testing.TB, db database.Store, seed database.APIKey) (key database
}
func WorkspaceAgent(t testing.TB, db database.Store, orig database.WorkspaceAgent) database.WorkspaceAgent {
workspace, err := db.InsertWorkspaceAgent(genCtx, database.InsertWorkspaceAgentParams{
agt, err := db.InsertWorkspaceAgent(genCtx, database.InsertWorkspaceAgentParams{
ID: takeFirst(orig.ID, uuid.New()),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
@ -154,9 +154,10 @@ func WorkspaceAgent(t testing.TB, db database.Store, orig database.WorkspaceAgen
ConnectionTimeoutSeconds: takeFirst(orig.ConnectionTimeoutSeconds, 3600),
TroubleshootingURL: takeFirst(orig.TroubleshootingURL, "https://example.com"),
MOTDFile: takeFirst(orig.TroubleshootingURL, ""),
DisplayApps: append([]database.DisplayApp{}, orig.DisplayApps...),
})
require.NoError(t, err, "insert workspace agent")
return workspace
return agt
}
func Workspace(t testing.TB, db database.Store, orig database.Workspace) database.Workspace {
@ -204,6 +205,7 @@ func WorkspaceBuild(t testing.TB, db database.Store, orig database.WorkspaceBuil
JobID: takeFirst(orig.JobID, uuid.New()),
ProvisionerState: takeFirstSlice(orig.ProvisionerState, []byte{}),
Deadline: takeFirst(orig.Deadline, dbtime.Now().Add(time.Hour)),
MaxDeadline: takeFirst(orig.MaxDeadline, time.Time{}),
Reason: takeFirst(orig.Reason, database.BuildReasonInitiator),
})
if err != nil {
@ -348,6 +350,7 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
Type: takeFirst(orig.Type, database.ProvisionerJobTypeWorkspaceBuild),
Input: takeFirstSlice(orig.Input, []byte("{}")),
Tags: orig.Tags,
TraceMetadata: pqtype.NullRawMessage{},
})
require.NoError(t, err, "insert job")
if ps != nil {
@ -359,6 +362,7 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
StartedAt: orig.StartedAt,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
})
require.NoError(t, err)
}
@ -460,6 +464,8 @@ func WorkspaceProxy(t testing.TB, db database.Store, orig database.WorkspaceProx
TokenHashedSecret: hashedSecret[:],
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
DerpEnabled: takeFirst(orig.DerpEnabled, false),
DerpOnly: takeFirst(orig.DerpEnabled, false),
})
require.NoError(t, err, "insert proxy")

View File

@ -128,7 +128,9 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OAuthAccessToken: token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthExpiry: token.Expiry,
})
if err != nil {
@ -144,7 +146,9 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
UserID: apiKey.UserID,
UpdatedAt: dbtime.Now(),
OAuthAccessToken: token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
})
if err != nil {
@ -216,7 +220,9 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OAuthAccessToken: state.Token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthRefreshToken: state.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthExpiry: state.Token.Expiry,
})
if err != nil {
@ -232,7 +238,9 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
UserID: apiKey.UserID,
UpdatedAt: dbtime.Now(),
OAuthAccessToken: state.Token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: state.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: state.Token.Expiry,
})
if err != nil {

View File

@ -2,6 +2,7 @@ package externalauth
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
@ -120,18 +121,20 @@ validate:
}
if token.AccessToken != externalAuthLink.OAuthAccessToken {
// Update it
externalAuthLink, err = db.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
updatedAuthLink, err := db.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
ProviderID: c.ID,
UserID: externalAuthLink.UserID,
UpdatedAt: dbtime.Now(),
OAuthAccessToken: token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
})
if err != nil {
return externalAuthLink, false, xerrors.Errorf("update external auth link: %w", err)
return updatedAuthLink, false, xerrors.Errorf("update external auth link: %w", err)
}
externalAuthLink = updatedAuthLink
}
return externalAuthLink, true, nil
}

View File

@ -369,12 +369,14 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
// nolint:gocritic
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
})
if err != nil {
@ -388,7 +390,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
// easier on the DB to colocate writes.
// nolint:gocritic
//nolint:gocritic // system needs to update user last seen at
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: dbtime.Now(),
@ -405,7 +407,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
// If the key is valid, we also fetch the user roles and status.
// The roles are used for RBAC authorize checks, and the status
// is to block 'suspended' users from accessing the platform.
// nolint:gocritic
//nolint:gocritic // system needs to update user roles
roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystemRestricted(ctx), key.UserID)
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{

View File

@ -72,6 +72,7 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
Name: req.Name,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
Description: "",
})
if err != nil {
return xerrors.Errorf("create organization: %w", err)

View File

@ -263,18 +263,21 @@ func (s *server) AcquireJobWithCancel(stream proto.DRPCProvisionerDaemon_Acquire
logger.Error(streamCtx, "recv error and failed to cancel acquire job", slog.Error(recvErr))
// Well, this is awkward. We hit an error receiving from the stream, but didn't cancel before we locked a job
// in the database. We need to mark this job as failed so the end user can retry if they want to.
now := dbtime.Now()
err := s.Database.UpdateProvisionerJobWithCompleteByID(
context.Background(),
database.UpdateProvisionerJobWithCompleteByIDParams{
ID: je.job.ID,
CompletedAt: sql.NullTime{
Time: dbtime.Now(),
Time: now,
Valid: true,
},
UpdatedAt: now,
Error: sql.NullString{
String: "connection to provisioner daemon broken",
Valid: true,
},
ErrorCode: sql.NullString{},
})
if err != nil {
logger.Error(streamCtx, "error updating failed job", slog.Error(err))
@ -308,6 +311,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
Valid: true,
},
ErrorCode: job.ErrorCode,
UpdatedAt: dbtime.Now(),
})
if err != nil {
return xerrors.Errorf("update provisioner job: %w", err)
@ -651,6 +655,7 @@ func (s *server) UpdateJob(ctx context.Context, request *proto.UpdateJobRequest)
}
if len(request.Logs) > 0 {
//nolint:exhaustruct // We append to the additional fields below.
insertParams := database.InsertProvisionerJobLogsParams{
JobID: parsedID,
}
@ -1062,6 +1067,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
Valid: true,
},
Error: completedError,
ErrorCode: sql.NullString{},
})
if err != nil {
return nil, xerrors.Errorf("update provisioner job: %w", err)
@ -1118,6 +1124,8 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
Time: dbtime.Now(),
Valid: true,
},
Error: sql.NullString{},
ErrorCode: sql.NullString{},
})
if err != nil {
return xerrors.Errorf("update provisioner job: %w", err)
@ -1275,6 +1283,8 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
Time: dbtime.Now(),
Valid: true,
},
Error: sql.NullString{},
ErrorCode: sql.NullString{},
})
if err != nil {
return nil, xerrors.Errorf("update provisioner job: %w", err)
@ -1386,6 +1396,8 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
TroubleshootingURL: prAgent.GetTroubleshootingUrl(),
MOTDFile: prAgent.GetMotdFile(),
DisplayApps: convertDisplayApps(prAgent.GetDisplayApps()),
InstanceMetadata: pqtype.NullRawMessage{},
ResourceMetadata: pqtype.NullRawMessage{},
})
if err != nil {
return xerrors.Errorf("insert agent: %w", err)
@ -1631,7 +1643,9 @@ func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig ht
UserID: userID,
LoginType: database.LoginTypeOIDC,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: link.OAuthExpiry,
})
if err != nil {

View File

@ -274,6 +274,11 @@ func unhangJob(ctx context.Context, log slog.Logger, db database.Store, pub pubs
// Insert the messages into the build log.
insertParams := database.InsertProvisionerJobLogsParams{
JobID: job.ID,
CreatedAt: nil,
Source: nil,
Level: nil,
Stage: nil,
Output: nil,
}
now := dbtime.Now()
for i, msg := range HungJobLogMessages {

View File

@ -1327,13 +1327,15 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
if link.UserID == uuid.Nil {
//nolint:gocritic
//nolint:gocritic // System needs to insert the user link (linked_id, oauth_token, oauth_expiry).
link, err = tx.InsertUserLink(dbauthz.AsSystemRestricted(ctx), database.InsertUserLinkParams{
UserID: user.ID,
LoginType: params.LoginType,
LinkedID: params.LinkedID,
OAuthAccessToken: params.State.Token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthRefreshToken: params.State.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: params.State.Token.Expiry,
})
if err != nil {
@ -1342,12 +1344,14 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
if link.UserID != uuid.Nil {
//nolint:gocritic
//nolint:gocritic // System needs to update the user link (linked_id, oauth_token, oauth_expiry).
link, err = tx.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: user.ID,
LoginType: params.LoginType,
OAuthAccessToken: params.State.Token.AccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthRefreshToken: params.State.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // set by dbcrypt if required
OAuthExpiry: params.State.Token.Expiry,
})
if err != nil {

View File

@ -1083,6 +1083,7 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
Name: req.Username,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
Description: "",
})
if err != nil {
return xerrors.Errorf("create organization: %w", err)
@ -1106,6 +1107,7 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
Username: req.Username,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
HashedPassword: []byte{},
// All new users are defaulted to members of the site.
RBACRoles: []string{},
LoginType: req.LoginType,

View File

@ -8,6 +8,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
@ -350,6 +351,8 @@ func (b *Builder) buildTx(authFunc func(action rbac.Action, object rbac.Objecter
Transition: b.trans,
JobID: provisionerJob.ID,
Reason: b.reason,
Deadline: time.Time{}, // set by provisioner upon completion
MaxDeadline: time.Time{}, // set by provisioner upon completion
})
if err != nil {
return BuildError{http.StatusInternalServerError, "insert workspace build", err}

View File

@ -25,8 +25,8 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
}
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs)))
for idx, uid := range userIDs {
err := cryptDB.InTx(func(tx database.Store) error {
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
err := cryptDB.InTx(func(cryptTx database.Store) error {
userLinks, err := cryptTx.GetUserLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get user links for user: %w", err)
}
@ -35,9 +35,11 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
if _, err := cryptTx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: userLink.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: userLink.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: userLink.OAuthExpiry,
UserID: uid,
LoginType: userLink.LoginType,
@ -46,7 +48,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
}
}
gitAuthLinks, err := tx.GetExternalAuthLinksByUserID(ctx, uid)
gitAuthLinks, err := cryptTx.GetExternalAuthLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get git auth links for user: %w", err)
}
@ -55,12 +57,14 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := tx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
if _, err := cryptTx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
ProviderID: gitAuthLink.ProviderID,
UserID: uid,
UpdatedAt: gitAuthLink.UpdatedAt,
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: gitAuthLink.OAuthExpiry,
}); err != nil {
return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err)
@ -121,7 +125,9 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
}
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: userLink.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id
OAuthRefreshToken: userLink.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id
OAuthExpiry: userLink.OAuthExpiry,
UserID: uid,
LoginType: userLink.LoginType,
@ -144,7 +150,9 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
UserID: uid,
UpdatedAt: gitAuthLink.UpdatedAt,
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // we explicitly want to clear the key id
OAuthExpiry: gitAuthLink.OAuthExpiry,
}); err != nil {
return xerrors.Errorf("update git auth link user_id=%s provider_id=%s: %w", gitAuthLink.UserID, gitAuthLink.ProviderID, err)

View File

@ -425,6 +425,8 @@ func (m *Manager) Close() error {
Hostname: m.self.Hostname,
Version: m.self.Version,
Error: m.self.Error,
DatabaseLatency: 0, // A stopped replica has no latency.
Primary: false, // A stopped replica cannot be primary.
})
if err != nil {
return xerrors.Errorf("update replica: %w", err)