feat(coderd): add dbcrypt package (#9522)

- Adds package enterprise/dbcrypt to implement database encryption/decryption
- Adds table dbcrypt_keys and associated queries
- Adds columns oauth_access_token_key_id and oauth_refresh_token_key_id
  to tables git_auth_links and user_links

Co-authored-by: Kyle Carberry <kyle@coder.com>
This commit is contained in:
Cian Johnston 2023-09-06 12:06:26 +01:00 committed by GitHub
parent 3bd0fd396c
commit 7918e65510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1996 additions and 72 deletions

View File

@ -838,6 +838,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetDBCryptKeys(ctx)
}
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
@ -914,6 +921,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
}
func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetGitAuthLinksByUserID(ctx, userID)
}
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID)
}
@ -1482,6 +1496,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
}
func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetUserLinksByUserID(ctx, userID)
}
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
// This does the filtering in SQL.
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
@ -1845,6 +1866,13 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
}
func (q *querier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.InsertDBCryptKey(ctx, arg)
}
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
return err
@ -2144,6 +2172,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.RevokeDBCryptKey(ctx, activeKeyDigest)
}
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}

View File

@ -31,6 +31,11 @@ import (
var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
var errForeignKeyConstraint = &pq.Error{
Code: "23503",
Message: "update or delete on table violates foreign key constraint",
}
var errDuplicateKey = &pq.Error{
Code: "23505",
Message: "duplicate key value violates unique constraint",
@ -45,6 +50,7 @@ func New() database.Store {
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
dbcryptKeys: make([]database.DBCryptKey, 0),
gitAuthLinks: make([]database.GitAuthLink, 0),
groups: make([]database.Group, 0),
groupMembers: make([]database.GroupMember, 0),
@ -117,6 +123,7 @@ type data struct {
// New tables
workspaceAgentStats []database.WorkspaceAgentStat
auditLogs []database.AuditLog
dbcryptKeys []database.DBCryptKey
files []database.File
gitAuthLinks []database.GitAuthLink
gitSSHKey []database.GitSSHKey
@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
return false
}
func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys))
for _, k := range q.dbcryptKeys {
if !k.ActiveKeyDigest.Valid {
continue
}
ks = append([]database.DBCryptKey{}, k)
}
return ks, nil
}
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}
@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
}, nil
}
func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
ks := make([]database.DBCryptKey, 0)
ks = append(ks, q.dbcryptKeys...)
return ks, nil
}
func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
return database.GitAuthLink{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
gals := make([]database.GitAuthLink, 0)
for _, gal := range q.gitAuthLinks {
if gal.UserID == userID {
gals = append(gals, gal)
}
}
return gals, nil
}
func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
return database.UserLink{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
uls := make([]database.UserLink, 0)
for _, ul := range q.userLinks {
if ul.UserID == userID {
uls = append(uls, ul)
}
}
return uls, nil
}
func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
if err := validateDatabaseType(params); err != nil {
return nil, err
@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
return alog, nil
}
func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
for _, key := range q.dbcryptKeys {
if key.Number == arg.Number {
return errDuplicateKey
}
}
q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{
Number: arg.Number,
ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true},
Test: arg.Test,
})
return nil
}
func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
defer q.mutex.Unlock()
// nolint:gosimple
gitAuthLink := database.GitAuthLink{
ProviderID: arg.ProviderID,
UserID: arg.UserID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthExpiry: arg.OAuthExpiry,
ProviderID: arg.ProviderID,
UserID: arg.UserID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
OAuthExpiry: arg.OAuthExpiry,
}
q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink)
return gitAuthLink, nil
@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
//nolint:gosimple
link := database.UserLink{
UserID: args.UserID,
LoginType: args.LoginType,
LinkedID: args.LinkedID,
OAuthAccessToken: args.OAuthAccessToken,
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthExpiry: args.OAuthExpiry,
UserID: args.UserID,
LoginType: args.LoginType,
LinkedID: args.LinkedID,
OAuthAccessToken: args.OAuthAccessToken,
OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID,
OAuthRefreshToken: args.OAuthRefreshToken,
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
OAuthExpiry: args.OAuthExpiry,
}
q.userLinks = append(q.userLinks, link)
@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
return database.WorkspaceProxy{}, sql.ErrNoRows
}
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for i := range q.dbcryptKeys {
key := q.dbcryptKeys[i]
// Is the key already revoked?
if !key.ActiveKeyDigest.Valid {
continue
}
if key.ActiveKeyDigest.String != activeKeyDigest {
continue
}
// Check for foreign key constraints.
for _, ul := range q.userLinks {
if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
(ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
return errForeignKeyConstraint
}
}
for _, gal := range q.gitAuthLinks {
if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
(gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
return errForeignKeyConstraint
}
}
// Revoke the key.
q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true}
q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true}
q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{}
return nil
}
return sql.ErrNoRows
}
func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) {
return false, xerrors.New("TryAcquireLock must only be called within a transaction")
}
@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
}
gitAuthLink.UpdatedAt = arg.UpdatedAt
gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken
gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
q.gitAuthLinks[index] = gitAuthLink
@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
for i, link := range q.userLinks {
if link.UserID == params.UserID && link.LoginType == params.LoginType {
link.OAuthAccessToken = params.OAuthAccessToken
link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID
link.OAuthRefreshToken = params.OAuthRefreshToken
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
link.OAuthExpiry = params.OAuthExpiry
q.userLinks[i] = link

View File

@ -470,12 +470,14 @@ func File(t testing.TB, db database.Store, orig database.File) database.File {
func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.UserLink {
link, err := db.InsertUserLink(genCtx, database.InsertUserLinkParams{
UserID: takeFirst(orig.UserID, uuid.New()),
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
LinkedID: takeFirst(orig.LinkedID),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
UserID: takeFirst(orig.UserID, uuid.New()),
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
LinkedID: takeFirst(orig.LinkedID),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
})
require.NoError(t, err, "insert link")
@ -484,13 +486,15 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) database.GitAuthLink {
link, err := db.InsertGitAuthLink(genCtx, database.InsertGitAuthLinkParams{
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
UserID: takeFirst(orig.UserID, uuid.New()),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
UserID: takeFirst(orig.UserID, uuid.New()),
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
})
require.NoError(t, err, "insert git auth link")

View File

@ -279,6 +279,13 @@ func (m metricsStore) GetAuthorizationUserRoles(ctx context.Context, userID uuid
return row, err
}
func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
start := time.Now()
r0, r1 := m.s.GetDBCryptKeys(ctx)
m.queryLatencies.WithLabelValues("GetDBCryptKeys").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetDERPMeshKey(ctx context.Context) (string, error) {
start := time.Now()
key, err := m.s.GetDERPMeshKey(ctx)
@ -349,6 +356,13 @@ func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAut
return link, err
}
func (m metricsStore) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
start := time.Now()
r0, r1 := m.s.GetGitAuthLinksByUserID(ctx, userID)
m.queryLatencies.WithLabelValues("GetGitAuthLinksByUserID").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
start := time.Now()
key, err := m.s.GetGitSSHKey(ctx, userID)
@ -774,6 +788,13 @@ func (m metricsStore) GetUserLinkByUserIDLoginType(ctx context.Context, arg data
return link, err
}
func (m metricsStore) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
start := time.Now()
r0, r1 := m.s.GetUserLinksByUserID(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserLinksByUserID").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
start := time.Now()
users, err := m.s.GetUsers(ctx, arg)
@ -1068,6 +1089,13 @@ func (m metricsStore) InsertAuditLog(ctx context.Context, arg database.InsertAud
return log, err
}
func (m metricsStore) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
start := time.Now()
r0 := m.s.InsertDBCryptKey(ctx, arg)
m.queryLatencies.WithLabelValues("InsertDBCryptKey").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) InsertDERPMeshKey(ctx context.Context, value string) error {
start := time.Now()
err := m.s.InsertDERPMeshKey(ctx, value)
@ -1320,6 +1348,13 @@ func (m metricsStore) RegisterWorkspaceProxy(ctx context.Context, arg database.R
return proxy, err
}
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
m.queryLatencies.WithLabelValues("RevokeDBCryptKey").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
start := time.Now()
ok, err := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)

View File

@ -506,6 +506,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaces(arg0, arg1, arg2 interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaces", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaces), arg0, arg1, arg2)
}
// GetDBCryptKeys mocks base method.
func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDBCryptKeys", arg0)
ret0, _ := ret[0].([]database.DBCryptKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetDBCryptKeys indicates an expected call of GetDBCryptKeys.
func (mr *MockStoreMockRecorder) GetDBCryptKeys(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDBCryptKeys", reflect.TypeOf((*MockStore)(nil).GetDBCryptKeys), arg0)
}
// GetDERPMeshKey mocks base method.
func (m *MockStore) GetDERPMeshKey(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
@ -656,6 +671,21 @@ func (mr *MockStoreMockRecorder) GetGitAuthLink(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLink", reflect.TypeOf((*MockStore)(nil).GetGitAuthLink), arg0, arg1)
}
// GetGitAuthLinksByUserID mocks base method.
func (m *MockStore) GetGitAuthLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.GitAuthLink, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGitAuthLinksByUserID", arg0, arg1)
ret0, _ := ret[0].([]database.GitAuthLink)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGitAuthLinksByUserID indicates an expected call of GetGitAuthLinksByUserID.
func (mr *MockStoreMockRecorder) GetGitAuthLinksByUserID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetGitAuthLinksByUserID), arg0, arg1)
}
// GetGitSSHKey mocks base method.
func (m *MockStore) GetGitSSHKey(arg0 context.Context, arg1 uuid.UUID) (database.GitSSHKey, error) {
m.ctrl.T.Helper()
@ -1601,6 +1631,21 @@ func (mr *MockStoreMockRecorder) GetUserLinkByUserIDLoginType(arg0, arg1 interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinkByUserIDLoginType", reflect.TypeOf((*MockStore)(nil).GetUserLinkByUserIDLoginType), arg0, arg1)
}
// GetUserLinksByUserID mocks base method.
func (m *MockStore) GetUserLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.UserLink, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserLinksByUserID", arg0, arg1)
ret0, _ := ret[0].([]database.UserLink)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserLinksByUserID indicates an expected call of GetUserLinksByUserID.
func (mr *MockStoreMockRecorder) GetUserLinksByUserID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetUserLinksByUserID), arg0, arg1)
}
// GetUsers mocks base method.
func (m *MockStore) GetUsers(arg0 context.Context, arg1 database.GetUsersParams) ([]database.GetUsersRow, error) {
m.ctrl.T.Helper()
@ -2245,6 +2290,20 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), arg0, arg1)
}
// InsertDBCryptKey mocks base method.
func (m *MockStore) InsertDBCryptKey(arg0 context.Context, arg1 database.InsertDBCryptKeyParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertDBCryptKey", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// InsertDBCryptKey indicates an expected call of InsertDBCryptKey.
func (mr *MockStoreMockRecorder) InsertDBCryptKey(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertDBCryptKey", reflect.TypeOf((*MockStore)(nil).InsertDBCryptKey), arg0, arg1)
}
// InsertDERPMeshKey mocks base method.
func (m *MockStore) InsertDERPMeshKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
@ -2789,6 +2848,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeDBCryptKey", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeDBCryptKey indicates an expected call of RevokeDBCryptKey.
func (mr *MockStoreMockRecorder) RevokeDBCryptKey(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeDBCryptKey", reflect.TypeOf((*MockStore)(nil).RevokeDBCryptKey), arg0, arg1)
}
// TryAcquireLock mocks base method.
func (m *MockStore) TryAcquireLock(arg0 context.Context, arg1 int64) (bool, error) {
m.ctrl.T.Helper()

View File

@ -275,6 +275,29 @@ CREATE TABLE audit_logs (
resource_icon text NOT NULL
);
CREATE TABLE dbcrypt_keys (
number integer NOT NULL,
active_key_digest text,
revoked_key_digest text,
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
revoked_at timestamp with time zone,
test text NOT NULL
);
COMMENT ON TABLE dbcrypt_keys IS 'A table used to store the keys used to encrypt the database.';
COMMENT ON COLUMN dbcrypt_keys.number IS 'An integer used to identify the key.';
COMMENT ON COLUMN dbcrypt_keys.active_key_digest IS 'If the key is active, the digest of the active key.';
COMMENT ON COLUMN dbcrypt_keys.revoked_key_digest IS 'If the key has been revoked, the digest of the revoked key.';
COMMENT ON COLUMN dbcrypt_keys.created_at IS 'The time at which the key was created.';
COMMENT ON COLUMN dbcrypt_keys.revoked_at IS 'The time at which the key was revoked.';
COMMENT ON COLUMN dbcrypt_keys.test IS 'A column used to test the encryption.';
CREATE TABLE files (
hash character varying(64) NOT NULL,
created_at timestamp with time zone NOT NULL,
@ -291,9 +314,15 @@ CREATE TABLE git_auth_links (
updated_at timestamp with time zone NOT NULL,
oauth_access_token text NOT NULL,
oauth_refresh_token text NOT NULL,
oauth_expiry timestamp with time zone NOT NULL
oauth_expiry timestamp with time zone NOT NULL,
oauth_access_token_key_id text,
oauth_refresh_token_key_id text
);
COMMENT ON COLUMN git_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
COMMENT ON COLUMN git_auth_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
CREATE TABLE gitsshkeys (
user_id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@ -701,9 +730,15 @@ CREATE TABLE user_links (
linked_id text DEFAULT ''::text NOT NULL,
oauth_access_token text DEFAULT ''::text NOT NULL,
oauth_refresh_token text DEFAULT ''::text NOT NULL,
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
oauth_access_token_key_id text,
oauth_refresh_token_key_id text
);
COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
CREATE TABLE workspace_agent_logs (
agent_id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@ -1037,6 +1072,15 @@ ALTER TABLE ONLY api_keys
ALTER TABLE ONLY audit_logs
ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY dbcrypt_keys
ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest);
ALTER TABLE ONLY dbcrypt_keys
ADD CONSTRAINT dbcrypt_keys_pkey PRIMARY KEY (number);
ALTER TABLE ONLY dbcrypt_keys
ADD CONSTRAINT dbcrypt_keys_revoked_key_digest_key UNIQUE (revoked_key_digest);
ALTER TABLE ONLY files
ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by);
@ -1249,6 +1293,12 @@ CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW
ALTER TABLE ONLY api_keys
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY git_auth_links
ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY git_auth_links
ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY gitsshkeys
ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
@ -1303,6 +1353,12 @@ ALTER TABLE ONLY templates
ALTER TABLE ONLY templates
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY user_links
ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY user_links
ADD CONSTRAINT user_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY user_links
ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;

View File

@ -0,0 +1,43 @@
BEGIN;
-- Before dropping this table, we need to check if there exist any
-- foreign key references to it. We do this by checking the following:
-- user_links.oauth_access_token_key_id
-- user_links.oauth_refresh_token_key_id
-- git_auth_links.oauth_access_token_key_id
-- git_auth_links.oauth_refresh_token_key_id
DO $$
BEGIN
IF EXISTS (
SELECT *
FROM user_links
WHERE oauth_access_token_key_id IS NOT NULL
OR oauth_refresh_token_key_id IS NOT NULL
) THEN RAISE EXCEPTION 'Cannot drop dbcrypt_keys table as there are still foreign key references to it from user_links.';
END IF;
IF EXISTS (
SELECT *
FROM git_auth_links
WHERE oauth_access_token_key_id IS NOT NULL
OR oauth_refresh_token_key_id IS NOT NULL
) THEN RAISE EXCEPTION 'Cannot drop dbcrypt_keys table as there are still foreign key references to it from git_auth_links.';
END IF;
END
$$;
-- Drop the columns first.
ALTER TABLE git_auth_links
DROP COLUMN IF EXISTS oauth_access_token_key_id,
DROP COLUMN IF EXISTS oauth_refresh_token_key_id;
ALTER TABLE user_links
DROP COLUMN IF EXISTS oauth_access_token_key_id,
DROP COLUMN IF EXISTS oauth_refresh_token_key_id;
-- Finally, drop the table.
DROP TABLE IF EXISTS dbcrypt_keys;
COMMIT;

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS dbcrypt_keys (
number int NOT NULL PRIMARY KEY,
active_key_digest text UNIQUE,
revoked_key_digest text UNIQUE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
revoked_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
test TEXT NOT NULL
);
COMMENT ON TABLE dbcrypt_keys IS 'A table used to store the keys used to encrypt the database.';
COMMENT ON COLUMN dbcrypt_keys.number IS 'An integer used to identify the key.';
COMMENT ON COLUMN dbcrypt_keys.active_key_digest IS 'If the key is active, the digest of the active key.';
COMMENT ON COLUMN dbcrypt_keys.revoked_key_digest IS 'If the key has been revoked, the digest of the revoked key.';
COMMENT ON COLUMN dbcrypt_keys.created_at IS 'The time at which the key was created.';
COMMENT ON COLUMN dbcrypt_keys.revoked_at IS 'The time at which the key was revoked.';
COMMENT ON COLUMN dbcrypt_keys.test IS 'A column used to test the encryption.';
ALTER TABLE git_auth_links
ADD COLUMN IF NOT EXISTS oauth_access_token_key_id text REFERENCES dbcrypt_keys(active_key_digest),
ADD COLUMN IF NOT EXISTS oauth_refresh_token_key_id text REFERENCES dbcrypt_keys(active_key_digest);
COMMENT ON COLUMN git_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
COMMENT ON COLUMN git_auth_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
ALTER TABLE user_links
ADD COLUMN IF NOT EXISTS oauth_access_token_key_id text REFERENCES dbcrypt_keys(active_key_digest),
ADD COLUMN IF NOT EXISTS oauth_refresh_token_key_id text REFERENCES dbcrypt_keys(active_key_digest);
COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';

View File

@ -266,6 +266,7 @@ func TestMigrateUpWithFixtures(t *testing.T) {
"template_version_parameters",
"workspace_build_parameters",
"template_version_variables",
"dbcrypt_keys", // having zero rows is a valid state for this table
}
s := &tableStats{s: make(map[string]int)}

View File

@ -1591,6 +1591,22 @@ type AuditLog struct {
ResourceIcon string `db:"resource_icon" json:"resource_icon"`
}
// A table used to store the keys used to encrypt the database.
type DBCryptKey struct {
// An integer used to identify the key.
Number int32 `db:"number" json:"number"`
// If the key is active, the digest of the active key.
ActiveKeyDigest sql.NullString `db:"active_key_digest" json:"active_key_digest"`
// If the key has been revoked, the digest of the revoked key.
RevokedKeyDigest sql.NullString `db:"revoked_key_digest" json:"revoked_key_digest"`
// The time at which the key was created.
CreatedAt sql.NullTime `db:"created_at" json:"created_at"`
// The time at which the key was revoked.
RevokedAt sql.NullTime `db:"revoked_at" json:"revoked_at"`
// A column used to test the encryption.
Test string `db:"test" json:"test"`
}
type File struct {
Hash string `db:"hash" json:"hash"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
@ -1608,6 +1624,10 @@ type GitAuthLink struct {
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
}
type GitSSHKey struct {
@ -1949,6 +1969,10 @@ type UserLink struct {
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
}
// Visible fields of users are allowed to be joined with other tables for including context of other resources.

View File

@ -58,6 +58,7 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error)
GetDERPMeshKey(ctx context.Context) (string, error)
GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error)
GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error)
@ -69,6 +70,7 @@ type sqlcQuerier interface {
// Get all templates that use a file.
GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error)
GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error)
GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error)
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
@ -150,6 +152,7 @@ type sqlcQuerier interface {
GetUserLatencyInsights(ctx context.Context, arg GetUserLatencyInsightsParams) ([]GetUserLatencyInsightsRow, error)
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
// This will never return deleted users.
GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error)
// This shouldn't check for deleted, because it's frequently used
@ -206,6 +209,7 @@ type sqlcQuerier interface {
// every member of the org.
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error
InsertDERPMeshKey(ctx context.Context, value string) error
InsertDeploymentID(ctx context.Context, value string) error
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
@ -247,6 +251,7 @@ type sqlcQuerier interface {
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
//
// This must be called from within a transaction. The lock will be automatically

View File

@ -636,6 +636,74 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam
return i, err
}
const getDBCryptKeys = `-- name: GetDBCryptKeys :many
SELECT number, active_key_digest, revoked_key_digest, created_at, revoked_at, test FROM dbcrypt_keys ORDER BY number ASC
`
func (q *sqlQuerier) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) {
rows, err := q.db.QueryContext(ctx, getDBCryptKeys)
if err != nil {
return nil, err
}
defer rows.Close()
var items []DBCryptKey
for rows.Next() {
var i DBCryptKey
if err := rows.Scan(
&i.Number,
&i.ActiveKeyDigest,
&i.RevokedKeyDigest,
&i.CreatedAt,
&i.RevokedAt,
&i.Test,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertDBCryptKey = `-- name: InsertDBCryptKey :exec
INSERT INTO dbcrypt_keys
(number, active_key_digest, created_at, test)
VALUES ($1::int, $2::text, CURRENT_TIMESTAMP, $3::text)
`
type InsertDBCryptKeyParams struct {
Number int32 `db:"number" json:"number"`
ActiveKeyDigest string `db:"active_key_digest" json:"active_key_digest"`
Test string `db:"test" json:"test"`
}
func (q *sqlQuerier) InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error {
_, err := q.db.ExecContext(ctx, insertDBCryptKey, arg.Number, arg.ActiveKeyDigest, arg.Test)
return err
}
const revokeDBCryptKey = `-- name: RevokeDBCryptKey :exec
UPDATE dbcrypt_keys
SET
revoked_key_digest = active_key_digest,
active_key_digest = revoked_key_digest,
revoked_at = CURRENT_TIMESTAMP
WHERE
active_key_digest = $1::text
AND
revoked_key_digest IS NULL
`
func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
_, err := q.db.ExecContext(ctx, revokeDBCryptKey, activeKeyDigest)
return err
}
const getFileByHashAndCreator = `-- name: GetFileByHashAndCreator :one
SELECT
hash, created_at, created_by, mimetype, data, id
@ -800,7 +868,7 @@ func (q *sqlQuerier) InsertFile(ctx context.Context, arg InsertFileParams) (File
}
const getGitAuthLink = `-- name: GetGitAuthLink :one
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry FROM git_auth_links WHERE provider_id = $1 AND user_id = $2
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM git_auth_links WHERE provider_id = $1 AND user_id = $2
`
type GetGitAuthLinkParams struct {
@ -819,10 +887,49 @@ func (q *sqlQuerier) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParam
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
const getGitAuthLinksByUserID = `-- name: GetGitAuthLinksByUserID :many
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM git_auth_links WHERE user_id = $1
`
func (q *sqlQuerier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error) {
rows, err := q.db.QueryContext(ctx, getGitAuthLinksByUserID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GitAuthLink
for rows.Next() {
var i GitAuthLink
if err := rows.Scan(
&i.ProviderID,
&i.UserID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertGitAuthLink = `-- name: InsertGitAuthLink :one
INSERT INTO git_auth_links (
provider_id,
@ -830,7 +937,9 @@ INSERT INTO git_auth_links (
created_at,
updated_at,
oauth_access_token,
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
) VALUES (
$1,
@ -839,18 +948,22 @@ INSERT INTO git_auth_links (
$4,
$5,
$6,
$7
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry
$7,
$8,
$9
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
`
type InsertGitAuthLinkParams struct {
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLinkParams) (GitAuthLink, error) {
@ -860,7 +973,9 @@ func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLin
arg.CreatedAt,
arg.UpdatedAt,
arg.OAuthAccessToken,
arg.OAuthAccessTokenKeyID,
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
)
var i GitAuthLink
@ -872,6 +987,8 @@ func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLin
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
@ -880,18 +997,22 @@ const updateGitAuthLink = `-- name: UpdateGitAuthLink :one
UPDATE git_auth_links SET
updated_at = $3,
oauth_access_token = $4,
oauth_refresh_token = $5,
oauth_expiry = $6
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry
oauth_access_token_key_id = $5,
oauth_refresh_token = $6,
oauth_refresh_token_key_id = $7,
oauth_expiry = $8
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
`
type UpdateGitAuthLinkParams struct {
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLinkParams) (GitAuthLink, error) {
@ -900,7 +1021,9 @@ func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLin
arg.UserID,
arg.UpdatedAt,
arg.OAuthAccessToken,
arg.OAuthAccessTokenKeyID,
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
)
var i GitAuthLink
@ -912,6 +1035,8 @@ func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLin
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
@ -5450,7 +5575,7 @@ func (q *sqlQuerier) InsertTemplateVersionVariable(ctx context.Context, arg Inse
const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one
SELECT
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
FROM
user_links
WHERE
@ -5467,13 +5592,15 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string)
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one
SELECT
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
FROM
user_links
WHERE
@ -5495,10 +5622,48 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
const getUserLinksByUserID = `-- name: GetUserLinksByUserID :many
SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM user_links WHERE user_id = $1
`
func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) {
rows, err := q.db.QueryContext(ctx, getUserLinksByUserID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []UserLink
for rows.Next() {
var i UserLink
if err := rows.Scan(
&i.UserID,
&i.LoginType,
&i.LinkedID,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertUserLink = `-- name: InsertUserLink :one
INSERT INTO
user_links (
@ -5506,20 +5671,24 @@ INSERT INTO
login_type,
linked_id,
oauth_access_token,
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
)
VALUES
( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
`
type InsertUserLinkParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
LinkedID string `db:"linked_id" json:"linked_id"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
LinkedID string `db:"linked_id" json:"linked_id"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) {
@ -5528,7 +5697,9 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam
arg.LoginType,
arg.LinkedID,
arg.OAuthAccessToken,
arg.OAuthAccessTokenKeyID,
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
)
var i UserLink
@ -5539,6 +5710,8 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
@ -5548,24 +5721,30 @@ UPDATE
user_links
SET
oauth_access_token = $1,
oauth_refresh_token = $2,
oauth_expiry = $3
oauth_access_token_key_id = $2,
oauth_refresh_token = $3,
oauth_refresh_token_key_id = $4,
oauth_expiry = $5
WHERE
user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
user_id = $6 AND login_type = $7 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
`
type UpdateUserLinkParams struct {
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) {
row := q.db.QueryRowContext(ctx, updateUserLink,
arg.OAuthAccessToken,
arg.OAuthAccessTokenKeyID,
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
arg.UserID,
arg.LoginType,
@ -5578,6 +5757,8 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}
@ -5588,7 +5769,7 @@ UPDATE
SET
linked_id = $1
WHERE
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
`
type UpdateUserLinkedIDParams struct {
@ -5607,6 +5788,8 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
)
return i, err
}

View File

@ -0,0 +1,18 @@
-- name: GetDBCryptKeys :many
SELECT * FROM dbcrypt_keys ORDER BY number ASC;
-- name: RevokeDBCryptKey :exec
UPDATE dbcrypt_keys
SET
revoked_key_digest = active_key_digest,
active_key_digest = revoked_key_digest,
revoked_at = CURRENT_TIMESTAMP
WHERE
active_key_digest = @active_key_digest::text
AND
revoked_key_digest IS NULL;
-- name: InsertDBCryptKey :exec
INSERT INTO dbcrypt_keys
(number, active_key_digest, created_at, test)
VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text);

View File

@ -1,6 +1,9 @@
-- name: GetGitAuthLink :one
SELECT * FROM git_auth_links WHERE provider_id = $1 AND user_id = $2;
-- name: GetGitAuthLinksByUserID :many
SELECT * FROM git_auth_links WHERE user_id = $1;
-- name: InsertGitAuthLink :one
INSERT INTO git_auth_links (
provider_id,
@ -8,7 +11,9 @@ INSERT INTO git_auth_links (
created_at,
updated_at,
oauth_access_token,
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
) VALUES (
$1,
@ -17,13 +22,17 @@ INSERT INTO git_auth_links (
$4,
$5,
$6,
$7
$7,
$8,
$9
) RETURNING *;
-- name: UpdateGitAuthLink :one
UPDATE git_auth_links SET
updated_at = $3,
oauth_access_token = $4,
oauth_refresh_token = $5,
oauth_expiry = $6
oauth_access_token_key_id = $5,
oauth_refresh_token = $6,
oauth_refresh_token_key_id = $7,
oauth_expiry = $8
WHERE provider_id = $1 AND user_id = $2 RETURNING *;

View File

@ -14,6 +14,9 @@ FROM
WHERE
user_id = $1 AND login_type = $2;
-- name: GetUserLinksByUserID :many
SELECT * FROM user_links WHERE user_id = $1;
-- name: InsertUserLink :one
INSERT INTO
user_links (
@ -21,11 +24,13 @@ INSERT INTO
login_type,
linked_id,
oauth_access_token,
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
)
VALUES
( $1, $2, $3, $4, $5, $6 ) RETURNING *;
( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING *;
-- name: UpdateUserLinkedID :one
UPDATE
@ -40,7 +45,9 @@ UPDATE
user_links
SET
oauth_access_token = $1,
oauth_refresh_token = $2,
oauth_expiry = $3
oauth_access_token_key_id = $2,
oauth_refresh_token = $3,
oauth_refresh_token_key_id = $4,
oauth_expiry = $5
WHERE
user_id = $4 AND login_type = $5 RETURNING *;
user_id = $6 AND login_type = $7 RETURNING *;

View File

@ -40,6 +40,7 @@ overrides:
api_key_scope_application_connect: APIKeyScopeApplicationConnect
avatar_url: AvatarURL
created_by_avatar_url: CreatedByAvatarURL
dbcrypt_key: DBCryptKey
session_count_vscode: SessionCountVSCode
session_count_jetbrains: SessionCountJetBrains
session_count_reconnecting_pty: SessionCountReconnectingPTY
@ -47,9 +48,11 @@ overrides:
connection_median_latency_ms: ConnectionMedianLatencyMS
login_type_oidc: LoginTypeOIDC
oauth_access_token: OAuthAccessToken
oauth_access_token_key_id: OAuthAccessTokenKeyID
oauth_expiry: OAuthExpiry
oauth_id_token: OAuthIDToken
oauth_refresh_token: OAuthRefreshToken
oauth_refresh_token_key_id: OAuthRefreshTokenKeyID
parameter_type_system_hcl: ParameterTypeSystemHCL
userstatus: UserStatus
gitsshkey: GitSSHKey

View File

@ -6,6 +6,8 @@ type UniqueConstraint string
// UniqueConstraint enums.
const (
UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest);
UniqueDbcryptKeysRevokedKeyDigestKey UniqueConstraint = "dbcrypt_keys_revoked_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_revoked_key_digest_key UNIQUE (revoked_key_digest);
UniqueFilesHashCreatedByKey UniqueConstraint = "files_hash_created_by_key" // ALTER TABLE ONLY files ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by);
UniqueGitAuthLinksProviderIDUserIDKey UniqueConstraint = "git_auth_links_provider_id_user_id_key" // ALTER TABLE ONLY git_auth_links ADD CONSTRAINT git_auth_links_provider_id_user_id_key UNIQUE (provider_id, user_id);
UniqueGroupMembersUserIDGroupIDKey UniqueConstraint = "group_members_user_id_group_id_key" // ALTER TABLE ONLY group_members ADD CONSTRAINT group_members_user_id_group_id_key UNIQUE (user_id, group_id);

View File

@ -0,0 +1,98 @@
package dbcrypt
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/xerrors"
)
// cipherAES256GCM is the name of the AES-256 cipher.
// This is used to identify the cipher used to encrypt a value.
// It is added to the digest to ensure that if, in the future,
// we add a new cipher type, and a key is re-used, we don't
// accidentally decrypt the wrong values.
// When adding a new cipher type, add a new constant here
// and ensure to add the cipher name to the digest of the new
// cipher type.
const (
cipherAES256GCM = "aes256gcm"
)
type Cipher interface {
Encrypt([]byte) ([]byte, error)
Decrypt([]byte) ([]byte, error)
HexDigest() string
}
// NewCiphers is a convenience function for creating multiple ciphers.
// It currently only supports AES-256-GCM.
func NewCiphers(keys ...[]byte) ([]Cipher, error) {
var cs []Cipher
for _, key := range keys {
c, err := cipherAES256(key)
if err != nil {
return nil, err
}
cs = append(cs, c)
}
return cs, nil
}
// cipherAES256 returns a new AES-256 cipher.
func cipherAES256(key []byte) (*aes256, error) {
if len(key) != 32 {
return nil, xerrors.Errorf("key must be 32 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
// We add the cipher name to the digest to ensure that if, in the future,
// we add a new cipher type, and a key is re-used, we don't accidentally
// decrypt the wrong values.
toDigest := []byte(cipherAES256GCM)
toDigest = append(toDigest, key...)
digest := fmt.Sprintf("%x", sha256.Sum256(toDigest))[:7]
return &aes256{aead: aead, digest: digest}, nil
}
type aes256 struct {
aead cipher.AEAD
// digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead.
digest string
}
func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) {
nonce := make([]byte, a.aead.NonceSize())
_, err := io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, err
}
dst := make([]byte, len(nonce))
copy(dst, nonce)
return a.aead.Seal(dst, nonce, plaintext, nil), nil
}
func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
if len(ciphertext) < a.aead.NonceSize() {
return nil, xerrors.Errorf("ciphertext too short")
}
decrypted, err := a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil)
if err != nil {
return nil, &DecryptFailedError{Inner: err}
}
return decrypted, nil
}
func (a *aes256) HexDigest() string {
return a.digest
}

View File

@ -0,0 +1,91 @@
package dbcrypt
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/require"
)
func TestCipherAES256(t *testing.T) {
t.Parallel()
t.Run("ValidInput", func(t *testing.T) {
t.Parallel()
key := bytes.Repeat([]byte{'a'}, 32)
cipher, err := cipherAES256(key)
require.NoError(t, err)
output, err := cipher.Encrypt([]byte("hello world"))
require.NoError(t, err)
response, err := cipher.Decrypt(output)
require.NoError(t, err)
require.Equal(t, "hello world", string(response))
})
t.Run("InvalidInput", func(t *testing.T) {
t.Parallel()
key := bytes.Repeat([]byte{'a'}, 32)
cipher, err := cipherAES256(key)
require.NoError(t, err)
_, err = cipher.Decrypt(bytes.Repeat([]byte{'a'}, 100))
var decryptErr *DecryptFailedError
require.ErrorAs(t, err, &decryptErr)
})
t.Run("InvalidKeySize", func(t *testing.T) {
t.Parallel()
_, err := cipherAES256(bytes.Repeat([]byte{'a'}, 31))
require.ErrorContains(t, err, "key must be 32 bytes")
})
t.Run("TestNonce", func(t *testing.T) {
t.Parallel()
key := bytes.Repeat([]byte{'a'}, 32)
cipher, err := cipherAES256(key)
require.NoError(t, err)
require.Equal(t, "864f702", cipher.HexDigest())
encrypted1, err := cipher.Encrypt([]byte("hello world"))
require.NoError(t, err)
encrypted2, err := cipher.Encrypt([]byte("hello world"))
require.NoError(t, err)
require.NotEqual(t, encrypted1, encrypted2, "nonce should be different for each encryption")
munged := make([]byte, len(encrypted1))
copy(munged, encrypted1)
munged[0] = munged[0] ^ 0xff
_, err = cipher.Decrypt(munged)
var decryptErr *DecryptFailedError
require.ErrorAs(t, err, &decryptErr, "munging the first byte of the encrypted data should cause decryption to fail")
})
}
// This test ensures backwards compatibility. If it breaks, something is very wrong.
func TestCiphersBackwardCompatibility(t *testing.T) {
t.Parallel()
var (
msg = "hello world"
key = bytes.Repeat([]byte{'a'}, 32)
//nolint: gosec // The below is the base64-encoded result of encrypting the above message with the above key.
encoded = `YhAz+lE2fFeeiVPH9voKN7UV1xSDrgcnC0LmNXmaAk1Yg0kPFO3x`
)
cipher, err := cipherAES256(key)
require.NoError(t, err)
// This is the code that was used to generate the above.
// Note that the output of this code will change every time it is run.
// encrypted, err := cipher.Encrypt([]byte(msg))
// require.NoError(t, err)
// t.Logf("encoded: %q", base64.StdEncoding.EncodeToString(encrypted))
decoded, err := base64.StdEncoding.DecodeString(encoded)
require.NoError(t, err, "the encoded string should be valid base64")
decrypted, err := cipher.Decrypt(decoded)
require.NoError(t, err, "decryption should succeed")
require.Equal(t, msg, string(decrypted), "decrypted message should match original message")
}

View File

@ -0,0 +1,374 @@
package dbcrypt
import (
"context"
"database/sql"
"encoding/base64"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
// testValue is the value that is stored in dbcrypt_keys.test.
// This is used to determine if the key is valid.
const testValue = "coder"
var (
b64encode = base64.StdEncoding.EncodeToString
b64decode = base64.StdEncoding.DecodeString
)
// DecryptFailedError is returned when decryption fails.
type DecryptFailedError struct {
Inner error
}
func (e *DecryptFailedError) Error() string {
return xerrors.Errorf("decrypt failed: %w", e.Inner).Error()
}
// New creates a database.Store wrapper that encrypts/decrypts values
// stored at rest in the database.
func New(ctx context.Context, db database.Store, ciphers ...Cipher) (database.Store, error) {
cm := make(map[string]Cipher)
for _, c := range ciphers {
cm[c.HexDigest()] = c
}
dbc := &dbCrypt{
ciphers: cm,
Store: db,
}
if len(ciphers) > 0 {
dbc.primaryCipherDigest = ciphers[0].HexDigest()
}
// nolint: gocritic // This is allowed.
authCtx := dbauthz.AsSystemRestricted(ctx)
if err := dbc.ensureEncryptedWithRetry(authCtx); err != nil {
return nil, xerrors.Errorf("ensure encrypted database fields: %w", err)
}
return dbc, nil
}
type dbCrypt struct {
// primaryCipherDigest is the digest of the primary cipher used for encrypting data.
primaryCipherDigest string
// ciphers is a map of cipher digests to ciphers.
ciphers map[string]Cipher
database.Store
}
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptions) error {
return db.Store.InTx(func(s database.Store) error {
return function(&dbCrypt{
primaryCipherDigest: db.primaryCipherDigest,
ciphers: db.ciphers,
Store: s,
})
}, txOpts)
}
func (db *dbCrypt) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
ks, err := db.Store.GetDBCryptKeys(ctx)
if err != nil {
return nil, err
}
// Decrypt the test field to ensure that the key is valid.
for i := range ks {
if !ks[i].ActiveKeyDigest.Valid {
// Key has been revoked. We can't decrypt the test field, but
// we need to return it so that the caller knows that the key
// has been revoked.
continue
}
if err := db.decryptField(&ks[i].Test, ks[i].ActiveKeyDigest); err != nil {
return nil, err
}
}
return ks, nil
}
// This does not need any special handling as it does not touch any encrypted fields.
// Explicitly defining this here to avoid confusion.
func (db *dbCrypt) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
return db.Store.RevokeDBCryptKey(ctx, activeKeyDigest)
}
func (db *dbCrypt) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
// It's nicer to be able to pass a *sql.NullString to encryptField, but we need to pass a string here.
var digest sql.NullString
err := db.encryptField(&arg.Test, &digest)
if err != nil {
return err
}
arg.ActiveKeyDigest = digest.String
return db.Store.InsertDBCryptKey(ctx, arg)
}
func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
links, err := db.Store.GetUserLinksByUserID(ctx, userID)
if err != nil {
return nil, err
}
for idx := range links {
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
return nil, err
}
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
return nil, err
}
}
return links, nil
}
func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
link, err := db.Store.InsertUserLink(ctx, params)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
link, err := db.Store.UpdateUserLink(ctx, params)
if err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.UserLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.UserLink{}, err
}
return link, nil
}
func (db *dbCrypt) InsertGitAuthLink(ctx context.Context, params database.InsertGitAuthLinkParams) (database.GitAuthLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
link, err := db.Store.InsertGitAuthLink(ctx, params)
if err != nil {
return database.GitAuthLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
return link, nil
}
func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
link, err := db.Store.GetGitAuthLink(ctx, params)
if err != nil {
return database.GitAuthLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
return link, nil
}
func (db *dbCrypt) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
links, err := db.Store.GetGitAuthLinksByUserID(ctx, userID)
if err != nil {
return nil, err
}
for idx := range links {
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
return nil, err
}
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
return nil, err
}
}
return links, nil
}
func (db *dbCrypt) UpdateGitAuthLink(ctx context.Context, params database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) {
if err := db.encryptField(&params.OAuthAccessToken, &params.OAuthAccessTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
if err := db.encryptField(&params.OAuthRefreshToken, &params.OAuthRefreshTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
link, err := db.Store.UpdateGitAuthLink(ctx, params)
if err != nil {
return database.GitAuthLink{}, err
}
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
return database.GitAuthLink{}, err
}
return link, nil
}
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
// If no cipher is loaded, then we can't encrypt anything!
if db.ciphers == nil || db.primaryCipherDigest == "" {
return nil
}
if field == nil {
return xerrors.Errorf("developer error: encryptField called with nil field")
}
if digest == nil {
return xerrors.Errorf("developer error: encryptField called with nil digest")
}
encrypted, err := db.ciphers[db.primaryCipherDigest].Encrypt([]byte(*field))
if err != nil {
return err
}
// Base64 is used to support UTF-8 encoding in PostgreSQL.
*field = b64encode(encrypted)
*digest = sql.NullString{String: db.primaryCipherDigest, Valid: true}
return nil
}
// decryptFields decrypts the given field using the key with the given digest.
// If the value fails to decrypt, sql.ErrNoRows will be returned.
func (db *dbCrypt) decryptField(field *string, digest sql.NullString) error {
if field == nil {
return xerrors.Errorf("developer error: decryptField called with nil field")
}
if !digest.Valid || digest.String == "" {
// This field is not encrypted.
return nil
}
key, ok := db.ciphers[digest.String]
if !ok {
return &DecryptFailedError{
Inner: xerrors.Errorf("no cipher with digest %q", digest.String),
}
}
data, err := b64decode(*field)
if err != nil {
// If it's not valid base64, we should complain loudly.
return &DecryptFailedError{
Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err),
}
}
decrypted, err := key.Decrypt(data)
if err != nil {
return &DecryptFailedError{Inner: err}
}
*field = string(decrypted)
return nil
}
func (db *dbCrypt) ensureEncryptedWithRetry(ctx context.Context) error {
var err error
for i := 0; i < 3; i++ {
err = db.ensureEncrypted(ctx)
if err == nil {
return nil
}
// If we get a serialization error, then we need to retry.
if !database.IsSerializedError(err) {
return err
}
// otherwise, retry
}
// If we get here, then we ran out of retries
return err
}
func (db *dbCrypt) ensureEncrypted(ctx context.Context) error {
return db.InTx(func(s database.Store) error {
// Attempt to read the encrypted test fields of the currently active keys.
ks, err := s.GetDBCryptKeys(ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return err
}
var highestNumber int32
var activeCipherFound bool
for _, k := range ks {
// If our primary key has been revoked, then we can't do anything.
if k.RevokedKeyDigest.Valid && k.RevokedKeyDigest.String == db.primaryCipherDigest {
return xerrors.Errorf("primary encryption key %q has been revoked", db.primaryCipherDigest)
}
if k.ActiveKeyDigest.Valid && k.ActiveKeyDigest.String == db.primaryCipherDigest {
activeCipherFound = true
}
if k.Number > highestNumber {
highestNumber = k.Number
}
}
if activeCipherFound || len(db.ciphers) == 0 {
return nil
}
// If we get here, then we have a new key that we need to insert.
return db.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: highestNumber + 1,
ActiveKeyDigest: db.primaryCipherDigest,
Test: testValue,
})
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
}

View File

@ -0,0 +1,679 @@
package dbcrypt
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"io"
"testing"
"github.com/golang/mock/gomock"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
)
func TestUserLinks(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("InsertUserLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("UpdateUserLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
})
updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
require.Equal(t, "access", updated.OAuthAccessToken)
require.Equal(t, "refresh", updated.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("GetUserLinkByLinkedID", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
link, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
link := dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
t.Run("GetUserLinksByUserID", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
links, err := crypt.GetUserLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, links, 1)
require.Equal(t, "access", links[0].OAuthAccessToken)
require.Equal(t, "refresh", links[0].OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String)
rawLinks, err := db.GetUserLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, rawLinks, 1)
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh")
})
t.Run("Empty", func(t *testing.T) {
t.Parallel()
_, crypt, _ := setup(t)
user := dbgen.User(t, crypt, database.User{})
links, err := crypt.GetUserLinksByUserID(ctx, user.ID)
require.NoError(t, err)
require.Empty(t, links)
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserLinksByUserID(ctx, user.ID)
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.UserLink(t, crypt, database.UserLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
link, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
rawLink, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
link := dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: link.UserID,
LoginType: link.LoginType,
})
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
}
func TestGitAuthLinks(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("InsertGitAuthLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
require.Equal(t, "access", link.OAuthAccessToken)
require.Equal(t, "refresh", link.OAuthRefreshToken)
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
})
t.Run("UpdateGitAuthLink", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{})
updated, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
require.NoError(t, err)
require.Equal(t, "access", updated.OAuthAccessToken)
require.Equal(t, "refresh", updated.OAuthRefreshToken)
link, err = db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
ProviderID: link.ProviderID,
UserID: link.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
})
t.Run("GetGitAuthLink", func(t *testing.T) {
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
UserID: link.UserID,
ProviderID: link.ProviderID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
UserID: link.UserID,
ProviderID: link.ProviderID,
})
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
t.Run("GetGitAuthLinksByUserID", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
user := dbgen.User(t, crypt, database.User{})
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
UserID: user.ID,
OAuthAccessToken: "access",
OAuthRefreshToken: "refresh",
})
links, err := crypt.GetGitAuthLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, links, 1)
require.Equal(t, "access", links[0].OAuthAccessToken)
require.Equal(t, "refresh", links[0].OAuthRefreshToken)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String)
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String)
rawLinks, err := db.GetGitAuthLinksByUserID(ctx, link.UserID)
require.NoError(t, err)
require.Len(t, rawLinks, 1)
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access")
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh")
})
t.Run("DecryptErr", func(t *testing.T) {
db, crypt, ciphers := setup(t)
user := dbgen.User(t, db, database.User{})
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
UserID: user.ID,
OAuthAccessToken: fakeBase64RandomData(t, 32),
OAuthRefreshToken: fakeBase64RandomData(t, 32),
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
})
_, err := crypt.GetGitAuthLinksByUserID(ctx, link.UserID)
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
})
})
}
func TestNew(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
// Given: a cipher is loaded
cipher := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
// Before: no keys should be present
keys, err := rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Empty(t, keys, "no keys should be present")
// When: we init the crypt db
_, err = New(ctx, rawDB, cipher)
require.NoError(t, err)
// Then: a new key is inserted
keys, err = rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1, "one key should be present")
require.Equal(t, cipher.HexDigest(), keys[0].ActiveKeyDigest.String, "key digest mismatch")
require.Empty(t, keys[0].RevokedKeyDigest.String, "key should not be revoked")
requireEncryptedEquals(t, cipher, keys[0].Test, "coder")
})
t.Run("MissingKey", func(t *testing.T) {
t.Parallel()
// Given: there exist two valid encryption keys
cipher1 := initCipher(t)
cipher2 := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
// Given: key 1 is already present in the database
err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: 1,
ActiveKeyDigest: cipher1.HexDigest(),
Test: fakeBase64RandomData(t, 32),
})
require.NoError(t, err, "no error should be returned")
keys, err := rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Len(t, keys, 1, "one key should be present")
// When: we init the crypt db with no keys
_, err = New(ctx, rawDB)
// Then: we error because we don't know how to decrypt the existing key
require.Error(t, err, "expected an error")
var derr *DecryptFailedError
require.ErrorAs(t, err, &derr, "expected a decrypt error")
// When: we init the crypt db with key 2
_, err = New(ctx, rawDB, cipher2)
// Then: we error because the key is not revoked and we don't know how to decrypt it
require.Error(t, err, "expected an error")
require.ErrorAs(t, err, &derr, "expected a decrypt error")
// When: the existing key is marked as having been revoked
err = rawDB.RevokeDBCryptKey(ctx, cipher1.HexDigest())
require.NoError(t, err, "no error should be returned")
// And: we init the crypt db with key 2
_, err = New(ctx, rawDB, cipher2)
// Then: we succeed
require.NoError(t, err)
// And: key 2 should now be the active key
keys, err = rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 2, "two keys should be present")
require.EqualValues(t, keys[0].Number, 1, "key number mismatch")
require.Empty(t, keys[0].ActiveKeyDigest.String, "key should not be active")
require.Equal(t, cipher1.HexDigest(), keys[0].RevokedKeyDigest.String, "key should be revoked")
require.EqualValues(t, keys[1].Number, 2, "key number mismatch")
require.Equal(t, cipher2.HexDigest(), keys[1].ActiveKeyDigest.String, "key digest mismatch")
require.Empty(t, keys[1].RevokedKeyDigest.String, "key should not be revoked")
requireEncryptedEquals(t, cipher2, keys[1].Test, "coder")
})
t.Run("NoKeys", func(t *testing.T) {
t.Parallel()
// Given: no cipher is loaded
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
keys, err := rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Empty(t, keys, "no keys should be present")
// When: we init the crypt db with no ciphers
_, err = New(ctx, rawDB)
// Then: it should succeed.
require.NoError(t, err, "dbcrypt.New should work with no keys against an unencrypted database")
// Assert invariant: no keys are inserted
keys, err = rawDB.GetDBCryptKeys(ctx)
require.NoError(t, err, "no error should be returned")
require.Empty(t, keys, "no keys should be present")
// Insert a key
require.NoError(t, rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: 1,
ActiveKeyDigest: "whatever",
Test: fakeBase64RandomData(t, 32),
}))
// This should fail as we do not know how to decrypt the key:
_, err = New(ctx, rawDB)
require.Error(t, err)
// Until we revoke the key:
require.NoError(t, rawDB.RevokeDBCryptKey(ctx, "whatever"))
_, err = New(ctx, rawDB)
require.NoError(t, err, "the above should still hold if the key is revoked")
})
t.Run("PrimaryRevoked", func(t *testing.T) {
t.Parallel()
// Given: a cipher is loaded
cipher := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
// And: the cipher is revoked before we init the crypt db
err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
Number: 1,
ActiveKeyDigest: cipher.HexDigest(),
Test: fakeBase64RandomData(t, 32),
})
require.NoError(t, err, "no error should be returned")
err = rawDB.RevokeDBCryptKey(ctx, cipher.HexDigest())
require.NoError(t, err, "no error should be returned")
// Then: when we init the crypt db, we error because the key is revoked
_, err = New(ctx, rawDB, cipher)
require.Error(t, err)
require.ErrorContains(t, err, "has been revoked")
})
t.Run("Retry", func(t *testing.T) {
t.Parallel()
// Given: a cipher is loaded
cipher := initCipher(t)
ctx, cancel := context.WithCancel(context.Background())
testVal, err := cipher.Encrypt([]byte("coder"))
key := database.DBCryptKey{
Number: 1,
ActiveKeyDigest: sql.NullString{String: cipher.HexDigest(), Valid: true},
Test: b64encode(testVal),
}
require.NoError(t, err)
t.Cleanup(cancel)
// And: a database that returns an error once when we try to serialize a key
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
gomock.InOrder(
// First try: we get a serialization error.
expectInTx(mockDB),
mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{}, nil),
mockDB.EXPECT().InsertDBCryptKey(gomock.Any(), gomock.Any()).Times(1).Return(&pq.Error{Code: "40001"}),
// Second try: we get the key we wanted to insert initially.
expectInTx(mockDB),
mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{key}, nil),
)
_, err = New(ctx, mockDB, cipher)
require.NoError(t, err)
})
}
func TestEncryptDecryptField(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := "coder"
digest := sql.NullString{}
require.NoError(t, cryptDB.encryptField(&field, &digest))
require.Equal(t, ciphers[0].HexDigest(), digest.String)
requireEncryptedEquals(t, ciphers[0], field, "coder")
require.NoError(t, cryptDB.decryptField(&field, digest))
require.Equal(t, "coder", field)
})
t.Run("NoKeys", func(t *testing.T) {
t.Parallel()
// With no keys, encryption and decryption are both no-ops.
_, cryptDB := setupNoCiphers(t)
field := "coder"
digest := sql.NullString{}
require.NoError(t, cryptDB.encryptField(&field, &digest))
require.Empty(t, digest.String)
require.Equal(t, "coder", field)
require.NoError(t, cryptDB.decryptField(&field, digest))
require.Equal(t, "coder", field)
})
t.Run("MissingKey", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := "coder"
digest := sql.NullString{}
err := cryptDB.encryptField(&field, &digest)
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], field, "coder")
require.Equal(t, ciphers[0].HexDigest(), digest.String)
require.True(t, digest.Valid)
digest = sql.NullString{String: "missing", Valid: true}
var derr *DecryptFailedError
err = cryptDB.decryptField(&field, digest)
require.Error(t, err)
require.ErrorAs(t, err, &derr)
})
t.Run("CantEncryptOrDecryptNil", func(t *testing.T) {
t.Parallel()
_, cryptDB, _ := setup(t)
require.ErrorContains(t, cryptDB.encryptField(nil, nil), "developer error")
require.ErrorContains(t, cryptDB.decryptField(nil, sql.NullString{}), "developer error")
})
t.Run("EncryptEmptyString", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := ""
digest := sql.NullString{}
require.NoError(t, cryptDB.encryptField(&field, &digest))
requireEncryptedEquals(t, ciphers[0], field, "")
require.Equal(t, ciphers[0].HexDigest(), digest.String)
require.NoError(t, cryptDB.decryptField(&field, digest))
require.Empty(t, field)
})
t.Run("DecryptEmptyString", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := ""
digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true}
err := cryptDB.decryptField(&field, digest)
// Currently this has to fail because the ciphertext must at least
// have a nonce. This may need to be changed depending on future
// ciphers.
require.ErrorContains(t, err, "ciphertext too short")
})
t.Run("InvalidBase64", func(t *testing.T) {
t.Parallel()
_, cryptDB, ciphers := setup(t)
field := "not valid base64"
digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true}
err := cryptDB.decryptField(&field, digest)
require.ErrorContains(t, err, "illegal base64 data")
})
}
func expectInTx(mdb *dbmock.MockStore) *gomock.Call {
return mdb.EXPECT().InTx(gomock.Any(), gomock.Any()).Times(1).DoAndReturn(
func(f func(store database.Store) error, _ *sql.TxOptions) error {
return f(mdb)
},
)
}
func requireEncryptedEquals(t *testing.T, c Cipher, value, expected string) {
t.Helper()
data, err := base64.StdEncoding.DecodeString(value)
require.NoError(t, err, "invalid base64")
got, err := c.Decrypt(data)
require.NoError(t, err, "failed to decrypt data")
require.Equal(t, expected, string(got), "decrypted data does not match")
}
func initCipher(t *testing.T) *aes256 {
t.Helper()
key := make([]byte, 32) // AES-256 key size is 32 bytes
_, err := io.ReadFull(rand.Reader, key)
require.NoError(t, err)
c, err := cipherAES256(key)
require.NoError(t, err)
return c
}
func setup(t *testing.T) (db database.Store, cryptDB *dbCrypt, cs []Cipher) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
cs = append(cs, initCipher(t))
cdb, err := New(ctx, rawDB, cs...)
require.NoError(t, err)
cryptDB, ok := cdb.(*dbCrypt)
require.True(t, ok)
return rawDB, cryptDB, cs
}
func setupNoCiphers(t *testing.T) (db database.Store, cryptodb *dbCrypt) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
rawDB, _ := dbtestutil.NewDB(t)
cdb, err := New(ctx, rawDB)
require.NoError(t, err)
cryptDB, ok := cdb.(*dbCrypt)
require.True(t, ok)
return rawDB, cryptDB
}
func fakeBase64RandomData(t *testing.T, n int) string {
t.Helper()
b := make([]byte, n)
_, err := io.ReadFull(rand.Reader, b)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(b)
}

34
enterprise/dbcrypt/doc.go Normal file
View File

@ -0,0 +1,34 @@
// Package dbcrypt provides a database.Store wrapper that encrypts/decrypts
// values stored at rest in the database.
//
// Encryption is done using Ciphers, which is an abstraction over a set of
// encryption keys. Each key has a unique identifier, which is used to
// uniquely identify the key whilst maintaining secrecy.
//
// Currently, AES-256-GCM is the only implemented cipher mode.
// The Cipher is currently used to encrypt/decrypt the following fields:
// - database.UserLink.OAuthAccessToken
// - database.UserLink.OAuthRefreshToken
// - database.GitAuthLink.OAuthAccessToken
// - database.GitAuthLink.OAuthRefreshToken
// - database.DBCryptSentinelValue
//
// Multiple ciphers can be provided to support key rotation. The primary cipher
// is used to encrypt and decrypt all data. Secondary ciphers are only used
// for decryption and, as a general rule, should only be active when rotating
// keys.
//
// Encryption keys are stored in the database in the table `dbcrypt_keys`.
// The table has the following schema:
// - number: the key number. This is used to avoid conflicts when rotating keys.
// - created_at: the time the key was created.
// - active_key_digest: the SHA256 digest of the active key. If null, the key has been revoked.
// - revoked_key_digest: the SHA256 digest of the revoked key. If null, the key has not been revoked.
// - revoked_at: the time the key was revoked. If null, the key has not been revoked.
// - test: the encrypted value of the string "coder". This is used to ensure that the key is valid.
//
// Encrypted fields are stored in the database as a base64-encoded string.
// Each encrypted column MUST have a corresponding _key_id column that is a foreign key
// reference to `dbcrypt_keys.active_key_digest`. This ensures that a key cannot be
// revoked until all rows that use that key have been migrated to a new key.
package dbcrypt