mirror of https://github.com/coder/coder.git
feat(coderd): add dbcrypt package (#9522)
- Adds package enterprise/dbcrypt to implement database encryption/decryption - Adds table dbcrypt_keys and associated queries - Adds columns oauth_access_token_key_id and oauth_refresh_token_key_id to tables git_auth_links and user_links Co-authored-by: Kyle Carberry <kyle@coder.com>
This commit is contained in:
parent
3bd0fd396c
commit
7918e65510
|
@ -838,6 +838,13 @@ func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUI
|
|||
return q.db.GetAuthorizationUserRoles(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetDBCryptKeys(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return "", err
|
||||
|
@ -914,6 +921,13 @@ func (q *querier) GetGitAuthLink(ctx context.Context, arg database.GetGitAuthLin
|
|||
return fetch(q.log, q.auth, q.db.GetGitAuthLink)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetGitAuthLinksByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetGitSSHKey)(ctx, userID)
|
||||
}
|
||||
|
@ -1482,6 +1496,13 @@ func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database
|
|||
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetUserLinksByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
|
||||
// This does the filtering in SQL.
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, rbac.ActionRead, rbac.ResourceUser.Type)
|
||||
|
@ -1845,6 +1866,13 @@ func (q *querier) InsertAuditLog(ctx context.Context, arg database.InsertAuditLo
|
|||
return insert(q.log, q.auth, rbac.ResourceAuditLog, q.db.InsertAuditLog)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.InsertDBCryptKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
|
@ -2144,6 +2172,13 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
|
|||
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
|
|
|
@ -31,6 +31,11 @@ import (
|
|||
|
||||
var validProxyByHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
|
||||
|
||||
var errForeignKeyConstraint = &pq.Error{
|
||||
Code: "23503",
|
||||
Message: "update or delete on table violates foreign key constraint",
|
||||
}
|
||||
|
||||
var errDuplicateKey = &pq.Error{
|
||||
Code: "23505",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
|
@ -45,6 +50,7 @@ func New() database.Store {
|
|||
organizationMembers: make([]database.OrganizationMember, 0),
|
||||
organizations: make([]database.Organization, 0),
|
||||
users: make([]database.User, 0),
|
||||
dbcryptKeys: make([]database.DBCryptKey, 0),
|
||||
gitAuthLinks: make([]database.GitAuthLink, 0),
|
||||
groups: make([]database.Group, 0),
|
||||
groupMembers: make([]database.GroupMember, 0),
|
||||
|
@ -117,6 +123,7 @@ type data struct {
|
|||
// New tables
|
||||
workspaceAgentStats []database.WorkspaceAgentStat
|
||||
auditLogs []database.AuditLog
|
||||
dbcryptKeys []database.DBCryptKey
|
||||
files []database.File
|
||||
gitAuthLinks []database.GitAuthLink
|
||||
gitSSHKey []database.GitSSHKey
|
||||
|
@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetActiveDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
ks := make([]database.DBCryptKey, 0, len(q.dbcryptKeys))
|
||||
for _, k := range q.dbcryptKeys {
|
||||
if !k.ActiveKeyDigest.Valid {
|
||||
continue
|
||||
}
|
||||
ks = append([]database.DBCryptKey{}, k)
|
||||
}
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
|
||||
return xerrors.New("AcquireLock must only be called within a transaction")
|
||||
}
|
||||
|
@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetDBCryptKeys(_ context.Context) ([]database.DBCryptKey, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
ks := make([]database.DBCryptKey, 0)
|
||||
ks = append(ks, q.dbcryptKeys...)
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
|
|||
return database.GitAuthLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetGitAuthLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
gals := make([]database.GitAuthLink, 0)
|
||||
for _, gal := range q.gitAuthLinks {
|
||||
if gal.UserID == userID {
|
||||
gals = append(gals, gal)
|
||||
}
|
||||
}
|
||||
return gals, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetGitSSHKey(_ context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
|
|||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetUserLinksByUserID(_ context.Context, userID uuid.UUID) ([]database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
uls := make([]database.UserLink, 0)
|
||||
for _, ul := range q.userLinks {
|
||||
if ul.UserID == userID {
|
||||
uls = append(uls, ul)
|
||||
}
|
||||
}
|
||||
return uls, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
|
||||
if err := validateDatabaseType(params); err != nil {
|
||||
return nil, err
|
||||
|
@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
|
|||
return alog, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) InsertDBCryptKey(_ context.Context, arg database.InsertDBCryptKeyParams) error {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range q.dbcryptKeys {
|
||||
if key.Number == arg.Number {
|
||||
return errDuplicateKey
|
||||
}
|
||||
}
|
||||
|
||||
q.dbcryptKeys = append(q.dbcryptKeys, database.DBCryptKey{
|
||||
Number: arg.Number,
|
||||
ActiveKeyDigest: sql.NullString{String: arg.ActiveKeyDigest, Valid: true},
|
||||
Test: arg.Test,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
|
|||
defer q.mutex.Unlock()
|
||||
// nolint:gosimple
|
||||
gitAuthLink := database.GitAuthLink{
|
||||
ProviderID: arg.ProviderID,
|
||||
UserID: arg.UserID,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
OAuthAccessToken: arg.OAuthAccessToken,
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
ProviderID: arg.ProviderID,
|
||||
UserID: arg.UserID,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
OAuthAccessToken: arg.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: arg.OAuthAccessTokenKeyID,
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
}
|
||||
q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink)
|
||||
return gitAuthLink, nil
|
||||
|
@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
|
|||
|
||||
//nolint:gosimple
|
||||
link := database.UserLink{
|
||||
UserID: args.UserID,
|
||||
LoginType: args.LoginType,
|
||||
LinkedID: args.LinkedID,
|
||||
OAuthAccessToken: args.OAuthAccessToken,
|
||||
OAuthRefreshToken: args.OAuthRefreshToken,
|
||||
OAuthExpiry: args.OAuthExpiry,
|
||||
UserID: args.UserID,
|
||||
LoginType: args.LoginType,
|
||||
LinkedID: args.LinkedID,
|
||||
OAuthAccessToken: args.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: args.OAuthAccessTokenKeyID,
|
||||
OAuthRefreshToken: args.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: args.OAuthRefreshTokenKeyID,
|
||||
OAuthExpiry: args.OAuthExpiry,
|
||||
}
|
||||
|
||||
q.userLinks = append(q.userLinks, link)
|
||||
|
@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
|
|||
return database.WorkspaceProxy{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i := range q.dbcryptKeys {
|
||||
key := q.dbcryptKeys[i]
|
||||
|
||||
// Is the key already revoked?
|
||||
if !key.ActiveKeyDigest.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
if key.ActiveKeyDigest.String != activeKeyDigest {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for foreign key constraints.
|
||||
for _, ul := range q.userLinks {
|
||||
if (ul.OAuthAccessTokenKeyID.Valid && ul.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
|
||||
(ul.OAuthRefreshTokenKeyID.Valid && ul.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
|
||||
return errForeignKeyConstraint
|
||||
}
|
||||
}
|
||||
for _, gal := range q.gitAuthLinks {
|
||||
if (gal.OAuthAccessTokenKeyID.Valid && gal.OAuthAccessTokenKeyID.String == activeKeyDigest) ||
|
||||
(gal.OAuthRefreshTokenKeyID.Valid && gal.OAuthRefreshTokenKeyID.String == activeKeyDigest) {
|
||||
return errForeignKeyConstraint
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke the key.
|
||||
q.dbcryptKeys[i].RevokedAt = sql.NullTime{Time: dbtime.Now(), Valid: true}
|
||||
q.dbcryptKeys[i].RevokedKeyDigest = sql.NullString{String: key.ActiveKeyDigest.String, Valid: true}
|
||||
q.dbcryptKeys[i].ActiveKeyDigest = sql.NullString{}
|
||||
return nil
|
||||
}
|
||||
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) {
|
||||
return false, xerrors.New("TryAcquireLock must only be called within a transaction")
|
||||
}
|
||||
|
@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
|
|||
}
|
||||
gitAuthLink.UpdatedAt = arg.UpdatedAt
|
||||
gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken
|
||||
gitAuthLink.OAuthAccessTokenKeyID = arg.OAuthAccessTokenKeyID
|
||||
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
|
||||
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
|
||||
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
|
||||
q.gitAuthLinks[index] = gitAuthLink
|
||||
|
||||
|
@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
|
|||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
link.OAuthAccessToken = params.OAuthAccessToken
|
||||
link.OAuthAccessTokenKeyID = params.OAuthAccessTokenKeyID
|
||||
link.OAuthRefreshToken = params.OAuthRefreshToken
|
||||
link.OAuthRefreshTokenKeyID = params.OAuthRefreshTokenKeyID
|
||||
link.OAuthExpiry = params.OAuthExpiry
|
||||
|
||||
q.userLinks[i] = link
|
||||
|
|
|
@ -470,12 +470,14 @@ func File(t testing.TB, db database.Store, orig database.File) database.File {
|
|||
|
||||
func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.UserLink {
|
||||
link, err := db.InsertUserLink(genCtx, database.InsertUserLinkParams{
|
||||
UserID: takeFirst(orig.UserID, uuid.New()),
|
||||
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
|
||||
LinkedID: takeFirst(orig.LinkedID),
|
||||
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
|
||||
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
|
||||
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
|
||||
UserID: takeFirst(orig.UserID, uuid.New()),
|
||||
LoginType: takeFirst(orig.LoginType, database.LoginTypeGithub),
|
||||
LinkedID: takeFirst(orig.LinkedID),
|
||||
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
|
||||
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
|
||||
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
|
||||
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
|
||||
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
|
||||
})
|
||||
|
||||
require.NoError(t, err, "insert link")
|
||||
|
@ -484,13 +486,15 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
|
|||
|
||||
func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) database.GitAuthLink {
|
||||
link, err := db.InsertGitAuthLink(genCtx, database.InsertGitAuthLinkParams{
|
||||
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
|
||||
UserID: takeFirst(orig.UserID, uuid.New()),
|
||||
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
|
||||
OAuthRefreshToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
|
||||
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
|
||||
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
|
||||
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
|
||||
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
|
||||
UserID: takeFirst(orig.UserID, uuid.New()),
|
||||
OAuthAccessToken: takeFirst(orig.OAuthAccessToken, uuid.NewString()),
|
||||
OAuthAccessTokenKeyID: takeFirst(orig.OAuthAccessTokenKeyID, sql.NullString{}),
|
||||
OAuthRefreshToken: takeFirst(orig.OAuthRefreshToken, uuid.NewString()),
|
||||
OAuthRefreshTokenKeyID: takeFirst(orig.OAuthRefreshTokenKeyID, sql.NullString{}),
|
||||
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
|
||||
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
|
||||
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
|
||||
})
|
||||
|
||||
require.NoError(t, err, "insert git auth link")
|
||||
|
|
|
@ -279,6 +279,13 @@ func (m metricsStore) GetAuthorizationUserRoles(ctx context.Context, userID uuid
|
|||
return row, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetDBCryptKeys(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetDBCryptKeys").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) GetDERPMeshKey(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
key, err := m.s.GetDERPMeshKey(ctx)
|
||||
|
@ -349,6 +356,13 @@ func (m metricsStore) GetGitAuthLink(ctx context.Context, arg database.GetGitAut
|
|||
return link, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetGitAuthLinksByUserID(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetGitAuthLinksByUserID").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
start := time.Now()
|
||||
key, err := m.s.GetGitSSHKey(ctx, userID)
|
||||
|
@ -774,6 +788,13 @@ func (m metricsStore) GetUserLinkByUserIDLoginType(ctx context.Context, arg data
|
|||
return link, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserLinksByUserID(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserLinksByUserID").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) GetUsers(ctx context.Context, arg database.GetUsersParams) ([]database.GetUsersRow, error) {
|
||||
start := time.Now()
|
||||
users, err := m.s.GetUsers(ctx, arg)
|
||||
|
@ -1068,6 +1089,13 @@ func (m metricsStore) InsertAuditLog(ctx context.Context, arg database.InsertAud
|
|||
return log, err
|
||||
}
|
||||
|
||||
func (m metricsStore) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.InsertDBCryptKey(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertDBCryptKey").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) InsertDERPMeshKey(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
err := m.s.InsertDERPMeshKey(ctx, value)
|
||||
|
@ -1320,6 +1348,13 @@ func (m metricsStore) RegisterWorkspaceProxy(ctx context.Context, arg database.R
|
|||
return proxy, err
|
||||
}
|
||||
|
||||
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
m.queryLatencies.WithLabelValues("RevokeDBCryptKey").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
ok, err := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
|
|
|
@ -506,6 +506,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaces(arg0, arg1, arg2 interf
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaces", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaces), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetDBCryptKeys mocks base method.
|
||||
func (m *MockStore) GetDBCryptKeys(arg0 context.Context) ([]database.DBCryptKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDBCryptKeys", arg0)
|
||||
ret0, _ := ret[0].([]database.DBCryptKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetDBCryptKeys indicates an expected call of GetDBCryptKeys.
|
||||
func (mr *MockStoreMockRecorder) GetDBCryptKeys(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDBCryptKeys", reflect.TypeOf((*MockStore)(nil).GetDBCryptKeys), arg0)
|
||||
}
|
||||
|
||||
// GetDERPMeshKey mocks base method.
|
||||
func (m *MockStore) GetDERPMeshKey(arg0 context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -656,6 +671,21 @@ func (mr *MockStoreMockRecorder) GetGitAuthLink(arg0, arg1 interface{}) *gomock.
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLink", reflect.TypeOf((*MockStore)(nil).GetGitAuthLink), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetGitAuthLinksByUserID mocks base method.
|
||||
func (m *MockStore) GetGitAuthLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.GitAuthLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGitAuthLinksByUserID", arg0, arg1)
|
||||
ret0, _ := ret[0].([]database.GitAuthLink)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGitAuthLinksByUserID indicates an expected call of GetGitAuthLinksByUserID.
|
||||
func (mr *MockStoreMockRecorder) GetGitAuthLinksByUserID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGitAuthLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetGitAuthLinksByUserID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetGitSSHKey mocks base method.
|
||||
func (m *MockStore) GetGitSSHKey(arg0 context.Context, arg1 uuid.UUID) (database.GitSSHKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -1601,6 +1631,21 @@ func (mr *MockStoreMockRecorder) GetUserLinkByUserIDLoginType(arg0, arg1 interfa
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinkByUserIDLoginType", reflect.TypeOf((*MockStore)(nil).GetUserLinkByUserIDLoginType), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetUserLinksByUserID mocks base method.
|
||||
func (m *MockStore) GetUserLinksByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.UserLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserLinksByUserID", arg0, arg1)
|
||||
ret0, _ := ret[0].([]database.UserLink)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserLinksByUserID indicates an expected call of GetUserLinksByUserID.
|
||||
func (mr *MockStoreMockRecorder) GetUserLinksByUserID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserLinksByUserID", reflect.TypeOf((*MockStore)(nil).GetUserLinksByUserID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetUsers mocks base method.
|
||||
func (m *MockStore) GetUsers(arg0 context.Context, arg1 database.GetUsersParams) ([]database.GetUsersRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -2245,6 +2290,20 @@ func (mr *MockStoreMockRecorder) InsertAuditLog(arg0, arg1 interface{}) *gomock.
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAuditLog", reflect.TypeOf((*MockStore)(nil).InsertAuditLog), arg0, arg1)
|
||||
}
|
||||
|
||||
// InsertDBCryptKey mocks base method.
|
||||
func (m *MockStore) InsertDBCryptKey(arg0 context.Context, arg1 database.InsertDBCryptKeyParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertDBCryptKey", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// InsertDBCryptKey indicates an expected call of InsertDBCryptKey.
|
||||
func (mr *MockStoreMockRecorder) InsertDBCryptKey(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertDBCryptKey", reflect.TypeOf((*MockStore)(nil).InsertDBCryptKey), arg0, arg1)
|
||||
}
|
||||
|
||||
// InsertDERPMeshKey mocks base method.
|
||||
func (m *MockStore) InsertDERPMeshKey(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -2789,6 +2848,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 interface{})
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1)
|
||||
}
|
||||
|
||||
// RevokeDBCryptKey mocks base method.
|
||||
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RevokeDBCryptKey", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RevokeDBCryptKey indicates an expected call of RevokeDBCryptKey.
|
||||
func (mr *MockStoreMockRecorder) RevokeDBCryptKey(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeDBCryptKey", reflect.TypeOf((*MockStore)(nil).RevokeDBCryptKey), arg0, arg1)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(arg0 context.Context, arg1 int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -275,6 +275,29 @@ CREATE TABLE audit_logs (
|
|||
resource_icon text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE dbcrypt_keys (
|
||||
number integer NOT NULL,
|
||||
active_key_digest text,
|
||||
revoked_key_digest text,
|
||||
created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP,
|
||||
revoked_at timestamp with time zone,
|
||||
test text NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE dbcrypt_keys IS 'A table used to store the keys used to encrypt the database.';
|
||||
|
||||
COMMENT ON COLUMN dbcrypt_keys.number IS 'An integer used to identify the key.';
|
||||
|
||||
COMMENT ON COLUMN dbcrypt_keys.active_key_digest IS 'If the key is active, the digest of the active key.';
|
||||
|
||||
COMMENT ON COLUMN dbcrypt_keys.revoked_key_digest IS 'If the key has been revoked, the digest of the revoked key.';
|
||||
|
||||
COMMENT ON COLUMN dbcrypt_keys.created_at IS 'The time at which the key was created.';
|
||||
|
||||
COMMENT ON COLUMN dbcrypt_keys.revoked_at IS 'The time at which the key was revoked.';
|
||||
|
||||
COMMENT ON COLUMN dbcrypt_keys.test IS 'A column used to test the encryption.';
|
||||
|
||||
CREATE TABLE files (
|
||||
hash character varying(64) NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
|
@ -291,9 +314,15 @@ CREATE TABLE git_auth_links (
|
|||
updated_at timestamp with time zone NOT NULL,
|
||||
oauth_access_token text NOT NULL,
|
||||
oauth_refresh_token text NOT NULL,
|
||||
oauth_expiry timestamp with time zone NOT NULL
|
||||
oauth_expiry timestamp with time zone NOT NULL,
|
||||
oauth_access_token_key_id text,
|
||||
oauth_refresh_token_key_id text
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN git_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
|
||||
|
||||
COMMENT ON COLUMN git_auth_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
|
||||
|
||||
CREATE TABLE gitsshkeys (
|
||||
user_id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
|
@ -701,9 +730,15 @@ CREATE TABLE user_links (
|
|||
linked_id text DEFAULT ''::text NOT NULL,
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
oauth_access_token_key_id text,
|
||||
oauth_refresh_token_key_id text
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
|
||||
|
||||
COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
|
||||
|
||||
CREATE TABLE workspace_agent_logs (
|
||||
agent_id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
|
@ -1037,6 +1072,15 @@ ALTER TABLE ONLY api_keys
|
|||
ALTER TABLE ONLY audit_logs
|
||||
ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY dbcrypt_keys
|
||||
ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY dbcrypt_keys
|
||||
ADD CONSTRAINT dbcrypt_keys_pkey PRIMARY KEY (number);
|
||||
|
||||
ALTER TABLE ONLY dbcrypt_keys
|
||||
ADD CONSTRAINT dbcrypt_keys_revoked_key_digest_key UNIQUE (revoked_key_digest);
|
||||
|
||||
ALTER TABLE ONLY files
|
||||
ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by);
|
||||
|
||||
|
@ -1249,6 +1293,12 @@ CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW
|
|||
ALTER TABLE ONLY api_keys
|
||||
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY git_auth_links
|
||||
ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY git_auth_links
|
||||
ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY gitsshkeys
|
||||
ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
|
||||
|
||||
|
@ -1303,6 +1353,12 @@ ALTER TABLE ONLY templates
|
|||
ALTER TABLE ONLY templates
|
||||
ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY user_links
|
||||
ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_links
|
||||
ADD CONSTRAINT user_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY user_links
|
||||
ADD CONSTRAINT user_links_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
BEGIN;
|
||||
|
||||
-- Before dropping this table, we need to check if there exist any
|
||||
-- foreign key references to it. We do this by checking the following:
|
||||
-- user_links.oauth_access_token_key_id
|
||||
-- user_links.oauth_refresh_token_key_id
|
||||
-- git_auth_links.oauth_access_token_key_id
|
||||
-- git_auth_links.oauth_refresh_token_key_id
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT *
|
||||
FROM user_links
|
||||
WHERE oauth_access_token_key_id IS NOT NULL
|
||||
OR oauth_refresh_token_key_id IS NOT NULL
|
||||
) THEN RAISE EXCEPTION 'Cannot drop dbcrypt_keys table as there are still foreign key references to it from user_links.';
|
||||
END IF;
|
||||
|
||||
IF EXISTS (
|
||||
SELECT *
|
||||
FROM git_auth_links
|
||||
WHERE oauth_access_token_key_id IS NOT NULL
|
||||
OR oauth_refresh_token_key_id IS NOT NULL
|
||||
) THEN RAISE EXCEPTION 'Cannot drop dbcrypt_keys table as there are still foreign key references to it from git_auth_links.';
|
||||
END IF;
|
||||
|
||||
END
|
||||
$$;
|
||||
|
||||
|
||||
-- Drop the columns first.
|
||||
ALTER TABLE git_auth_links
|
||||
DROP COLUMN IF EXISTS oauth_access_token_key_id,
|
||||
DROP COLUMN IF EXISTS oauth_refresh_token_key_id;
|
||||
|
||||
ALTER TABLE user_links
|
||||
DROP COLUMN IF EXISTS oauth_access_token_key_id,
|
||||
DROP COLUMN IF EXISTS oauth_refresh_token_key_id;
|
||||
|
||||
-- Finally, drop the table.
|
||||
DROP TABLE IF EXISTS dbcrypt_keys;
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,30 @@
|
|||
CREATE TABLE IF NOT EXISTS dbcrypt_keys (
|
||||
number int NOT NULL PRIMARY KEY,
|
||||
active_key_digest text UNIQUE,
|
||||
revoked_key_digest text UNIQUE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
revoked_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
test TEXT NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE dbcrypt_keys IS 'A table used to store the keys used to encrypt the database.';
|
||||
COMMENT ON COLUMN dbcrypt_keys.number IS 'An integer used to identify the key.';
|
||||
COMMENT ON COLUMN dbcrypt_keys.active_key_digest IS 'If the key is active, the digest of the active key.';
|
||||
COMMENT ON COLUMN dbcrypt_keys.revoked_key_digest IS 'If the key has been revoked, the digest of the revoked key.';
|
||||
COMMENT ON COLUMN dbcrypt_keys.created_at IS 'The time at which the key was created.';
|
||||
COMMENT ON COLUMN dbcrypt_keys.revoked_at IS 'The time at which the key was revoked.';
|
||||
COMMENT ON COLUMN dbcrypt_keys.test IS 'A column used to test the encryption.';
|
||||
|
||||
ALTER TABLE git_auth_links
|
||||
ADD COLUMN IF NOT EXISTS oauth_access_token_key_id text REFERENCES dbcrypt_keys(active_key_digest),
|
||||
ADD COLUMN IF NOT EXISTS oauth_refresh_token_key_id text REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
COMMENT ON COLUMN git_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
|
||||
COMMENT ON COLUMN git_auth_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
|
||||
|
||||
ALTER TABLE user_links
|
||||
ADD COLUMN IF NOT EXISTS oauth_access_token_key_id text REFERENCES dbcrypt_keys(active_key_digest),
|
||||
ADD COLUMN IF NOT EXISTS oauth_refresh_token_key_id text REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
COMMENT ON COLUMN user_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
|
||||
COMMENT ON COLUMN user_links.oauth_refresh_token_key_id IS 'The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted';
|
|
@ -266,6 +266,7 @@ func TestMigrateUpWithFixtures(t *testing.T) {
|
|||
"template_version_parameters",
|
||||
"workspace_build_parameters",
|
||||
"template_version_variables",
|
||||
"dbcrypt_keys", // having zero rows is a valid state for this table
|
||||
}
|
||||
s := &tableStats{s: make(map[string]int)}
|
||||
|
||||
|
|
|
@ -1591,6 +1591,22 @@ type AuditLog struct {
|
|||
ResourceIcon string `db:"resource_icon" json:"resource_icon"`
|
||||
}
|
||||
|
||||
// A table used to store the keys used to encrypt the database.
|
||||
type DBCryptKey struct {
|
||||
// An integer used to identify the key.
|
||||
Number int32 `db:"number" json:"number"`
|
||||
// If the key is active, the digest of the active key.
|
||||
ActiveKeyDigest sql.NullString `db:"active_key_digest" json:"active_key_digest"`
|
||||
// If the key has been revoked, the digest of the revoked key.
|
||||
RevokedKeyDigest sql.NullString `db:"revoked_key_digest" json:"revoked_key_digest"`
|
||||
// The time at which the key was created.
|
||||
CreatedAt sql.NullTime `db:"created_at" json:"created_at"`
|
||||
// The time at which the key was revoked.
|
||||
RevokedAt sql.NullTime `db:"revoked_at" json:"revoked_at"`
|
||||
// A column used to test the encryption.
|
||||
Test string `db:"test" json:"test"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Hash string `db:"hash" json:"hash"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
|
@ -1608,6 +1624,10 @@ type GitAuthLink struct {
|
|||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
}
|
||||
|
||||
type GitSSHKey struct {
|
||||
|
@ -1949,6 +1969,10 @@ type UserLink struct {
|
|||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
}
|
||||
|
||||
// Visible fields of users are allowed to be joined with other tables for including context of other resources.
|
||||
|
|
|
@ -58,6 +58,7 @@ type sqlcQuerier interface {
|
|||
// This function returns roles for authorization purposes. Implied member roles
|
||||
// are included.
|
||||
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
|
||||
GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error)
|
||||
GetDERPMeshKey(ctx context.Context) (string, error)
|
||||
GetDefaultProxyConfig(ctx context.Context) (GetDefaultProxyConfigRow, error)
|
||||
GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]GetDeploymentDAUsRow, error)
|
||||
|
@ -69,6 +70,7 @@ type sqlcQuerier interface {
|
|||
// Get all templates that use a file.
|
||||
GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]GetFileTemplatesRow, error)
|
||||
GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error)
|
||||
GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error)
|
||||
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
|
||||
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
|
||||
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
|
||||
|
@ -150,6 +152,7 @@ type sqlcQuerier interface {
|
|||
GetUserLatencyInsights(ctx context.Context, arg GetUserLatencyInsightsParams) ([]GetUserLatencyInsightsRow, error)
|
||||
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
// This will never return deleted users.
|
||||
GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error)
|
||||
// This shouldn't check for deleted, because it's frequently used
|
||||
|
@ -206,6 +209,7 @@ type sqlcQuerier interface {
|
|||
// every member of the org.
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error
|
||||
InsertDERPMeshKey(ctx context.Context, value string) error
|
||||
InsertDeploymentID(ctx context.Context, value string) error
|
||||
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
|
||||
|
@ -247,6 +251,7 @@ type sqlcQuerier interface {
|
|||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
|
|
|
@ -636,6 +636,74 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam
|
|||
return i, err
|
||||
}
|
||||
|
||||
const getDBCryptKeys = `-- name: GetDBCryptKeys :many
|
||||
SELECT number, active_key_digest, revoked_key_digest, created_at, revoked_at, test FROM dbcrypt_keys ORDER BY number ASC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetDBCryptKeys(ctx context.Context) ([]DBCryptKey, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getDBCryptKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []DBCryptKey
|
||||
for rows.Next() {
|
||||
var i DBCryptKey
|
||||
if err := rows.Scan(
|
||||
&i.Number,
|
||||
&i.ActiveKeyDigest,
|
||||
&i.RevokedKeyDigest,
|
||||
&i.CreatedAt,
|
||||
&i.RevokedAt,
|
||||
&i.Test,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertDBCryptKey = `-- name: InsertDBCryptKey :exec
|
||||
INSERT INTO dbcrypt_keys
|
||||
(number, active_key_digest, created_at, test)
|
||||
VALUES ($1::int, $2::text, CURRENT_TIMESTAMP, $3::text)
|
||||
`
|
||||
|
||||
type InsertDBCryptKeyParams struct {
|
||||
Number int32 `db:"number" json:"number"`
|
||||
ActiveKeyDigest string `db:"active_key_digest" json:"active_key_digest"`
|
||||
Test string `db:"test" json:"test"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertDBCryptKey(ctx context.Context, arg InsertDBCryptKeyParams) error {
|
||||
_, err := q.db.ExecContext(ctx, insertDBCryptKey, arg.Number, arg.ActiveKeyDigest, arg.Test)
|
||||
return err
|
||||
}
|
||||
|
||||
const revokeDBCryptKey = `-- name: RevokeDBCryptKey :exec
|
||||
UPDATE dbcrypt_keys
|
||||
SET
|
||||
revoked_key_digest = active_key_digest,
|
||||
active_key_digest = revoked_key_digest,
|
||||
revoked_at = CURRENT_TIMESTAMP
|
||||
WHERE
|
||||
active_key_digest = $1::text
|
||||
AND
|
||||
revoked_key_digest IS NULL
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
_, err := q.db.ExecContext(ctx, revokeDBCryptKey, activeKeyDigest)
|
||||
return err
|
||||
}
|
||||
|
||||
const getFileByHashAndCreator = `-- name: GetFileByHashAndCreator :one
|
||||
SELECT
|
||||
hash, created_at, created_by, mimetype, data, id
|
||||
|
@ -800,7 +868,7 @@ func (q *sqlQuerier) InsertFile(ctx context.Context, arg InsertFileParams) (File
|
|||
}
|
||||
|
||||
const getGitAuthLink = `-- name: GetGitAuthLink :one
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry FROM git_auth_links WHERE provider_id = $1 AND user_id = $2
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM git_auth_links WHERE provider_id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type GetGitAuthLinkParams struct {
|
||||
|
@ -819,10 +887,49 @@ func (q *sqlQuerier) GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParam
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getGitAuthLinksByUserID = `-- name: GetGitAuthLinksByUserID :many
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM git_auth_links WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]GitAuthLink, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getGitAuthLinksByUserID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GitAuthLink
|
||||
for rows.Next() {
|
||||
var i GitAuthLink
|
||||
if err := rows.Scan(
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertGitAuthLink = `-- name: InsertGitAuthLink :one
|
||||
INSERT INTO git_auth_links (
|
||||
provider_id,
|
||||
|
@ -830,7 +937,9 @@ INSERT INTO git_auth_links (
|
|||
created_at,
|
||||
updated_at,
|
||||
oauth_access_token,
|
||||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
) VALUES (
|
||||
$1,
|
||||
|
@ -839,18 +948,22 @@ INSERT INTO git_auth_links (
|
|||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7
|
||||
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
$7,
|
||||
$8,
|
||||
$9
|
||||
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
`
|
||||
|
||||
type InsertGitAuthLinkParams struct {
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLinkParams) (GitAuthLink, error) {
|
||||
|
@ -860,7 +973,9 @@ func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLin
|
|||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthAccessTokenKeyID,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i GitAuthLink
|
||||
|
@ -872,6 +987,8 @@ func (q *sqlQuerier) InsertGitAuthLink(ctx context.Context, arg InsertGitAuthLin
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -880,18 +997,22 @@ const updateGitAuthLink = `-- name: UpdateGitAuthLink :one
|
|||
UPDATE git_auth_links SET
|
||||
updated_at = $3,
|
||||
oauth_access_token = $4,
|
||||
oauth_refresh_token = $5,
|
||||
oauth_expiry = $6
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
oauth_access_token_key_id = $5,
|
||||
oauth_refresh_token = $6,
|
||||
oauth_refresh_token_key_id = $7,
|
||||
oauth_expiry = $8
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
`
|
||||
|
||||
type UpdateGitAuthLinkParams struct {
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLinkParams) (GitAuthLink, error) {
|
||||
|
@ -900,7 +1021,9 @@ func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLin
|
|||
arg.UserID,
|
||||
arg.UpdatedAt,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthAccessTokenKeyID,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i GitAuthLink
|
||||
|
@ -912,6 +1035,8 @@ func (q *sqlQuerier) UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLin
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -5450,7 +5575,7 @@ func (q *sqlQuerier) InsertTemplateVersionVariable(ctx context.Context, arg Inse
|
|||
|
||||
const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one
|
||||
SELECT
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
|
@ -5467,13 +5592,15 @@ func (q *sqlQuerier) GetUserLinkByLinkedID(ctx context.Context, linkedID string)
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserLinkByUserIDLoginType = `-- name: GetUserLinkByUserIDLoginType :one
|
||||
SELECT
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
FROM
|
||||
user_links
|
||||
WHERE
|
||||
|
@ -5495,10 +5622,48 @@ func (q *sqlQuerier) GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUs
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserLinksByUserID = `-- name: GetUserLinksByUserID :many
|
||||
SELECT user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM user_links WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUserLinksByUserID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []UserLink
|
||||
for rows.Next() {
|
||||
var i UserLink
|
||||
if err := rows.Scan(
|
||||
&i.UserID,
|
||||
&i.LoginType,
|
||||
&i.LinkedID,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertUserLink = `-- name: InsertUserLink :one
|
||||
INSERT INTO
|
||||
user_links (
|
||||
|
@ -5506,20 +5671,24 @@ INSERT INTO
|
|||
login_type,
|
||||
linked_id,
|
||||
oauth_access_token,
|
||||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
( $1, $2, $3, $4, $5, $6 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
`
|
||||
|
||||
type InsertUserLinkParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
LinkedID string `db:"linked_id" json:"linked_id"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) {
|
||||
|
@ -5528,7 +5697,9 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam
|
|||
arg.LoginType,
|
||||
arg.LinkedID,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthAccessTokenKeyID,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i UserLink
|
||||
|
@ -5539,6 +5710,8 @@ func (q *sqlQuerier) InsertUserLink(ctx context.Context, arg InsertUserLinkParam
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -5548,24 +5721,30 @@ UPDATE
|
|||
user_links
|
||||
SET
|
||||
oauth_access_token = $1,
|
||||
oauth_refresh_token = $2,
|
||||
oauth_expiry = $3
|
||||
oauth_access_token_key_id = $2,
|
||||
oauth_refresh_token = $3,
|
||||
oauth_refresh_token_key_id = $4,
|
||||
oauth_expiry = $5
|
||||
WHERE
|
||||
user_id = $4 AND login_type = $5 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
user_id = $6 AND login_type = $7 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
`
|
||||
|
||||
type UpdateUserLinkParams struct {
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUserLink,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthAccessTokenKeyID,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
arg.UserID,
|
||||
arg.LoginType,
|
||||
|
@ -5578,6 +5757,8 @@ func (q *sqlQuerier) UpdateUserLink(ctx context.Context, arg UpdateUserLinkParam
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -5588,7 +5769,7 @@ UPDATE
|
|||
SET
|
||||
linked_id = $1
|
||||
WHERE
|
||||
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry
|
||||
user_id = $2 AND login_type = $3 RETURNING user_id, login_type, linked_id, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
`
|
||||
|
||||
type UpdateUserLinkedIDParams struct {
|
||||
|
@ -5607,6 +5788,8 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
|
|||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
-- name: GetDBCryptKeys :many
|
||||
SELECT * FROM dbcrypt_keys ORDER BY number ASC;
|
||||
|
||||
-- name: RevokeDBCryptKey :exec
|
||||
UPDATE dbcrypt_keys
|
||||
SET
|
||||
revoked_key_digest = active_key_digest,
|
||||
active_key_digest = revoked_key_digest,
|
||||
revoked_at = CURRENT_TIMESTAMP
|
||||
WHERE
|
||||
active_key_digest = @active_key_digest::text
|
||||
AND
|
||||
revoked_key_digest IS NULL;
|
||||
|
||||
-- name: InsertDBCryptKey :exec
|
||||
INSERT INTO dbcrypt_keys
|
||||
(number, active_key_digest, created_at, test)
|
||||
VALUES (@number::int, @active_key_digest::text, CURRENT_TIMESTAMP, @test::text);
|
|
@ -1,6 +1,9 @@
|
|||
-- name: GetGitAuthLink :one
|
||||
SELECT * FROM git_auth_links WHERE provider_id = $1 AND user_id = $2;
|
||||
|
||||
-- name: GetGitAuthLinksByUserID :many
|
||||
SELECT * FROM git_auth_links WHERE user_id = $1;
|
||||
|
||||
-- name: InsertGitAuthLink :one
|
||||
INSERT INTO git_auth_links (
|
||||
provider_id,
|
||||
|
@ -8,7 +11,9 @@ INSERT INTO git_auth_links (
|
|||
created_at,
|
||||
updated_at,
|
||||
oauth_access_token,
|
||||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
) VALUES (
|
||||
$1,
|
||||
|
@ -17,13 +22,17 @@ INSERT INTO git_auth_links (
|
|||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7
|
||||
$7,
|
||||
$8,
|
||||
$9
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateGitAuthLink :one
|
||||
UPDATE git_auth_links SET
|
||||
updated_at = $3,
|
||||
oauth_access_token = $4,
|
||||
oauth_refresh_token = $5,
|
||||
oauth_expiry = $6
|
||||
oauth_access_token_key_id = $5,
|
||||
oauth_refresh_token = $6,
|
||||
oauth_refresh_token_key_id = $7,
|
||||
oauth_expiry = $8
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
|
||||
|
|
|
@ -14,6 +14,9 @@ FROM
|
|||
WHERE
|
||||
user_id = $1 AND login_type = $2;
|
||||
|
||||
-- name: GetUserLinksByUserID :many
|
||||
SELECT * FROM user_links WHERE user_id = $1;
|
||||
|
||||
-- name: InsertUserLink :one
|
||||
INSERT INTO
|
||||
user_links (
|
||||
|
@ -21,11 +24,13 @@ INSERT INTO
|
|||
login_type,
|
||||
linked_id,
|
||||
oauth_access_token,
|
||||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
( $1, $2, $3, $4, $5, $6 ) RETURNING *;
|
||||
( $1, $2, $3, $4, $5, $6, $7, $8 ) RETURNING *;
|
||||
|
||||
-- name: UpdateUserLinkedID :one
|
||||
UPDATE
|
||||
|
@ -40,7 +45,9 @@ UPDATE
|
|||
user_links
|
||||
SET
|
||||
oauth_access_token = $1,
|
||||
oauth_refresh_token = $2,
|
||||
oauth_expiry = $3
|
||||
oauth_access_token_key_id = $2,
|
||||
oauth_refresh_token = $3,
|
||||
oauth_refresh_token_key_id = $4,
|
||||
oauth_expiry = $5
|
||||
WHERE
|
||||
user_id = $4 AND login_type = $5 RETURNING *;
|
||||
user_id = $6 AND login_type = $7 RETURNING *;
|
||||
|
|
|
@ -40,6 +40,7 @@ overrides:
|
|||
api_key_scope_application_connect: APIKeyScopeApplicationConnect
|
||||
avatar_url: AvatarURL
|
||||
created_by_avatar_url: CreatedByAvatarURL
|
||||
dbcrypt_key: DBCryptKey
|
||||
session_count_vscode: SessionCountVSCode
|
||||
session_count_jetbrains: SessionCountJetBrains
|
||||
session_count_reconnecting_pty: SessionCountReconnectingPTY
|
||||
|
@ -47,9 +48,11 @@ overrides:
|
|||
connection_median_latency_ms: ConnectionMedianLatencyMS
|
||||
login_type_oidc: LoginTypeOIDC
|
||||
oauth_access_token: OAuthAccessToken
|
||||
oauth_access_token_key_id: OAuthAccessTokenKeyID
|
||||
oauth_expiry: OAuthExpiry
|
||||
oauth_id_token: OAuthIDToken
|
||||
oauth_refresh_token: OAuthRefreshToken
|
||||
oauth_refresh_token_key_id: OAuthRefreshTokenKeyID
|
||||
parameter_type_system_hcl: ParameterTypeSystemHCL
|
||||
userstatus: UserStatus
|
||||
gitsshkey: GitSSHKey
|
||||
|
|
|
@ -6,6 +6,8 @@ type UniqueConstraint string
|
|||
|
||||
// UniqueConstraint enums.
|
||||
const (
|
||||
UniqueDbcryptKeysActiveKeyDigestKey UniqueConstraint = "dbcrypt_keys_active_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_active_key_digest_key UNIQUE (active_key_digest);
|
||||
UniqueDbcryptKeysRevokedKeyDigestKey UniqueConstraint = "dbcrypt_keys_revoked_key_digest_key" // ALTER TABLE ONLY dbcrypt_keys ADD CONSTRAINT dbcrypt_keys_revoked_key_digest_key UNIQUE (revoked_key_digest);
|
||||
UniqueFilesHashCreatedByKey UniqueConstraint = "files_hash_created_by_key" // ALTER TABLE ONLY files ADD CONSTRAINT files_hash_created_by_key UNIQUE (hash, created_by);
|
||||
UniqueGitAuthLinksProviderIDUserIDKey UniqueConstraint = "git_auth_links_provider_id_user_id_key" // ALTER TABLE ONLY git_auth_links ADD CONSTRAINT git_auth_links_provider_id_user_id_key UNIQUE (provider_id, user_id);
|
||||
UniqueGroupMembersUserIDGroupIDKey UniqueConstraint = "group_members_user_id_group_id_key" // ALTER TABLE ONLY group_members ADD CONSTRAINT group_members_user_id_group_id_key UNIQUE (user_id, group_id);
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
package dbcrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// cipherAES256GCM is the name of the AES-256 cipher.
|
||||
// This is used to identify the cipher used to encrypt a value.
|
||||
// It is added to the digest to ensure that if, in the future,
|
||||
// we add a new cipher type, and a key is re-used, we don't
|
||||
// accidentally decrypt the wrong values.
|
||||
// When adding a new cipher type, add a new constant here
|
||||
// and ensure to add the cipher name to the digest of the new
|
||||
// cipher type.
|
||||
const (
|
||||
cipherAES256GCM = "aes256gcm"
|
||||
)
|
||||
|
||||
type Cipher interface {
|
||||
Encrypt([]byte) ([]byte, error)
|
||||
Decrypt([]byte) ([]byte, error)
|
||||
HexDigest() string
|
||||
}
|
||||
|
||||
// NewCiphers is a convenience function for creating multiple ciphers.
|
||||
// It currently only supports AES-256-GCM.
|
||||
func NewCiphers(keys ...[]byte) ([]Cipher, error) {
|
||||
var cs []Cipher
|
||||
for _, key := range keys {
|
||||
c, err := cipherAES256(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs = append(cs, c)
|
||||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// cipherAES256 returns a new AES-256 cipher.
|
||||
func cipherAES256(key []byte) (*aes256, error) {
|
||||
if len(key) != 32 {
|
||||
return nil, xerrors.Errorf("key must be 32 bytes")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// We add the cipher name to the digest to ensure that if, in the future,
|
||||
// we add a new cipher type, and a key is re-used, we don't accidentally
|
||||
// decrypt the wrong values.
|
||||
toDigest := []byte(cipherAES256GCM)
|
||||
toDigest = append(toDigest, key...)
|
||||
digest := fmt.Sprintf("%x", sha256.Sum256(toDigest))[:7]
|
||||
return &aes256{aead: aead, digest: digest}, nil
|
||||
}
|
||||
|
||||
type aes256 struct {
|
||||
aead cipher.AEAD
|
||||
// digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead.
|
||||
digest string
|
||||
}
|
||||
|
||||
func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
nonce := make([]byte, a.aead.NonceSize())
|
||||
_, err := io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dst := make([]byte, len(nonce))
|
||||
copy(dst, nonce)
|
||||
return a.aead.Seal(dst, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) < a.aead.NonceSize() {
|
||||
return nil, xerrors.Errorf("ciphertext too short")
|
||||
}
|
||||
decrypted, err := a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil)
|
||||
if err != nil {
|
||||
return nil, &DecryptFailedError{Inner: err}
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func (a *aes256) HexDigest() string {
|
||||
return a.digest
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
package dbcrypt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCipherAES256(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ValidInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := bytes.Repeat([]byte{'a'}, 32)
|
||||
cipher, err := cipherAES256(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
output, err := cipher.Encrypt([]byte("hello world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := cipher.Decrypt(output)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello world", string(response))
|
||||
})
|
||||
|
||||
t.Run("InvalidInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := bytes.Repeat([]byte{'a'}, 32)
|
||||
cipher, err := cipherAES256(key)
|
||||
require.NoError(t, err)
|
||||
_, err = cipher.Decrypt(bytes.Repeat([]byte{'a'}, 100))
|
||||
var decryptErr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &decryptErr)
|
||||
})
|
||||
|
||||
t.Run("InvalidKeySize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := cipherAES256(bytes.Repeat([]byte{'a'}, 31))
|
||||
require.ErrorContains(t, err, "key must be 32 bytes")
|
||||
})
|
||||
|
||||
t.Run("TestNonce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := bytes.Repeat([]byte{'a'}, 32)
|
||||
cipher, err := cipherAES256(key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "864f702", cipher.HexDigest())
|
||||
|
||||
encrypted1, err := cipher.Encrypt([]byte("hello world"))
|
||||
require.NoError(t, err)
|
||||
encrypted2, err := cipher.Encrypt([]byte("hello world"))
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, encrypted1, encrypted2, "nonce should be different for each encryption")
|
||||
|
||||
munged := make([]byte, len(encrypted1))
|
||||
copy(munged, encrypted1)
|
||||
munged[0] = munged[0] ^ 0xff
|
||||
_, err = cipher.Decrypt(munged)
|
||||
var decryptErr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &decryptErr, "munging the first byte of the encrypted data should cause decryption to fail")
|
||||
})
|
||||
}
|
||||
|
||||
// This test ensures backwards compatibility. If it breaks, something is very wrong.
|
||||
func TestCiphersBackwardCompatibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
msg = "hello world"
|
||||
key = bytes.Repeat([]byte{'a'}, 32)
|
||||
//nolint: gosec // The below is the base64-encoded result of encrypting the above message with the above key.
|
||||
encoded = `YhAz+lE2fFeeiVPH9voKN7UV1xSDrgcnC0LmNXmaAk1Yg0kPFO3x`
|
||||
)
|
||||
|
||||
cipher, err := cipherAES256(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// This is the code that was used to generate the above.
|
||||
// Note that the output of this code will change every time it is run.
|
||||
// encrypted, err := cipher.Encrypt([]byte(msg))
|
||||
// require.NoError(t, err)
|
||||
// t.Logf("encoded: %q", base64.StdEncoding.EncodeToString(encrypted))
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
require.NoError(t, err, "the encoded string should be valid base64")
|
||||
decrypted, err := cipher.Decrypt(decoded)
|
||||
require.NoError(t, err, "decryption should succeed")
|
||||
require.Equal(t, msg, string(decrypted), "decrypted message should match original message")
|
||||
}
|
|
@ -0,0 +1,374 @@
|
|||
package dbcrypt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// testValue is the value that is stored in dbcrypt_keys.test.
|
||||
// This is used to determine if the key is valid.
|
||||
const testValue = "coder"
|
||||
|
||||
var (
|
||||
b64encode = base64.StdEncoding.EncodeToString
|
||||
b64decode = base64.StdEncoding.DecodeString
|
||||
)
|
||||
|
||||
// DecryptFailedError is returned when decryption fails.
|
||||
type DecryptFailedError struct {
|
||||
Inner error
|
||||
}
|
||||
|
||||
func (e *DecryptFailedError) Error() string {
|
||||
return xerrors.Errorf("decrypt failed: %w", e.Inner).Error()
|
||||
}
|
||||
|
||||
// New creates a database.Store wrapper that encrypts/decrypts values
|
||||
// stored at rest in the database.
|
||||
func New(ctx context.Context, db database.Store, ciphers ...Cipher) (database.Store, error) {
|
||||
cm := make(map[string]Cipher)
|
||||
for _, c := range ciphers {
|
||||
cm[c.HexDigest()] = c
|
||||
}
|
||||
dbc := &dbCrypt{
|
||||
ciphers: cm,
|
||||
Store: db,
|
||||
}
|
||||
if len(ciphers) > 0 {
|
||||
dbc.primaryCipherDigest = ciphers[0].HexDigest()
|
||||
}
|
||||
// nolint: gocritic // This is allowed.
|
||||
authCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
if err := dbc.ensureEncryptedWithRetry(authCtx); err != nil {
|
||||
return nil, xerrors.Errorf("ensure encrypted database fields: %w", err)
|
||||
}
|
||||
return dbc, nil
|
||||
}
|
||||
|
||||
type dbCrypt struct {
|
||||
// primaryCipherDigest is the digest of the primary cipher used for encrypting data.
|
||||
primaryCipherDigest string
|
||||
// ciphers is a map of cipher digests to ciphers.
|
||||
ciphers map[string]Cipher
|
||||
database.Store
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InTx(function func(database.Store) error, txOpts *sql.TxOptions) error {
|
||||
return db.Store.InTx(func(s database.Store) error {
|
||||
return function(&dbCrypt{
|
||||
primaryCipherDigest: db.primaryCipherDigest,
|
||||
ciphers: db.ciphers,
|
||||
Store: s,
|
||||
})
|
||||
}, txOpts)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetDBCryptKeys(ctx context.Context) ([]database.DBCryptKey, error) {
|
||||
ks, err := db.Store.GetDBCryptKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Decrypt the test field to ensure that the key is valid.
|
||||
for i := range ks {
|
||||
if !ks[i].ActiveKeyDigest.Valid {
|
||||
// Key has been revoked. We can't decrypt the test field, but
|
||||
// we need to return it so that the caller knows that the key
|
||||
// has been revoked.
|
||||
continue
|
||||
}
|
||||
if err := db.decryptField(&ks[i].Test, ks[i].ActiveKeyDigest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
// This does not need any special handling as it does not touch any encrypted fields.
|
||||
// Explicitly defining this here to avoid confusion.
|
||||
func (db *dbCrypt) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
return db.Store.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertDBCryptKey(ctx context.Context, arg database.InsertDBCryptKeyParams) error {
|
||||
// It's nicer to be able to pass a *sql.NullString to encryptField, but we need to pass a string here.
|
||||
var digest sql.NullString
|
||||
err := db.encryptField(&arg.Test, &digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
arg.ActiveKeyDigest = digest.String
|
||||
return db.Store.InsertDBCryptKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
|
||||
link, err := db.Store.GetUserLinkByLinkedID(ctx, linkedID)
|
||||
if err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserLink, error) {
|
||||
links, err := db.Store.GetUserLinksByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for idx := range links {
|
||||
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return links, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetUserLinkByUserIDLoginType(ctx context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
|
||||
link, err := db.Store.GetUserLinkByUserIDLoginType(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertUserLink(ctx context.Context, params database.InsertUserLinkParams) (database.UserLink, error) {
|
||||
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
link, err := db.Store.InsertUserLink(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateUserLink(ctx context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
|
||||
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
link, err := db.Store.UpdateUserLink(ctx, params)
|
||||
if err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.UserLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) InsertGitAuthLink(ctx context.Context, params database.InsertGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
link, err := db.Store.InsertGitAuthLink(ctx, params)
|
||||
if err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetGitAuthLink(ctx context.Context, params database.GetGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
link, err := db.Store.GetGitAuthLink(ctx, params)
|
||||
if err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) GetGitAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]database.GitAuthLink, error) {
|
||||
links, err := db.Store.GetGitAuthLinksByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for idx := range links {
|
||||
if err := db.decryptField(&links[idx].OAuthAccessToken, links[idx].OAuthAccessTokenKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.decryptField(&links[idx].OAuthRefreshToken, links[idx].OAuthRefreshTokenKeyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return links, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) UpdateGitAuthLink(ctx context.Context, params database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) {
|
||||
if err := db.encryptField(¶ms.OAuthAccessToken, ¶ms.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.encryptField(¶ms.OAuthRefreshToken, ¶ms.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
link, err := db.Store.UpdateGitAuthLink(ctx, params)
|
||||
if err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthAccessToken, link.OAuthAccessTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
if err := db.decryptField(&link.OAuthRefreshToken, link.OAuthRefreshTokenKeyID); err != nil {
|
||||
return database.GitAuthLink{}, err
|
||||
}
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) encryptField(field *string, digest *sql.NullString) error {
|
||||
// If no cipher is loaded, then we can't encrypt anything!
|
||||
if db.ciphers == nil || db.primaryCipherDigest == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
return xerrors.Errorf("developer error: encryptField called with nil field")
|
||||
}
|
||||
if digest == nil {
|
||||
return xerrors.Errorf("developer error: encryptField called with nil digest")
|
||||
}
|
||||
|
||||
encrypted, err := db.ciphers[db.primaryCipherDigest].Encrypt([]byte(*field))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Base64 is used to support UTF-8 encoding in PostgreSQL.
|
||||
*field = b64encode(encrypted)
|
||||
*digest = sql.NullString{String: db.primaryCipherDigest, Valid: true}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptFields decrypts the given field using the key with the given digest.
|
||||
// If the value fails to decrypt, sql.ErrNoRows will be returned.
|
||||
func (db *dbCrypt) decryptField(field *string, digest sql.NullString) error {
|
||||
if field == nil {
|
||||
return xerrors.Errorf("developer error: decryptField called with nil field")
|
||||
}
|
||||
|
||||
if !digest.Valid || digest.String == "" {
|
||||
// This field is not encrypted.
|
||||
return nil
|
||||
}
|
||||
|
||||
key, ok := db.ciphers[digest.String]
|
||||
if !ok {
|
||||
return &DecryptFailedError{
|
||||
Inner: xerrors.Errorf("no cipher with digest %q", digest.String),
|
||||
}
|
||||
}
|
||||
|
||||
data, err := b64decode(*field)
|
||||
if err != nil {
|
||||
// If it's not valid base64, we should complain loudly.
|
||||
return &DecryptFailedError{
|
||||
Inner: xerrors.Errorf("malformed encrypted field %q: %w", *field, err),
|
||||
}
|
||||
}
|
||||
decrypted, err := key.Decrypt(data)
|
||||
if err != nil {
|
||||
return &DecryptFailedError{Inner: err}
|
||||
}
|
||||
*field = string(decrypted)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *dbCrypt) ensureEncryptedWithRetry(ctx context.Context) error {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
err = db.ensureEncrypted(ctx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// If we get a serialization error, then we need to retry.
|
||||
if !database.IsSerializedError(err) {
|
||||
return err
|
||||
}
|
||||
// otherwise, retry
|
||||
}
|
||||
// If we get here, then we ran out of retries
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *dbCrypt) ensureEncrypted(ctx context.Context) error {
|
||||
return db.InTx(func(s database.Store) error {
|
||||
// Attempt to read the encrypted test fields of the currently active keys.
|
||||
ks, err := s.GetDBCryptKeys(ctx)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
var highestNumber int32
|
||||
var activeCipherFound bool
|
||||
for _, k := range ks {
|
||||
// If our primary key has been revoked, then we can't do anything.
|
||||
if k.RevokedKeyDigest.Valid && k.RevokedKeyDigest.String == db.primaryCipherDigest {
|
||||
return xerrors.Errorf("primary encryption key %q has been revoked", db.primaryCipherDigest)
|
||||
}
|
||||
|
||||
if k.ActiveKeyDigest.Valid && k.ActiveKeyDigest.String == db.primaryCipherDigest {
|
||||
activeCipherFound = true
|
||||
}
|
||||
|
||||
if k.Number > highestNumber {
|
||||
highestNumber = k.Number
|
||||
}
|
||||
}
|
||||
|
||||
if activeCipherFound || len(db.ciphers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we get here, then we have a new key that we need to insert.
|
||||
return db.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
|
||||
Number: highestNumber + 1,
|
||||
ActiveKeyDigest: db.primaryCipherDigest,
|
||||
Test: testValue,
|
||||
})
|
||||
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
|
||||
}
|
|
@ -0,0 +1,679 @@
|
|||
package dbcrypt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
)
|
||||
|
||||
func TestUserLinks(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("InsertUserLink", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
link := dbgen.UserLink(t, crypt, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
require.Equal(t, "access", link.OAuthAccessToken)
|
||||
require.Equal(t, "refresh", link.OAuthRefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
|
||||
|
||||
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("UpdateUserLink", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
link := dbgen.UserLink(t, crypt, database.UserLink{
|
||||
UserID: user.ID,
|
||||
})
|
||||
|
||||
updated, err := crypt.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "access", updated.OAuthAccessToken)
|
||||
require.Equal(t, "refresh", updated.OAuthRefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
|
||||
|
||||
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("GetUserLinkByLinkedID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
link := dbgen.UserLink(t, crypt, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
|
||||
link, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "access", link.OAuthAccessToken)
|
||||
require.Equal(t, "refresh", link.OAuthRefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
|
||||
|
||||
rawLink, err := db.GetUserLinkByLinkedID(ctx, link.LinkedID)
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("DecryptErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
link := dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: fakeBase64RandomData(t, 32),
|
||||
OAuthRefreshToken: fakeBase64RandomData(t, 32),
|
||||
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
})
|
||||
|
||||
_, err := crypt.GetUserLinkByLinkedID(ctx, link.LinkedID)
|
||||
require.Error(t, err, "expected an error")
|
||||
var derr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetUserLinksByUserID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
link := dbgen.UserLink(t, crypt, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
links, err := crypt.GetUserLinksByUserID(ctx, link.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, links, 1)
|
||||
require.Equal(t, "access", links[0].OAuthAccessToken)
|
||||
require.Equal(t, "refresh", links[0].OAuthRefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String)
|
||||
|
||||
rawLinks, err := db.GetUserLinksByUserID(ctx, link.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rawLinks, 1)
|
||||
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, crypt, _ := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
links, err := crypt.GetUserLinksByUserID(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, links)
|
||||
})
|
||||
|
||||
t.Run("DecryptErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_ = dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: fakeBase64RandomData(t, 32),
|
||||
OAuthRefreshToken: fakeBase64RandomData(t, 32),
|
||||
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
})
|
||||
_, err := crypt.GetUserLinksByUserID(ctx, user.ID)
|
||||
require.Error(t, err, "expected an error")
|
||||
var derr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetUserLinkByUserIDLoginType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
link := dbgen.UserLink(t, crypt, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
|
||||
link, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "access", link.OAuthAccessToken)
|
||||
require.Equal(t, "refresh", link.OAuthRefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthAccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), link.OAuthRefreshTokenKeyID.String)
|
||||
|
||||
rawLink, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("DecryptErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
link := dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: fakeBase64RandomData(t, 32),
|
||||
OAuthRefreshToken: fakeBase64RandomData(t, 32),
|
||||
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
})
|
||||
|
||||
_, err := crypt.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.Error(t, err, "expected an error")
|
||||
var derr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGitAuthLinks(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("InsertGitAuthLink", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
require.Equal(t, "access", link.OAuthAccessToken)
|
||||
require.Equal(t, "refresh", link.OAuthRefreshToken)
|
||||
|
||||
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("UpdateGitAuthLink", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{})
|
||||
updated, err := crypt.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "access", updated.OAuthAccessToken)
|
||||
require.Equal(t, "refresh", updated.OAuthRefreshToken)
|
||||
|
||||
link, err = db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("GetGitAuthLink", func(t *testing.T) {
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
link, err := db.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
UserID: link.UserID,
|
||||
ProviderID: link.ProviderID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], link.OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], link.OAuthRefreshToken, "refresh")
|
||||
})
|
||||
t.Run("DecryptErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
OAuthAccessToken: fakeBase64RandomData(t, 32),
|
||||
OAuthRefreshToken: fakeBase64RandomData(t, 32),
|
||||
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
})
|
||||
|
||||
_, err := crypt.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
UserID: link.UserID,
|
||||
ProviderID: link.ProviderID,
|
||||
})
|
||||
require.Error(t, err, "expected an error")
|
||||
var derr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetGitAuthLinksByUserID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, crypt, database.User{})
|
||||
link := dbgen.GitAuthLink(t, crypt, database.GitAuthLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: "access",
|
||||
OAuthRefreshToken: "refresh",
|
||||
})
|
||||
links, err := crypt.GetGitAuthLinksByUserID(ctx, link.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, links, 1)
|
||||
require.Equal(t, "access", links[0].OAuthAccessToken)
|
||||
require.Equal(t, "refresh", links[0].OAuthRefreshToken)
|
||||
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthAccessTokenKeyID.String)
|
||||
require.Equal(t, ciphers[0].HexDigest(), links[0].OAuthRefreshTokenKeyID.String)
|
||||
|
||||
rawLinks, err := db.GetGitAuthLinksByUserID(ctx, link.UserID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rawLinks, 1)
|
||||
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthAccessToken, "access")
|
||||
requireEncryptedEquals(t, ciphers[0], rawLinks[0].OAuthRefreshToken, "refresh")
|
||||
})
|
||||
|
||||
t.Run("DecryptErr", func(t *testing.T) {
|
||||
db, crypt, ciphers := setup(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
UserID: user.ID,
|
||||
OAuthAccessToken: fakeBase64RandomData(t, 32),
|
||||
OAuthRefreshToken: fakeBase64RandomData(t, 32),
|
||||
OAuthAccessTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
OAuthRefreshTokenKeyID: sql.NullString{String: ciphers[0].HexDigest(), Valid: true},
|
||||
})
|
||||
_, err := crypt.GetGitAuthLinksByUserID(ctx, link.UserID)
|
||||
require.Error(t, err, "expected an error")
|
||||
var derr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Given: a cipher is loaded
|
||||
cipher := initCipher(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rawDB, _ := dbtestutil.NewDB(t)
|
||||
|
||||
// Before: no keys should be present
|
||||
keys, err := rawDB.GetDBCryptKeys(ctx)
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
require.Empty(t, keys, "no keys should be present")
|
||||
|
||||
// When: we init the crypt db
|
||||
_, err = New(ctx, rawDB, cipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: a new key is inserted
|
||||
keys, err = rawDB.GetDBCryptKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1, "one key should be present")
|
||||
require.Equal(t, cipher.HexDigest(), keys[0].ActiveKeyDigest.String, "key digest mismatch")
|
||||
require.Empty(t, keys[0].RevokedKeyDigest.String, "key should not be revoked")
|
||||
requireEncryptedEquals(t, cipher, keys[0].Test, "coder")
|
||||
})
|
||||
|
||||
t.Run("MissingKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: there exist two valid encryption keys
|
||||
cipher1 := initCipher(t)
|
||||
cipher2 := initCipher(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rawDB, _ := dbtestutil.NewDB(t)
|
||||
|
||||
// Given: key 1 is already present in the database
|
||||
err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
|
||||
Number: 1,
|
||||
ActiveKeyDigest: cipher1.HexDigest(),
|
||||
Test: fakeBase64RandomData(t, 32),
|
||||
})
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
keys, err := rawDB.GetDBCryptKeys(ctx)
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
require.Len(t, keys, 1, "one key should be present")
|
||||
|
||||
// When: we init the crypt db with no keys
|
||||
_, err = New(ctx, rawDB)
|
||||
// Then: we error because we don't know how to decrypt the existing key
|
||||
require.Error(t, err, "expected an error")
|
||||
var derr *DecryptFailedError
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
|
||||
// When: we init the crypt db with key 2
|
||||
_, err = New(ctx, rawDB, cipher2)
|
||||
|
||||
// Then: we error because the key is not revoked and we don't know how to decrypt it
|
||||
require.Error(t, err, "expected an error")
|
||||
require.ErrorAs(t, err, &derr, "expected a decrypt error")
|
||||
|
||||
// When: the existing key is marked as having been revoked
|
||||
err = rawDB.RevokeDBCryptKey(ctx, cipher1.HexDigest())
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
|
||||
// And: we init the crypt db with key 2
|
||||
_, err = New(ctx, rawDB, cipher2)
|
||||
|
||||
// Then: we succeed
|
||||
require.NoError(t, err)
|
||||
|
||||
// And: key 2 should now be the active key
|
||||
keys, err = rawDB.GetDBCryptKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 2, "two keys should be present")
|
||||
require.EqualValues(t, keys[0].Number, 1, "key number mismatch")
|
||||
require.Empty(t, keys[0].ActiveKeyDigest.String, "key should not be active")
|
||||
require.Equal(t, cipher1.HexDigest(), keys[0].RevokedKeyDigest.String, "key should be revoked")
|
||||
|
||||
require.EqualValues(t, keys[1].Number, 2, "key number mismatch")
|
||||
require.Equal(t, cipher2.HexDigest(), keys[1].ActiveKeyDigest.String, "key digest mismatch")
|
||||
require.Empty(t, keys[1].RevokedKeyDigest.String, "key should not be revoked")
|
||||
requireEncryptedEquals(t, cipher2, keys[1].Test, "coder")
|
||||
})
|
||||
|
||||
t.Run("NoKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Given: no cipher is loaded
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rawDB, _ := dbtestutil.NewDB(t)
|
||||
|
||||
keys, err := rawDB.GetDBCryptKeys(ctx)
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
require.Empty(t, keys, "no keys should be present")
|
||||
|
||||
// When: we init the crypt db with no ciphers
|
||||
_, err = New(ctx, rawDB)
|
||||
|
||||
// Then: it should succeed.
|
||||
require.NoError(t, err, "dbcrypt.New should work with no keys against an unencrypted database")
|
||||
|
||||
// Assert invariant: no keys are inserted
|
||||
keys, err = rawDB.GetDBCryptKeys(ctx)
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
require.Empty(t, keys, "no keys should be present")
|
||||
|
||||
// Insert a key
|
||||
require.NoError(t, rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
|
||||
Number: 1,
|
||||
ActiveKeyDigest: "whatever",
|
||||
Test: fakeBase64RandomData(t, 32),
|
||||
}))
|
||||
|
||||
// This should fail as we do not know how to decrypt the key:
|
||||
_, err = New(ctx, rawDB)
|
||||
require.Error(t, err)
|
||||
// Until we revoke the key:
|
||||
require.NoError(t, rawDB.RevokeDBCryptKey(ctx, "whatever"))
|
||||
_, err = New(ctx, rawDB)
|
||||
require.NoError(t, err, "the above should still hold if the key is revoked")
|
||||
})
|
||||
|
||||
t.Run("PrimaryRevoked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Given: a cipher is loaded
|
||||
cipher := initCipher(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rawDB, _ := dbtestutil.NewDB(t)
|
||||
|
||||
// And: the cipher is revoked before we init the crypt db
|
||||
err := rawDB.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{
|
||||
Number: 1,
|
||||
ActiveKeyDigest: cipher.HexDigest(),
|
||||
Test: fakeBase64RandomData(t, 32),
|
||||
})
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
err = rawDB.RevokeDBCryptKey(ctx, cipher.HexDigest())
|
||||
require.NoError(t, err, "no error should be returned")
|
||||
|
||||
// Then: when we init the crypt db, we error because the key is revoked
|
||||
_, err = New(ctx, rawDB, cipher)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "has been revoked")
|
||||
})
|
||||
|
||||
t.Run("Retry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Given: a cipher is loaded
|
||||
cipher := initCipher(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
testVal, err := cipher.Encrypt([]byte("coder"))
|
||||
key := database.DBCryptKey{
|
||||
Number: 1,
|
||||
ActiveKeyDigest: sql.NullString{String: cipher.HexDigest(), Valid: true},
|
||||
Test: b64encode(testVal),
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// And: a database that returns an error once when we try to serialize a key
|
||||
ctrl := gomock.NewController(t)
|
||||
mockDB := dbmock.NewMockStore(ctrl)
|
||||
|
||||
gomock.InOrder(
|
||||
// First try: we get a serialization error.
|
||||
expectInTx(mockDB),
|
||||
mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{}, nil),
|
||||
mockDB.EXPECT().InsertDBCryptKey(gomock.Any(), gomock.Any()).Times(1).Return(&pq.Error{Code: "40001"}),
|
||||
// Second try: we get the key we wanted to insert initially.
|
||||
expectInTx(mockDB),
|
||||
mockDB.EXPECT().GetDBCryptKeys(gomock.Any()).Times(1).Return([]database.DBCryptKey{key}, nil),
|
||||
)
|
||||
|
||||
_, err = New(ctx, mockDB, cipher)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncryptDecryptField(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cryptDB, ciphers := setup(t)
|
||||
field := "coder"
|
||||
digest := sql.NullString{}
|
||||
require.NoError(t, cryptDB.encryptField(&field, &digest))
|
||||
require.Equal(t, ciphers[0].HexDigest(), digest.String)
|
||||
requireEncryptedEquals(t, ciphers[0], field, "coder")
|
||||
require.NoError(t, cryptDB.decryptField(&field, digest))
|
||||
require.Equal(t, "coder", field)
|
||||
})
|
||||
|
||||
t.Run("NoKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// With no keys, encryption and decryption are both no-ops.
|
||||
_, cryptDB := setupNoCiphers(t)
|
||||
field := "coder"
|
||||
digest := sql.NullString{}
|
||||
require.NoError(t, cryptDB.encryptField(&field, &digest))
|
||||
require.Empty(t, digest.String)
|
||||
require.Equal(t, "coder", field)
|
||||
require.NoError(t, cryptDB.decryptField(&field, digest))
|
||||
require.Equal(t, "coder", field)
|
||||
})
|
||||
|
||||
t.Run("MissingKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cryptDB, ciphers := setup(t)
|
||||
field := "coder"
|
||||
digest := sql.NullString{}
|
||||
err := cryptDB.encryptField(&field, &digest)
|
||||
require.NoError(t, err)
|
||||
requireEncryptedEquals(t, ciphers[0], field, "coder")
|
||||
require.Equal(t, ciphers[0].HexDigest(), digest.String)
|
||||
require.True(t, digest.Valid)
|
||||
|
||||
digest = sql.NullString{String: "missing", Valid: true}
|
||||
var derr *DecryptFailedError
|
||||
err = cryptDB.decryptField(&field, digest)
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &derr)
|
||||
})
|
||||
|
||||
t.Run("CantEncryptOrDecryptNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cryptDB, _ := setup(t)
|
||||
require.ErrorContains(t, cryptDB.encryptField(nil, nil), "developer error")
|
||||
require.ErrorContains(t, cryptDB.decryptField(nil, sql.NullString{}), "developer error")
|
||||
})
|
||||
|
||||
t.Run("EncryptEmptyString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cryptDB, ciphers := setup(t)
|
||||
field := ""
|
||||
digest := sql.NullString{}
|
||||
require.NoError(t, cryptDB.encryptField(&field, &digest))
|
||||
requireEncryptedEquals(t, ciphers[0], field, "")
|
||||
require.Equal(t, ciphers[0].HexDigest(), digest.String)
|
||||
require.NoError(t, cryptDB.decryptField(&field, digest))
|
||||
require.Empty(t, field)
|
||||
})
|
||||
|
||||
t.Run("DecryptEmptyString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cryptDB, ciphers := setup(t)
|
||||
field := ""
|
||||
digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true}
|
||||
err := cryptDB.decryptField(&field, digest)
|
||||
// Currently this has to fail because the ciphertext must at least
|
||||
// have a nonce. This may need to be changed depending on future
|
||||
// ciphers.
|
||||
require.ErrorContains(t, err, "ciphertext too short")
|
||||
})
|
||||
|
||||
t.Run("InvalidBase64", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, cryptDB, ciphers := setup(t)
|
||||
field := "not valid base64"
|
||||
digest := sql.NullString{String: ciphers[0].HexDigest(), Valid: true}
|
||||
err := cryptDB.decryptField(&field, digest)
|
||||
require.ErrorContains(t, err, "illegal base64 data")
|
||||
})
|
||||
}
|
||||
|
||||
func expectInTx(mdb *dbmock.MockStore) *gomock.Call {
|
||||
return mdb.EXPECT().InTx(gomock.Any(), gomock.Any()).Times(1).DoAndReturn(
|
||||
func(f func(store database.Store) error, _ *sql.TxOptions) error {
|
||||
return f(mdb)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func requireEncryptedEquals(t *testing.T, c Cipher, value, expected string) {
|
||||
t.Helper()
|
||||
data, err := base64.StdEncoding.DecodeString(value)
|
||||
require.NoError(t, err, "invalid base64")
|
||||
got, err := c.Decrypt(data)
|
||||
require.NoError(t, err, "failed to decrypt data")
|
||||
require.Equal(t, expected, string(got), "decrypted data does not match")
|
||||
}
|
||||
|
||||
func initCipher(t *testing.T) *aes256 {
|
||||
t.Helper()
|
||||
key := make([]byte, 32) // AES-256 key size is 32 bytes
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
require.NoError(t, err)
|
||||
c, err := cipherAES256(key)
|
||||
require.NoError(t, err)
|
||||
return c
|
||||
}
|
||||
|
||||
func setup(t *testing.T) (db database.Store, cryptDB *dbCrypt, cs []Cipher) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rawDB, _ := dbtestutil.NewDB(t)
|
||||
|
||||
cs = append(cs, initCipher(t))
|
||||
cdb, err := New(ctx, rawDB, cs...)
|
||||
require.NoError(t, err)
|
||||
cryptDB, ok := cdb.(*dbCrypt)
|
||||
require.True(t, ok)
|
||||
|
||||
return rawDB, cryptDB, cs
|
||||
}
|
||||
|
||||
func setupNoCiphers(t *testing.T) (db database.Store, cryptodb *dbCrypt) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rawDB, _ := dbtestutil.NewDB(t)
|
||||
cdb, err := New(ctx, rawDB)
|
||||
require.NoError(t, err)
|
||||
cryptDB, ok := cdb.(*dbCrypt)
|
||||
require.True(t, ok)
|
||||
return rawDB, cryptDB
|
||||
}
|
||||
|
||||
func fakeBase64RandomData(t *testing.T, n int) string {
|
||||
t.Helper()
|
||||
b := make([]byte, n)
|
||||
_, err := io.ReadFull(rand.Reader, b)
|
||||
require.NoError(t, err)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
// Package dbcrypt provides a database.Store wrapper that encrypts/decrypts
|
||||
// values stored at rest in the database.
|
||||
//
|
||||
// Encryption is done using Ciphers, which is an abstraction over a set of
|
||||
// encryption keys. Each key has a unique identifier, which is used to
|
||||
// uniquely identify the key whilst maintaining secrecy.
|
||||
//
|
||||
// Currently, AES-256-GCM is the only implemented cipher mode.
|
||||
// The Cipher is currently used to encrypt/decrypt the following fields:
|
||||
// - database.UserLink.OAuthAccessToken
|
||||
// - database.UserLink.OAuthRefreshToken
|
||||
// - database.GitAuthLink.OAuthAccessToken
|
||||
// - database.GitAuthLink.OAuthRefreshToken
|
||||
// - database.DBCryptSentinelValue
|
||||
//
|
||||
// Multiple ciphers can be provided to support key rotation. The primary cipher
|
||||
// is used to encrypt and decrypt all data. Secondary ciphers are only used
|
||||
// for decryption and, as a general rule, should only be active when rotating
|
||||
// keys.
|
||||
//
|
||||
// Encryption keys are stored in the database in the table `dbcrypt_keys`.
|
||||
// The table has the following schema:
|
||||
// - number: the key number. This is used to avoid conflicts when rotating keys.
|
||||
// - created_at: the time the key was created.
|
||||
// - active_key_digest: the SHA256 digest of the active key. If null, the key has been revoked.
|
||||
// - revoked_key_digest: the SHA256 digest of the revoked key. If null, the key has not been revoked.
|
||||
// - revoked_at: the time the key was revoked. If null, the key has not been revoked.
|
||||
// - test: the encrypted value of the string "coder". This is used to ensure that the key is valid.
|
||||
//
|
||||
// Encrypted fields are stored in the database as a base64-encoded string.
|
||||
// Each encrypted column MUST have a corresponding _key_id column that is a foreign key
|
||||
// reference to `dbcrypt_keys.active_key_digest`. This ensures that a key cannot be
|
||||
// revoked until all rows that use that key have been migrated to a new key.
|
||||
package dbcrypt
|
Loading…
Reference in New Issue