mirror of https://github.com/coder/coder.git
350 lines
8.6 KiB
Go
350 lines
8.6 KiB
Go
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"testing"
|
|
|
|
"github.com/spf13/pflag"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/sloghuman"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/serpent"
|
|
)
|
|
|
|
func Test_configureCipherSuites(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cipherNames := func(ciphers []*tls.CipherSuite) []string {
|
|
var names []string
|
|
for _, c := range ciphers {
|
|
names = append(names, c.Name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
cipherIDs := func(ciphers []*tls.CipherSuite) []uint16 {
|
|
var ids []uint16
|
|
for _, c := range ciphers {
|
|
ids = append(ids, c.ID)
|
|
}
|
|
return ids
|
|
}
|
|
|
|
cipherByName := func(cipher string) *tls.CipherSuite {
|
|
for _, c := range append(tls.CipherSuites(), tls.InsecureCipherSuites()...) {
|
|
if cipher == c.Name {
|
|
c := c
|
|
return c
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
wantErr string
|
|
wantWarnings []string
|
|
inputCiphers []string
|
|
minTLS uint16
|
|
maxTLS uint16
|
|
allowInsecure bool
|
|
expectCiphers []uint16
|
|
}{
|
|
{
|
|
name: "AllSecure",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: cipherNames(tls.CipherSuites()),
|
|
wantWarnings: []string{},
|
|
expectCiphers: cipherIDs(tls.CipherSuites()),
|
|
},
|
|
{
|
|
name: "AllowInsecure",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: append(cipherNames(tls.CipherSuites()), tls.InsecureCipherSuites()[0].Name),
|
|
allowInsecure: true,
|
|
wantWarnings: []string{
|
|
"insecure tls cipher specified",
|
|
},
|
|
expectCiphers: append(cipherIDs(tls.CipherSuites()), tls.InsecureCipherSuites()[0].ID),
|
|
},
|
|
{
|
|
name: "AllInsecure",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: append(cipherNames(tls.CipherSuites()), cipherNames(tls.InsecureCipherSuites())...),
|
|
allowInsecure: true,
|
|
wantWarnings: []string{
|
|
"insecure tls cipher specified",
|
|
},
|
|
expectCiphers: append(cipherIDs(tls.CipherSuites()), cipherIDs(tls.InsecureCipherSuites())...),
|
|
},
|
|
{
|
|
// Providing ciphers that are not compatible with any tls version
|
|
// enabled should generate a warning.
|
|
name: "ExcessiveCiphers",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS11,
|
|
inputCiphers: []string{
|
|
"TLS_RSA_WITH_AES_128_CBC_SHA",
|
|
// Only for TLS 1.3
|
|
"TLS_AES_128_GCM_SHA256",
|
|
},
|
|
allowInsecure: true,
|
|
wantWarnings: []string{
|
|
"cipher not supported for tls versions",
|
|
},
|
|
expectCiphers: cipherIDs([]*tls.CipherSuite{
|
|
cipherByName("TLS_RSA_WITH_AES_128_CBC_SHA"),
|
|
cipherByName("TLS_AES_128_GCM_SHA256"),
|
|
}),
|
|
},
|
|
// Errors
|
|
{
|
|
name: "NotRealCiphers",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: []string{"RSA-Fake"},
|
|
wantErr: "unsupported tls ciphers",
|
|
},
|
|
{
|
|
name: "NoCiphers",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
wantErr: "no tls ciphers supported",
|
|
},
|
|
{
|
|
name: "InsecureNotAllowed",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: append(cipherNames(tls.CipherSuites()), tls.InsecureCipherSuites()[0].Name),
|
|
wantErr: "insecure tls ciphers specified",
|
|
},
|
|
{
|
|
name: "TLS1.3",
|
|
minTLS: tls.VersionTLS13,
|
|
maxTLS: tls.VersionTLS13,
|
|
inputCiphers: cipherNames(tls.CipherSuites()),
|
|
wantErr: "'--tls-ciphers' cannot be specified when using minimum tls version 1.3",
|
|
},
|
|
{
|
|
name: "TLSUnsupported",
|
|
minTLS: tls.VersionTLS10,
|
|
maxTLS: tls.VersionTLS13,
|
|
// TLS_RSA_WITH_AES_128_GCM_SHA256 only supports tls 1.2
|
|
inputCiphers: []string{"TLS_RSA_WITH_AES_128_GCM_SHA256"},
|
|
wantErr: "no tls ciphers supported for tls versions",
|
|
},
|
|
{
|
|
name: "Min>Max",
|
|
minTLS: tls.VersionTLS13,
|
|
maxTLS: tls.VersionTLS12,
|
|
wantErr: "minimum tls version (TLS 1.3) cannot be greater than maximum tls version (TLS 1.2)",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
var out bytes.Buffer
|
|
logger := slog.Make(sloghuman.Sink(&out))
|
|
|
|
found, err := configureCipherSuites(ctx, logger, tt.inputCiphers, tt.allowInsecure, tt.minTLS, tt.maxTLS)
|
|
if tt.wantErr != "" {
|
|
require.ErrorContains(t, err, tt.wantErr)
|
|
} else {
|
|
require.NoError(t, err, "no error")
|
|
require.ElementsMatch(t, tt.expectCiphers, found, "expected ciphers")
|
|
if len(tt.wantWarnings) > 0 {
|
|
logger.Sync()
|
|
for _, w := range tt.wantWarnings {
|
|
assert.Contains(t, out.String(), w, "expected warning")
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRedirectHTTPToHTTPSDeprecation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testcases := []struct {
|
|
name string
|
|
environ serpent.Environ
|
|
flags []string
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "AllUnset",
|
|
environ: serpent.Environ{},
|
|
flags: []string{},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP=true",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP", Value: "true"}},
|
|
flags: []string{},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS=true",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS", Value: "true"}},
|
|
flags: []string{},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP=false",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP", Value: "false"}},
|
|
flags: []string{},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS=false",
|
|
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS", Value: "false"}},
|
|
flags: []string{},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "--tls-redirect-http-to-https",
|
|
environ: serpent.Environ{},
|
|
flags: []string{"--tls-redirect-http-to-https"},
|
|
expected: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testcases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil)
|
|
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
|
_ = flags.Bool("tls-redirect-http-to-https", true, "")
|
|
err := flags.Parse(tc.flags)
|
|
require.NoError(t, err)
|
|
inv := (&serpent.Invocation{Environ: tc.environ}).WithTestParsedFlags(t, flags)
|
|
cfg := &codersdk.DeploymentValues{}
|
|
opts := cfg.Options()
|
|
err = opts.SetDefaults()
|
|
require.NoError(t, err)
|
|
redirectHTTPToHTTPSDeprecation(ctx, logger, inv, cfg)
|
|
require.Equal(t, tc.expected, cfg.RedirectToAccessURL.Value())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsDERPPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testcases := []struct {
|
|
path string
|
|
expected bool
|
|
}{
|
|
//{
|
|
// path: "/derp",
|
|
// expected: true,
|
|
// },
|
|
{
|
|
path: "/derp/",
|
|
expected: true,
|
|
},
|
|
{
|
|
path: "/derp/latency-check",
|
|
expected: true,
|
|
},
|
|
{
|
|
path: "/derp/latency-check/",
|
|
expected: true,
|
|
},
|
|
{
|
|
path: "",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "/",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "/derptastic",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "/api/v2/derp",
|
|
expected: false,
|
|
},
|
|
{
|
|
path: "//",
|
|
expected: false,
|
|
},
|
|
}
|
|
for _, tc := range testcases {
|
|
tc := tc
|
|
t.Run(tc.path, func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, tc.expected, isDERPPath(tc.path))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEscapePostgresURLUserInfo(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testcases := []struct {
|
|
input string
|
|
output string
|
|
err error
|
|
}{
|
|
{
|
|
input: "postgres://coder:coder@localhost:5432/coder",
|
|
output: "postgres://coder:coder@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co{der@localhost:5432/coder",
|
|
output: "postgres://coder:co%7Bder@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co:der@localhost:5432/coder",
|
|
output: "postgres://coder:co:der@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://coder:co der@localhost:5432/coder",
|
|
output: "postgres://coder:co%20der@localhost:5432/coder",
|
|
err: nil,
|
|
},
|
|
{
|
|
input: "postgres://local host:5432/coder",
|
|
output: "",
|
|
err: xerrors.New("parse postgres url: parse \"postgres://local host:5432/coder\": invalid character \" \" in host name"),
|
|
},
|
|
}
|
|
for _, tc := range testcases {
|
|
tc := tc
|
|
t.Run(tc.input, func(t *testing.T) {
|
|
t.Parallel()
|
|
o, err := escapePostgresURLUserInfo(tc.input)
|
|
require.Equal(t, tc.output, o)
|
|
if tc.err != nil {
|
|
require.Error(t, err)
|
|
require.EqualValues(t, tc.err.Error(), err.Error())
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|