mirror of https://github.com/coder/coder.git
fix(enterprise/dbcrypt): do not skip deleted users when encrypting or deleting (#9694)
- Broadens scope of data generation in TestServerDBCrypt over all user login types, statuses, and deletion status. - Adds support for specifying user status / user deletion status in dbgen - Adds more comprehensive logging in TestServerDBCrypt upon test failure (to be generalized and expanded upon in a follow-up) - Adds AllUserIDs query, updates dbcrypt to use this instead of GetUsers.
This commit is contained in:
parent
bc97eaa41b
commit
72dff7f188
|
@ -664,6 +664,15 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID) erro
|
|||
return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
|
||||
// Although this technically only reads users, only system-related functions should be
|
||||
// allowed to call this.
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.AllUserIDs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CleanTailnetCoordinators(ctx context.Context) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
|
|
|
@ -812,6 +812,16 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, workspaceID uui
|
|||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
userIDs := make([]uuid.UUID, 0, len(q.users))
|
||||
for idx := range q.users {
|
||||
userIDs[idx] = q.users[idx].ID
|
||||
}
|
||||
return userIDs, nil
|
||||
}
|
||||
|
||||
func (*FakeQuerier) CleanTailnetCoordinators(_ context.Context) error {
|
||||
return ErrUnimplemented
|
||||
}
|
||||
|
|
|
@ -227,7 +227,7 @@ func User(t testing.TB, db database.Store, orig database.User) database.User {
|
|||
|
||||
user, err = db.UpdateUserStatus(genCtx, database.UpdateUserStatusParams{
|
||||
ID: user.ID,
|
||||
Status: database.UserStatusActive,
|
||||
Status: takeFirst(orig.Status, database.UserStatusActive),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err, "insert user")
|
||||
|
@ -240,6 +240,14 @@ func User(t testing.TB, db database.Store, orig database.User) database.User {
|
|||
})
|
||||
require.NoError(t, err, "user last seen")
|
||||
}
|
||||
|
||||
if orig.Deleted {
|
||||
err = db.UpdateUserDeletedByID(genCtx, database.UpdateUserDeletedByIDParams{
|
||||
ID: user.ID,
|
||||
Deleted: orig.Deleted,
|
||||
})
|
||||
require.NoError(t, err, "set user as deleted")
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
|
|
|
@ -100,6 +100,13 @@ func (m metricsStore) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID)
|
|||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AllUserIDs(ctx)
|
||||
m.queryLatencies.WithLabelValues("AllUserIDs").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) CleanTailnetCoordinators(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
err := m.s.CleanTailnetCoordinators(ctx)
|
||||
|
|
|
@ -82,6 +82,21 @@ func (mr *MockStoreMockRecorder) ActivityBumpWorkspace(arg0, arg1 interface{}) *
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityBumpWorkspace", reflect.TypeOf((*MockStore)(nil).ActivityBumpWorkspace), arg0, arg1)
|
||||
}
|
||||
|
||||
// AllUserIDs mocks base method.
|
||||
func (m *MockStore) AllUserIDs(arg0 context.Context) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AllUserIDs", arg0)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AllUserIDs indicates an expected call of AllUserIDs.
|
||||
func (mr *MockStoreMockRecorder) AllUserIDs(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllUserIDs", reflect.TypeOf((*MockStore)(nil).AllUserIDs), arg0)
|
||||
}
|
||||
|
||||
// CleanTailnetCoordinators mocks base method.
|
||||
func (m *MockStore) CleanTailnetCoordinators(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -31,6 +31,8 @@ type sqlcQuerier interface {
|
|||
// We only bump if workspace shutdown is manual.
|
||||
// We only bump when 5% of the deadline has elapsed.
|
||||
ActivityBumpWorkspace(ctx context.Context, workspaceID uuid.UUID) error
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
AllUserIDs(ctx context.Context) ([]uuid.UUID, error)
|
||||
CleanTailnetCoordinators(ctx context.Context) error
|
||||
DeleteAPIKeyByID(ctx context.Context, id string) error
|
||||
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
|
|
|
@ -5846,6 +5846,34 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
|
|||
return i, err
|
||||
}
|
||||
|
||||
const allUserIDs = `-- name: AllUserIDs :many
|
||||
SELECT DISTINCT id FROM USERS
|
||||
`
|
||||
|
||||
// AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
func (q *sqlQuerier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, allUserIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getActiveUserCount = `-- name: GetActiveUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
|
|
|
@ -16,3 +16,4 @@ AND
|
|||
INSERT INTO dbcrypt_keys
|
||||
(number, active_key_digest, created_at, test)
|
||||
VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text);
|
||||
|
||||
|
|
|
@ -262,3 +262,8 @@ WHERE
|
|||
last_seen_at < @last_seen_after :: timestamp
|
||||
AND status = 'active'::user_status
|
||||
RETURNING id, email, last_seen_at;
|
||||
|
||||
-- AllUserIDs returns all UserIDs regardless of user status or deletion.
|
||||
-- name: AllUserIDs :many
|
||||
SELECT DISTINCT id FROM USERS;
|
||||
|
||||
|
|
|
@ -20,6 +20,9 @@ import (
|
|||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
)
|
||||
|
||||
// TestServerDBCrypt tests end-to-end encryption, decryption, and deletion
|
||||
// of encrypted user data.
|
||||
//
|
||||
// nolint: paralleltest // use of t.Setenv
|
||||
func TestServerDBCrypt(t *testing.T) {
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
|
@ -41,15 +44,38 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
})
|
||||
db := database.New(sqlDB)
|
||||
|
||||
// Populate the database with some unencrypted data.
|
||||
users := genData(t, db, 10)
|
||||
t.Cleanup(func() {
|
||||
if t.Failed() {
|
||||
t.Logf("Dumping data due to failed test. I hope you find what you're looking for!")
|
||||
dumpUsers(t, sqlDB)
|
||||
}
|
||||
})
|
||||
|
||||
// Setup an initial cipher
|
||||
// Populate the database with some unencrypted data.
|
||||
t.Logf("Generating unencrypted data")
|
||||
users := genData(t, db)
|
||||
|
||||
// Setup an initial cipher A
|
||||
keyA := mustString(t, 32)
|
||||
cipherA, err := dbcrypt.NewCiphers([]byte(keyA))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an encrypted database
|
||||
cryptdb, err := dbcrypt.New(ctx, db, cipherA...)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Populate the database with some encrypted data using cipher A.
|
||||
t.Logf("Generating data encrypted with cipher A")
|
||||
newUsers := genData(t, cryptdb)
|
||||
|
||||
// Validate that newly created users were encrypted with cipher A
|
||||
for _, usr := range newUsers {
|
||||
requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID)
|
||||
}
|
||||
users = append(users, newUsers...)
|
||||
|
||||
// Encrypt all the data with the initial cipher.
|
||||
t.Logf("Encrypting all data with cipher A")
|
||||
inv, _ := newCLI(t, "server", "dbcrypt", "rotate",
|
||||
"--postgres-url", connectionURL,
|
||||
"--new-key", base64.StdEncoding.EncodeToString([]byte(keyA)),
|
||||
|
@ -65,18 +91,12 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID)
|
||||
}
|
||||
|
||||
// Create an encrypted database
|
||||
cryptdb, err := dbcrypt.New(ctx, db, cipherA...)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Populate the database with some encrypted data using cipher A.
|
||||
users = append(users, genData(t, cryptdb, 10)...)
|
||||
|
||||
// Re-encrypt all existing data with a new cipher.
|
||||
keyB := mustString(t, 32)
|
||||
cipherBA, err := dbcrypt.NewCiphers([]byte(keyB), []byte(keyA))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Enrypting all data with cipher B")
|
||||
inv, _ = newCLI(t, "server", "dbcrypt", "rotate",
|
||||
"--postgres-url", connectionURL,
|
||||
"--new-key", base64.StdEncoding.EncodeToString([]byte(keyB)),
|
||||
|
@ -94,6 +114,7 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
}
|
||||
|
||||
// Assert that we can revoke the old key.
|
||||
t.Logf("Revoking cipher A")
|
||||
err = db.RevokeDBCryptKey(ctx, cipherA[0].HexDigest())
|
||||
require.NoError(t, err, "failed to revoke old key")
|
||||
|
||||
|
@ -109,6 +130,7 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
require.Empty(t, oldKey.ActiveKeyDigest.String, "expected the old key to not be active")
|
||||
|
||||
// Revoking the new key should fail.
|
||||
t.Logf("Attempting to revoke cipher B should fail as it is still in use")
|
||||
err = db.RevokeDBCryptKey(ctx, cipherBA[0].HexDigest())
|
||||
require.Error(t, err, "expected to fail to revoke the new key")
|
||||
var pgErr *pq.Error
|
||||
|
@ -116,6 +138,7 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
require.EqualValues(t, "23503", pgErr.Code, "expected a foreign key constraint violation error")
|
||||
|
||||
// Decrypt the data using only cipher B. This should result in the key being revoked.
|
||||
t.Logf("Decrypting with cipher B")
|
||||
inv, _ = newCLI(t, "server", "dbcrypt", "decrypt",
|
||||
"--postgres-url", connectionURL,
|
||||
"--keys", base64.StdEncoding.EncodeToString([]byte(keyB)),
|
||||
|
@ -144,6 +167,7 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
cipherC, err := dbcrypt.NewCiphers([]byte(keyC))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Logf("Re-encrypting with cipher C")
|
||||
inv, _ = newCLI(t, "server", "dbcrypt", "rotate",
|
||||
"--postgres-url", connectionURL,
|
||||
"--new-key", base64.StdEncoding.EncodeToString([]byte(keyC)),
|
||||
|
@ -161,6 +185,7 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
}
|
||||
|
||||
// Now delete all the encrypted data.
|
||||
t.Logf("Deleting all encrypted data")
|
||||
inv, _ = newCLI(t, "server", "dbcrypt", "delete",
|
||||
"--postgres-url", connectionURL,
|
||||
"--external-token-encryption-keys", base64.StdEncoding.EncodeToString([]byte(keyC)),
|
||||
|
@ -191,30 +216,84 @@ func TestServerDBCrypt(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func genData(t *testing.T, db database.Store, n int) []database.User {
|
||||
func genData(t *testing.T, db database.Store) []database.User {
|
||||
t.Helper()
|
||||
var users []database.User
|
||||
for i := 0; i < n; i++ {
|
||||
usr := dbgen.User(t, db, database.User{
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
_ = dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: usr.ID,
|
||||
LoginType: usr.LoginType,
|
||||
OAuthAccessToken: "access-" + usr.ID.String(),
|
||||
OAuthRefreshToken: "refresh-" + usr.ID.String(),
|
||||
})
|
||||
_ = dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
UserID: usr.ID,
|
||||
ProviderID: "fake",
|
||||
OAuthAccessToken: "access-" + usr.ID.String(),
|
||||
OAuthRefreshToken: "refresh-" + usr.ID.String(),
|
||||
})
|
||||
users = append(users, usr)
|
||||
// Make some users
|
||||
for _, status := range database.AllUserStatusValues() {
|
||||
for _, loginType := range database.AllLoginTypeValues() {
|
||||
for _, deleted := range []bool{false, true} {
|
||||
usr := dbgen.User(t, db, database.User{
|
||||
LoginType: loginType,
|
||||
Status: status,
|
||||
Deleted: deleted,
|
||||
})
|
||||
_ = dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
UserID: usr.ID,
|
||||
ProviderID: "fake",
|
||||
OAuthAccessToken: "access-" + usr.ID.String(),
|
||||
OAuthRefreshToken: "refresh-" + usr.ID.String(),
|
||||
})
|
||||
// Fun fact: our schema allows _all_ login types to have
|
||||
// a user_link. Even though I'm not sure how it could occur
|
||||
// in practice, making sure to test all combinations here.
|
||||
_ = dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: usr.ID,
|
||||
LoginType: usr.LoginType,
|
||||
OAuthAccessToken: "access-" + usr.ID.String(),
|
||||
OAuthRefreshToken: "refresh-" + usr.ID.String(),
|
||||
})
|
||||
users = append(users, usr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func dumpUsers(t *testing.T, db *sql.DB) {
|
||||
t.Helper()
|
||||
rows, err := db.QueryContext(context.Background(), `SELECT
|
||||
u.id,
|
||||
u.login_type,
|
||||
u.status,
|
||||
u.deleted,
|
||||
ul.oauth_access_token_key_id AS uloatkid,
|
||||
ul.oauth_refresh_token_key_id AS ulortkid,
|
||||
gal.oauth_access_token_key_id AS galoatkid,
|
||||
gal.oauth_refresh_token_key_id AS galortkid
|
||||
FROM users u
|
||||
LEFT OUTER JOIN user_links ul ON u.id = ul.user_id
|
||||
LEFT OUTER JOIN git_auth_links gal ON u.id = gal.user_id
|
||||
ORDER BY u.created_at ASC;`)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var (
|
||||
id string
|
||||
loginType string
|
||||
status string
|
||||
deleted bool
|
||||
UlOatKid sql.NullString
|
||||
UlOrtKid sql.NullString
|
||||
GalOatKid sql.NullString
|
||||
GalOrtKid sql.NullString
|
||||
)
|
||||
require.NoError(t, rows.Scan(
|
||||
&id,
|
||||
&loginType,
|
||||
&status,
|
||||
&deleted,
|
||||
&UlOatKid,
|
||||
&UlOrtKid,
|
||||
&GalOatKid,
|
||||
&GalOrtKid,
|
||||
))
|
||||
t.Logf("user: id:%s login_type:%-8s status:%-9s deleted:%-5t ul_kids{at:%-7s rt:%-7s} gal_kids{at:%-7s rt:%-7s}",
|
||||
id, loginType, status, deleted, UlOatKid.String, UlOrtKid.String, GalOatKid.String, GalOrtKid.String,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func mustString(t *testing.T, n int) string {
|
||||
t.Helper()
|
||||
s, err := cryptorand.String(n)
|
||||
|
|
|
@ -19,45 +19,45 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
|||
return xerrors.Errorf("create cryptdb: %w", err)
|
||||
}
|
||||
|
||||
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
|
||||
userIDs, err := db.AllUserIDs(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get users: %w", err)
|
||||
}
|
||||
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users)))
|
||||
for idx, usr := range users {
|
||||
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs)))
|
||||
for idx, uid := range userIDs {
|
||||
err := cryptDB.InTx(func(tx database.Store) error {
|
||||
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
|
||||
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user links for user: %w", err)
|
||||
}
|
||||
for _, userLink := range userLinks {
|
||||
if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: userLink.OAuthAccessToken,
|
||||
OAuthRefreshToken: userLink.OAuthRefreshToken,
|
||||
OAuthExpiry: userLink.OAuthExpiry,
|
||||
UserID: usr.ID,
|
||||
LoginType: usr.LoginType,
|
||||
UserID: uid,
|
||||
LoginType: userLink.LoginType,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
|
||||
}
|
||||
}
|
||||
|
||||
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
|
||||
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get git auth links for user: %w", err)
|
||||
}
|
||||
for _, gitAuthLink := range gitAuthLinks {
|
||||
if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
|
||||
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
|
||||
ProviderID: gitAuthLink.ProviderID,
|
||||
UserID: usr.ID,
|
||||
UserID: uid,
|
||||
UpdatedAt: gitAuthLink.UpdatedAt,
|
||||
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
|
||||
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
|
||||
|
@ -73,7 +73,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
|
|||
if err != nil {
|
||||
return xerrors.Errorf("update user links: %w", err)
|
||||
}
|
||||
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
// Revoke old keys
|
||||
|
@ -103,45 +103,45 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
|||
}
|
||||
cryptDB.primaryCipherDigest = ""
|
||||
|
||||
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
|
||||
userIDs, err := db.AllUserIDs(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get users: %w", err)
|
||||
}
|
||||
log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(users)))
|
||||
for idx, usr := range users {
|
||||
log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(userIDs)))
|
||||
for idx, uid := range userIDs {
|
||||
err := cryptDB.InTx(func(tx database.Store) error {
|
||||
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
|
||||
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user links for user: %w", err)
|
||||
}
|
||||
for _, userLink := range userLinks {
|
||||
if !userLink.OAuthAccessTokenKeyID.Valid && !userLink.OAuthRefreshTokenKeyID.Valid {
|
||||
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1))
|
||||
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: userLink.OAuthAccessToken,
|
||||
OAuthRefreshToken: userLink.OAuthRefreshToken,
|
||||
OAuthExpiry: userLink.OAuthExpiry,
|
||||
UserID: usr.ID,
|
||||
LoginType: usr.LoginType,
|
||||
UserID: uid,
|
||||
LoginType: userLink.LoginType,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
|
||||
}
|
||||
}
|
||||
|
||||
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
|
||||
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get git auth links for user: %w", err)
|
||||
}
|
||||
for _, gitAuthLink := range gitAuthLinks {
|
||||
if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid {
|
||||
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1))
|
||||
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1))
|
||||
continue
|
||||
}
|
||||
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
|
||||
ProviderID: gitAuthLink.ProviderID,
|
||||
UserID: usr.ID,
|
||||
UserID: uid,
|
||||
UpdatedAt: gitAuthLink.UpdatedAt,
|
||||
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
|
||||
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
|
||||
|
@ -157,7 +157,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
|
|||
if err != nil {
|
||||
return xerrors.Errorf("update user links: %w", err)
|
||||
}
|
||||
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
|
||||
}
|
||||
|
||||
// Revoke _all_ keys
|
||||
|
|
Loading…
Reference in New Issue