fix(enterprise/dbcrypt): do not skip deleted users when encrypting or deleting (#9694)

- Broadens scope of data generation in TestServerDBCrypt over all user login types, statuses, and deletion status.
- Adds support for specifying user status / user deletion status in dbgen
- Adds more comprehensive logging in TestServerDBCrypt upon test failure (to be generalized and expanded upon in a follow-up)
- Adds AllUserIDs query, updates dbcrypt to use this instead of GetUsers.
This commit is contained in:
Cian Johnston 2023-09-15 15:09:40 +01:00 committed by GitHub
parent bc97eaa41b
commit 72dff7f188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 215 additions and 51 deletions

View File

@ -664,6 +664,15 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID) erro
return update(q.log, q.auth, fetch, q.db.ActivityBumpWorkspace)(ctx, arg)
}
func (q *querier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
// Although this technically only reads users, only system-related functions should be
// allowed to call this.
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.AllUserIDs(ctx)
}
func (q *querier) CleanTailnetCoordinators(ctx context.Context) error {
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return err

View File

@ -812,6 +812,16 @@ func (q *FakeQuerier) ActivityBumpWorkspace(ctx context.Context, workspaceID uui
return sql.ErrNoRows
}
func (q *FakeQuerier) AllUserIDs(_ context.Context) ([]uuid.UUID, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
userIDs := make([]uuid.UUID, 0, len(q.users))
for idx := range q.users {
userIDs[idx] = q.users[idx].ID
}
return userIDs, nil
}
func (*FakeQuerier) CleanTailnetCoordinators(_ context.Context) error {
return ErrUnimplemented
}

View File

@ -227,7 +227,7 @@ func User(t testing.TB, db database.Store, orig database.User) database.User {
user, err = db.UpdateUserStatus(genCtx, database.UpdateUserStatusParams{
ID: user.ID,
Status: database.UserStatusActive,
Status: takeFirst(orig.Status, database.UserStatusActive),
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err, "insert user")
@ -240,6 +240,14 @@ func User(t testing.TB, db database.Store, orig database.User) database.User {
})
require.NoError(t, err, "user last seen")
}
if orig.Deleted {
err = db.UpdateUserDeletedByID(genCtx, database.UpdateUserDeletedByIDParams{
ID: user.ID,
Deleted: orig.Deleted,
})
require.NoError(t, err, "set user as deleted")
}
return user
}

View File

@ -100,6 +100,13 @@ func (m metricsStore) ActivityBumpWorkspace(ctx context.Context, arg uuid.UUID)
return r0
}
func (m metricsStore) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.AllUserIDs(ctx)
m.queryLatencies.WithLabelValues("AllUserIDs").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) CleanTailnetCoordinators(ctx context.Context) error {
start := time.Now()
err := m.s.CleanTailnetCoordinators(ctx)

View File

@ -82,6 +82,21 @@ func (mr *MockStoreMockRecorder) ActivityBumpWorkspace(arg0, arg1 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivityBumpWorkspace", reflect.TypeOf((*MockStore)(nil).ActivityBumpWorkspace), arg0, arg1)
}
// AllUserIDs mocks base method.
func (m *MockStore) AllUserIDs(arg0 context.Context) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllUserIDs", arg0)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AllUserIDs indicates an expected call of AllUserIDs.
func (mr *MockStoreMockRecorder) AllUserIDs(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllUserIDs", reflect.TypeOf((*MockStore)(nil).AllUserIDs), arg0)
}
// CleanTailnetCoordinators mocks base method.
func (m *MockStore) CleanTailnetCoordinators(arg0 context.Context) error {
m.ctrl.T.Helper()

View File

@ -31,6 +31,8 @@ type sqlcQuerier interface {
// We only bump if workspace shutdown is manual.
// We only bump when 5% of the deadline has elapsed.
ActivityBumpWorkspace(ctx context.Context, workspaceID uuid.UUID) error
// AllUserIDs returns all UserIDs regardless of user status or deletion.
AllUserIDs(ctx context.Context) ([]uuid.UUID, error)
CleanTailnetCoordinators(ctx context.Context) error
DeleteAPIKeyByID(ctx context.Context, id string) error
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error

View File

@ -5846,6 +5846,34 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
return i, err
}
const allUserIDs = `-- name: AllUserIDs :many
SELECT DISTINCT id FROM USERS
`
// AllUserIDs returns all UserIDs regardless of user status or deletion.
func (q *sqlQuerier) AllUserIDs(ctx context.Context) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, allUserIDs)
if err != nil {
return nil, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getActiveUserCount = `-- name: GetActiveUserCount :one
SELECT
COUNT(*)

View File

@ -16,3 +16,4 @@ AND
INSERT INTO dbcrypt_keys
(number, active_key_digest, created_at, test)
VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text);

View File

@ -262,3 +262,8 @@ WHERE
last_seen_at < @last_seen_after :: timestamp
AND status = 'active'::user_status
RETURNING id, email, last_seen_at;
-- AllUserIDs returns all UserIDs regardless of user status or deletion.
-- name: AllUserIDs :many
SELECT DISTINCT id FROM USERS;

View File

@ -20,6 +20,9 @@ import (
"github.com/coder/coder/v2/pty/ptytest"
)
// TestServerDBCrypt tests end-to-end encryption, decryption, and deletion
// of encrypted user data.
//
// nolint: paralleltest // use of t.Setenv
func TestServerDBCrypt(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
@ -41,15 +44,38 @@ func TestServerDBCrypt(t *testing.T) {
})
db := database.New(sqlDB)
// Populate the database with some unencrypted data.
users := genData(t, db, 10)
t.Cleanup(func() {
if t.Failed() {
t.Logf("Dumping data due to failed test. I hope you find what you're looking for!")
dumpUsers(t, sqlDB)
}
})
// Setup an initial cipher
// Populate the database with some unencrypted data.
t.Logf("Generating unencrypted data")
users := genData(t, db)
// Setup an initial cipher A
keyA := mustString(t, 32)
cipherA, err := dbcrypt.NewCiphers([]byte(keyA))
require.NoError(t, err)
// Create an encrypted database
cryptdb, err := dbcrypt.New(ctx, db, cipherA...)
require.NoError(t, err)
// Populate the database with some encrypted data using cipher A.
t.Logf("Generating data encrypted with cipher A")
newUsers := genData(t, cryptdb)
// Validate that newly created users were encrypted with cipher A
for _, usr := range newUsers {
requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID)
}
users = append(users, newUsers...)
// Encrypt all the data with the initial cipher.
t.Logf("Encrypting all data with cipher A")
inv, _ := newCLI(t, "server", "dbcrypt", "rotate",
"--postgres-url", connectionURL,
"--new-key", base64.StdEncoding.EncodeToString([]byte(keyA)),
@ -65,18 +91,12 @@ func TestServerDBCrypt(t *testing.T) {
requireEncryptedWithCipher(ctx, t, db, cipherA[0], usr.ID)
}
// Create an encrypted database
cryptdb, err := dbcrypt.New(ctx, db, cipherA...)
require.NoError(t, err)
// Populate the database with some encrypted data using cipher A.
users = append(users, genData(t, cryptdb, 10)...)
// Re-encrypt all existing data with a new cipher.
keyB := mustString(t, 32)
cipherBA, err := dbcrypt.NewCiphers([]byte(keyB), []byte(keyA))
require.NoError(t, err)
t.Logf("Enrypting all data with cipher B")
inv, _ = newCLI(t, "server", "dbcrypt", "rotate",
"--postgres-url", connectionURL,
"--new-key", base64.StdEncoding.EncodeToString([]byte(keyB)),
@ -94,6 +114,7 @@ func TestServerDBCrypt(t *testing.T) {
}
// Assert that we can revoke the old key.
t.Logf("Revoking cipher A")
err = db.RevokeDBCryptKey(ctx, cipherA[0].HexDigest())
require.NoError(t, err, "failed to revoke old key")
@ -109,6 +130,7 @@ func TestServerDBCrypt(t *testing.T) {
require.Empty(t, oldKey.ActiveKeyDigest.String, "expected the old key to not be active")
// Revoking the new key should fail.
t.Logf("Attempting to revoke cipher B should fail as it is still in use")
err = db.RevokeDBCryptKey(ctx, cipherBA[0].HexDigest())
require.Error(t, err, "expected to fail to revoke the new key")
var pgErr *pq.Error
@ -116,6 +138,7 @@ func TestServerDBCrypt(t *testing.T) {
require.EqualValues(t, "23503", pgErr.Code, "expected a foreign key constraint violation error")
// Decrypt the data using only cipher B. This should result in the key being revoked.
t.Logf("Decrypting with cipher B")
inv, _ = newCLI(t, "server", "dbcrypt", "decrypt",
"--postgres-url", connectionURL,
"--keys", base64.StdEncoding.EncodeToString([]byte(keyB)),
@ -144,6 +167,7 @@ func TestServerDBCrypt(t *testing.T) {
cipherC, err := dbcrypt.NewCiphers([]byte(keyC))
require.NoError(t, err)
t.Logf("Re-encrypting with cipher C")
inv, _ = newCLI(t, "server", "dbcrypt", "rotate",
"--postgres-url", connectionURL,
"--new-key", base64.StdEncoding.EncodeToString([]byte(keyC)),
@ -161,6 +185,7 @@ func TestServerDBCrypt(t *testing.T) {
}
// Now delete all the encrypted data.
t.Logf("Deleting all encrypted data")
inv, _ = newCLI(t, "server", "dbcrypt", "delete",
"--postgres-url", connectionURL,
"--external-token-encryption-keys", base64.StdEncoding.EncodeToString([]byte(keyC)),
@ -191,30 +216,84 @@ func TestServerDBCrypt(t *testing.T) {
}
}
func genData(t *testing.T, db database.Store, n int) []database.User {
func genData(t *testing.T, db database.Store) []database.User {
t.Helper()
var users []database.User
for i := 0; i < n; i++ {
usr := dbgen.User(t, db, database.User{
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: usr.ID,
LoginType: usr.LoginType,
OAuthAccessToken: "access-" + usr.ID.String(),
OAuthRefreshToken: "refresh-" + usr.ID.String(),
})
_ = dbgen.GitAuthLink(t, db, database.GitAuthLink{
UserID: usr.ID,
ProviderID: "fake",
OAuthAccessToken: "access-" + usr.ID.String(),
OAuthRefreshToken: "refresh-" + usr.ID.String(),
})
users = append(users, usr)
// Make some users
for _, status := range database.AllUserStatusValues() {
for _, loginType := range database.AllLoginTypeValues() {
for _, deleted := range []bool{false, true} {
usr := dbgen.User(t, db, database.User{
LoginType: loginType,
Status: status,
Deleted: deleted,
})
_ = dbgen.GitAuthLink(t, db, database.GitAuthLink{
UserID: usr.ID,
ProviderID: "fake",
OAuthAccessToken: "access-" + usr.ID.String(),
OAuthRefreshToken: "refresh-" + usr.ID.String(),
})
// Fun fact: our schema allows _all_ login types to have
// a user_link. Even though I'm not sure how it could occur
// in practice, making sure to test all combinations here.
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: usr.ID,
LoginType: usr.LoginType,
OAuthAccessToken: "access-" + usr.ID.String(),
OAuthRefreshToken: "refresh-" + usr.ID.String(),
})
users = append(users, usr)
}
}
}
return users
}
func dumpUsers(t *testing.T, db *sql.DB) {
t.Helper()
rows, err := db.QueryContext(context.Background(), `SELECT
u.id,
u.login_type,
u.status,
u.deleted,
ul.oauth_access_token_key_id AS uloatkid,
ul.oauth_refresh_token_key_id AS ulortkid,
gal.oauth_access_token_key_id AS galoatkid,
gal.oauth_refresh_token_key_id AS galortkid
FROM users u
LEFT OUTER JOIN user_links ul ON u.id = ul.user_id
LEFT OUTER JOIN git_auth_links gal ON u.id = gal.user_id
ORDER BY u.created_at ASC;`)
require.NoError(t, err)
defer rows.Close()
for rows.Next() {
var (
id string
loginType string
status string
deleted bool
UlOatKid sql.NullString
UlOrtKid sql.NullString
GalOatKid sql.NullString
GalOrtKid sql.NullString
)
require.NoError(t, rows.Scan(
&id,
&loginType,
&status,
&deleted,
&UlOatKid,
&UlOrtKid,
&GalOatKid,
&GalOrtKid,
))
t.Logf("user: id:%s login_type:%-8s status:%-9s deleted:%-5t ul_kids{at:%-7s rt:%-7s} gal_kids{at:%-7s rt:%-7s}",
id, loginType, status, deleted, UlOatKid.String, UlOrtKid.String, GalOatKid.String, GalOrtKid.String,
)
}
}
func mustString(t *testing.T, n int) string {
t.Helper()
s, err := cryptorand.String(n)

View File

@ -19,45 +19,45 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
return xerrors.Errorf("create cryptdb: %w", err)
}
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
userIDs, err := db.AllUserIDs(ctx)
if err != nil {
return xerrors.Errorf("get users: %w", err)
}
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(users)))
for idx, usr := range users {
log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs)))
for idx, uid := range userIDs {
err := cryptDB.InTx(func(tx database.Store) error {
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get user links for user: %w", err)
}
for _, userLink := range userLinks {
if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: userLink.OAuthAccessToken,
OAuthRefreshToken: userLink.OAuthRefreshToken,
OAuthExpiry: userLink.OAuthExpiry,
UserID: usr.ID,
LoginType: usr.LoginType,
UserID: uid,
LoginType: userLink.LoginType,
}); err != nil {
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
}
}
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get git auth links for user: %w", err)
}
for _, gitAuthLink := range gitAuthLinks {
if gitAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && gitAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() {
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
continue
}
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
ProviderID: gitAuthLink.ProviderID,
UserID: usr.ID,
UserID: uid,
UpdatedAt: gitAuthLink.UpdatedAt,
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
@ -73,7 +73,7 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe
if err != nil {
return xerrors.Errorf("update user links: %w", err)
}
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
// Revoke old keys
@ -103,45 +103,45 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
}
cryptDB.primaryCipherDigest = ""
users, err := cryptDB.GetUsers(ctx, database.GetUsersParams{})
userIDs, err := db.AllUserIDs(ctx)
if err != nil {
return xerrors.Errorf("get users: %w", err)
}
log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(users)))
for idx, usr := range users {
log.Info(ctx, "decrypting user tokens", slog.F("user_count", len(userIDs)))
for idx, uid := range userIDs {
err := cryptDB.InTx(func(tx database.Store) error {
userLinks, err := tx.GetUserLinksByUserID(ctx, usr.ID)
userLinks, err := tx.GetUserLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get user links for user: %w", err)
}
for _, userLink := range userLinks {
if !userLink.OAuthAccessTokenKeyID.Valid && !userLink.OAuthRefreshTokenKeyID.Valid {
log.Debug(ctx, "skipping user link", slog.F("user_id", usr.ID), slog.F("current", idx+1))
log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: userLink.OAuthAccessToken,
OAuthRefreshToken: userLink.OAuthRefreshToken,
OAuthExpiry: userLink.OAuthExpiry,
UserID: usr.ID,
LoginType: usr.LoginType,
UserID: uid,
LoginType: userLink.LoginType,
}); err != nil {
return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err)
}
}
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, usr.ID)
gitAuthLinks, err := tx.GetGitAuthLinksByUserID(ctx, uid)
if err != nil {
return xerrors.Errorf("get git auth links for user: %w", err)
}
for _, gitAuthLink := range gitAuthLinks {
if !gitAuthLink.OAuthAccessTokenKeyID.Valid && !gitAuthLink.OAuthRefreshTokenKeyID.Valid {
log.Debug(ctx, "skipping git auth link", slog.F("user_id", usr.ID), slog.F("current", idx+1))
log.Debug(ctx, "skipping git auth link", slog.F("user_id", uid), slog.F("current", idx+1))
continue
}
if _, err := tx.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
ProviderID: gitAuthLink.ProviderID,
UserID: usr.ID,
UserID: uid,
UpdatedAt: gitAuthLink.UpdatedAt,
OAuthAccessToken: gitAuthLink.OAuthAccessToken,
OAuthRefreshToken: gitAuthLink.OAuthRefreshToken,
@ -157,7 +157,7 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph
if err != nil {
return xerrors.Errorf("update user links: %w", err)
}
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", usr.ID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest()))
}
// Revoke _all_ keys