mirror of https://github.com/coder/coder.git
chore: use robust RNG in `cryptorand` (#8040)
This commit is contained in:
parent
c8e67833f5
commit
ca6b9e9368
|
@ -30,31 +30,6 @@ func TestRandError(t *testing.T) {
|
|||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Int63 error")
|
||||
})
|
||||
|
||||
t.Run("Uint64", func(t *testing.T) {
|
||||
_, err := cryptorand.Uint64()
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Uint64 error")
|
||||
})
|
||||
|
||||
t.Run("Int31", func(t *testing.T) {
|
||||
_, err := cryptorand.Int31()
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Int31 error")
|
||||
})
|
||||
|
||||
t.Run("Int31n", func(t *testing.T) {
|
||||
_, err := cryptorand.Int31n(100)
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Int31n error")
|
||||
})
|
||||
|
||||
t.Run("Uint32", func(t *testing.T) {
|
||||
_, err := cryptorand.Uint32()
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Uint32 error")
|
||||
})
|
||||
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
_, err := cryptorand.Int()
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Int error")
|
||||
})
|
||||
|
||||
t.Run("Intn_32bit", func(t *testing.T) {
|
||||
_, err := cryptorand.Intn(100)
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Intn error")
|
||||
|
@ -70,11 +45,6 @@ func TestRandError(t *testing.T) {
|
|||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Float64 error")
|
||||
})
|
||||
|
||||
t.Run("Float32", func(t *testing.T) {
|
||||
_, err := cryptorand.Float32()
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected Float32 error")
|
||||
})
|
||||
|
||||
t.Run("StringCharset", func(t *testing.T) {
|
||||
_, err := cryptorand.HexString(10)
|
||||
require.ErrorIs(t, err, io.ErrShortBuffer, "expected HexString error")
|
||||
|
|
|
@ -3,194 +3,58 @@ package cryptorand
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
insecurerand "math/rand"
|
||||
)
|
||||
|
||||
// Most of this code is inspired by math/rand, so shares similar
|
||||
// functions and implementations, but uses crypto/rand to generate
|
||||
// random Int63 data.
|
||||
type cryptoSource struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (*cryptoSource) Seed(_ int64) {
|
||||
// Intentionally disregard seed
|
||||
}
|
||||
|
||||
func (c *cryptoSource) Int63() int64 {
|
||||
var n int64
|
||||
err := binary.Read(rand.Reader, binary.BigEndian, &n)
|
||||
if err != nil {
|
||||
c.err = err
|
||||
}
|
||||
// The sign bit must be cleared to ensure the final value is non-negative.
|
||||
n &= 0x7fffffffffffffff
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *cryptoSource) Uint64() uint64 {
|
||||
var n uint64
|
||||
err := binary.Read(rand.Reader, binary.BigEndian, &n)
|
||||
if err != nil {
|
||||
c.err = err
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// secureRand returns a cryptographically secure random number generator.
|
||||
func secureRand() (*insecurerand.Rand, *cryptoSource) {
|
||||
var cs cryptoSource
|
||||
//nolint:gosec
|
||||
return insecurerand.New(&cs), &cs
|
||||
}
|
||||
|
||||
// Int64 returns a non-negative random 63-bit integer as a int64.
|
||||
func Int63() (int64, error) {
|
||||
var i int64
|
||||
err := binary.Read(rand.Reader, binary.BigEndian, &i)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("read binary: %w", err)
|
||||
}
|
||||
|
||||
if i < 0 {
|
||||
return -i, nil
|
||||
}
|
||||
return i, nil
|
||||
rng, cs := secureRand()
|
||||
return rng.Int63(), cs.err
|
||||
}
|
||||
|
||||
// Uint64 returns a random 64-bit integer as a uint64.
|
||||
func Uint64() (uint64, error) {
|
||||
upper, err := Int63()
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("read upper: %w", err)
|
||||
}
|
||||
|
||||
lower, err := Int63()
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("read lower: %w", err)
|
||||
}
|
||||
|
||||
return uint64(lower)>>31 | uint64(upper)<<32, nil
|
||||
}
|
||||
|
||||
// Int31 returns a non-negative random 31-bit integer as a int32.
|
||||
func Int31() (int32, error) {
|
||||
i, err := Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int32(i >> 32), nil
|
||||
}
|
||||
|
||||
// Uint32 returns a 32-bit value as a uint32.
|
||||
func Uint32() (uint32, error) {
|
||||
i, err := Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint32(i >> 31), nil
|
||||
}
|
||||
|
||||
// Int returns a non-negative random integer as a int.
|
||||
func Int() (int, error) {
|
||||
i, err := Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if i < 0 {
|
||||
return int(-i), nil
|
||||
}
|
||||
return int(i), nil
|
||||
}
|
||||
|
||||
// Int63n returns a non-negative random integer in [0,max) as a int64.
|
||||
func Int63n(max int64) (int64, error) {
|
||||
if max <= 0 {
|
||||
panic("invalid argument to Int63n")
|
||||
}
|
||||
|
||||
trueMax := int64((1 << 63) - 1 - (1<<63)%uint64(max))
|
||||
i, err := Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i > trueMax {
|
||||
i, err = Int63()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return i % max, nil
|
||||
}
|
||||
|
||||
// Int31n returns a non-negative integer in [0,max) as a int32.
|
||||
func Int31n(max int32) (int32, error) {
|
||||
i, err := Uint32()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return UnbiasedModulo32(i, max)
|
||||
}
|
||||
|
||||
// UnbiasedModulo32 uniformly modulos v by n over a sufficiently large data
|
||||
// set, regenerating v if necessary. n must be > 0. All input bits in v must be
|
||||
// fully random, you cannot cast a random uint8/uint16 for input into this
|
||||
// function.
|
||||
//
|
||||
//nolint:varnamelen
|
||||
func UnbiasedModulo32(v uint32, n int32) (int32, error) {
|
||||
prod := uint64(v) * uint64(n)
|
||||
low := uint32(prod)
|
||||
if low < uint32(n) {
|
||||
thresh := uint32(-n) % uint32(n)
|
||||
for low < thresh {
|
||||
var err error
|
||||
v, err = Uint32()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
prod = uint64(v) * uint64(n)
|
||||
low = uint32(prod)
|
||||
}
|
||||
}
|
||||
return int32(prod >> 32), nil
|
||||
}
|
||||
|
||||
// Intn returns a non-negative integer in [0,max) as a int.
|
||||
// Intn returns a non-negative integer in [0,max) as an int.
|
||||
func Intn(max int) (int, error) {
|
||||
if max <= 0 {
|
||||
panic("n must be a positive nonzero number")
|
||||
}
|
||||
|
||||
if max <= 1<<31-1 {
|
||||
i, err := Int31n(int32(max))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int(i), nil
|
||||
}
|
||||
|
||||
i, err := Int63n(int64(max))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int(i), nil
|
||||
rng, cs := secureRand()
|
||||
return rng.Intn(max), cs.err
|
||||
}
|
||||
|
||||
// Float64 returns a random number in [0.0,1.0) as a float64.
|
||||
func Float64() (float64, error) {
|
||||
again:
|
||||
i, err := Int63n(1 << 53)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f := (float64(i) / (1 << 53))
|
||||
if f == 1 {
|
||||
goto again
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Float32 returns a random number in [0.0,1.0) as a float32.
|
||||
func Float32() (float32, error) {
|
||||
again:
|
||||
i, err := Float64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f := float32(i)
|
||||
if f == 1 {
|
||||
goto again
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Duration returns a random time.Duration value
|
||||
func Duration() (time.Duration, error) {
|
||||
i, err := Int63()
|
||||
if err != nil {
|
||||
return time.Duration(0), err
|
||||
}
|
||||
|
||||
return time.Duration(i), nil
|
||||
rng, cs := secureRand()
|
||||
return rng.Float64(), cs.err
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package cryptorand_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -21,96 +19,6 @@ func TestInt63(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUint64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Uint64()
|
||||
require.NoError(t, err, "unexpected error from Uint64")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt31(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Int31()
|
||||
require.NoError(t, err, "unexpected error from Int31")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0, "values must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbiasedModulo32(t *testing.T) {
|
||||
t.Parallel()
|
||||
const mod = 7
|
||||
dist := [mod]uint32{}
|
||||
|
||||
_, err := cryptorand.UnbiasedModulo32(0, mod)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
b := [4]byte{}
|
||||
_, _ = rand.Read(b[:])
|
||||
v, err := cryptorand.UnbiasedModulo32(binary.BigEndian.Uint32(b[:]), mod)
|
||||
require.NoError(t, err, "unexpected error from UnbiasedModulo32")
|
||||
dist[v]++
|
||||
}
|
||||
|
||||
t.Logf("dist: %+v <- evenly distributed?", dist)
|
||||
}
|
||||
|
||||
func TestUint32(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Uint32()
|
||||
require.NoError(t, err, "unexpected error from Uint32")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Int()
|
||||
require.NoError(t, err, "unexpected error from Int")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0, "values must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt63n(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Int63n(1 << 35)
|
||||
require.NoError(t, err, "unexpected error from Int63n")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0, "values must be positive")
|
||||
require.True(t, v < 1<<35, "values must be less than 1<<35")
|
||||
}
|
||||
|
||||
// Expect a panic if max is negative
|
||||
require.PanicsWithValue(t, "invalid argument to Int63n", func() {
|
||||
cryptorand.Int63n(0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInt31n(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Int31n(100)
|
||||
require.NoError(t, err, "unexpected error from Int31n")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0, "values must be positive")
|
||||
require.True(t, v < 100, "values must be less than 100")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -118,7 +26,7 @@ func TestIntn(t *testing.T) {
|
|||
v, err := cryptorand.Intn(100)
|
||||
require.NoError(t, err, "unexpected error from Intn")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0, "values must be positive")
|
||||
require.GreaterOrEqual(t, v, 0, "values must be positive")
|
||||
require.True(t, v < 100, "values must be less than 100")
|
||||
}
|
||||
|
||||
|
@ -127,7 +35,7 @@ func TestIntn(t *testing.T) {
|
|||
require.NoError(t, err, "expected Intn to work for 64-bit int")
|
||||
|
||||
// Expect a panic if max is negative
|
||||
require.PanicsWithValue(t, "n must be a positive nonzero number", func() {
|
||||
require.PanicsWithValue(t, "invalid argument to Intn", func() {
|
||||
cryptorand.Intn(0)
|
||||
})
|
||||
}
|
||||
|
@ -143,26 +51,3 @@ func TestFloat64(t *testing.T) {
|
|||
require.True(t, v < 1.0, "values must be less than 1.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloat32(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Float32()
|
||||
require.NoError(t, err, "unexpected error from Float32")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0.0, "values must be positive")
|
||||
require.True(t, v < 1.0, "values must be less than 1.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
v, err := cryptorand.Duration()
|
||||
require.NoError(t, err, "unexpected error from Duration")
|
||||
t.Logf("value: %v <- random?", v)
|
||||
require.True(t, v >= 0.0, "values must be positive")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Charsets
|
||||
|
@ -32,19 +34,48 @@ const (
|
|||
Human = "23456789abcdefghjkmnpqrstuvwxyz"
|
||||
)
|
||||
|
||||
// StringCharset generates a random string using the provided charset and size
|
||||
func StringCharset(charSetStr string, size int) (string, error) {
|
||||
charSet := []rune(charSetStr)
|
||||
// unbiasedModulo32 uniformly modulos v by n over a sufficiently large data
|
||||
// set, regenerating v if necessary. n must be > 0. All input bits in v must be
|
||||
// fully random, you cannot cast a random uint8/uint16 for input into this
|
||||
// function.
|
||||
//
|
||||
// See more details on this algorithm here:
|
||||
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
|
||||
//
|
||||
//nolint:varnamelen
|
||||
func unbiasedModulo32(v uint32, n int32) (int32, error) {
|
||||
prod := uint64(v) * uint64(n)
|
||||
low := uint32(prod)
|
||||
if low < uint32(n) {
|
||||
thresh := uint32(-n) % uint32(n)
|
||||
for low < thresh {
|
||||
err := binary.Read(rand.Reader, binary.BigEndian, &v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
prod = uint64(v) * uint64(n)
|
||||
low = uint32(prod)
|
||||
}
|
||||
}
|
||||
return int32(prod >> 32), nil
|
||||
}
|
||||
|
||||
if len(charSet) == 0 || size == 0 {
|
||||
// StringCharset generates a random string using the provided charset and size.
|
||||
func StringCharset(charSetStr string, size int) (string, error) {
|
||||
if size == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// This buffer facilitates pre-emptively creation of random uint32s
|
||||
// to reduce syscall overhead.
|
||||
ibuf := make([]byte, 4*size)
|
||||
if len(charSetStr) == 0 {
|
||||
return "", xerrors.Errorf("charSetStr must not be empty")
|
||||
}
|
||||
|
||||
_, err := rand.Read(ibuf)
|
||||
charSet := []rune(charSetStr)
|
||||
|
||||
// We pre-allocate the entropy to amortize the crypto/rand syscall overhead.
|
||||
entropy := make([]byte, 4*size)
|
||||
|
||||
_, err := rand.Read(entropy)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -53,15 +84,18 @@ func StringCharset(charSetStr string, size int) (string, error) {
|
|||
buf.Grow(size)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
count, err := UnbiasedModulo32(
|
||||
binary.BigEndian.Uint32(ibuf[i*4:(i+1)*4]),
|
||||
r := binary.BigEndian.Uint32(entropy[:4])
|
||||
entropy = entropy[4:]
|
||||
|
||||
ci, err := unbiasedModulo32(
|
||||
r,
|
||||
int32(len(charSet)),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, _ = buf.WriteRune(charSet[count])
|
||||
_, _ = buf.WriteRune(charSet[ci])
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
|
|
Loading…
Reference in New Issue