coder/enterprise/dbcrypt/cipher.go

99 lines
2.5 KiB
Go

package dbcrypt
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/xerrors"
)
// cipherAES256GCM is the name of the AES-256 cipher.
// This is used to identify the cipher used to encrypt a value.
// It is added to the digest to ensure that if, in the future,
// we add a new cipher type, and a key is re-used, we don't
// accidentally decrypt the wrong values.
// When adding a new cipher type, add a new constant here
// and ensure to add the cipher name to the digest of the new
// cipher type.
const (
cipherAES256GCM = "aes256gcm"
)
type Cipher interface {
Encrypt([]byte) ([]byte, error)
Decrypt([]byte) ([]byte, error)
HexDigest() string
}
// NewCiphers is a convenience function for creating multiple ciphers.
// It currently only supports AES-256-GCM.
func NewCiphers(keys ...[]byte) ([]Cipher, error) {
var cs []Cipher
for _, key := range keys {
c, err := cipherAES256(key)
if err != nil {
return nil, err
}
cs = append(cs, c)
}
return cs, nil
}
// cipherAES256 returns a new AES-256 cipher.
func cipherAES256(key []byte) (*aes256, error) {
if len(key) != 32 {
return nil, xerrors.Errorf("key must be 32 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
// We add the cipher name to the digest to ensure that if, in the future,
// we add a new cipher type, and a key is re-used, we don't accidentally
// decrypt the wrong values.
toDigest := []byte(cipherAES256GCM)
toDigest = append(toDigest, key...)
digest := fmt.Sprintf("%x", sha256.Sum256(toDigest))[:7]
return &aes256{aead: aead, digest: digest}, nil
}
type aes256 struct {
aead cipher.AEAD
// digest is the first 7 bytes of the hex-encoded SHA-256 digest of aead.
digest string
}
func (a *aes256) Encrypt(plaintext []byte) ([]byte, error) {
nonce := make([]byte, a.aead.NonceSize())
_, err := io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, err
}
dst := make([]byte, len(nonce))
copy(dst, nonce)
return a.aead.Seal(dst, nonce, plaintext, nil), nil
}
func (a *aes256) Decrypt(ciphertext []byte) ([]byte, error) {
if len(ciphertext) < a.aead.NonceSize() {
return nil, xerrors.Errorf("ciphertext too short")
}
decrypted, err := a.aead.Open(nil, ciphertext[:a.aead.NonceSize()], ciphertext[a.aead.NonceSize():], nil)
if err != nil {
return nil, &DecryptFailedError{Inner: err}
}
return decrypted, nil
}
func (a *aes256) HexDigest() string {
return a.digest
}