coder/cli/server_internal_test.go

350 lines
8.6 KiB
Go

package cli
import (
"bytes"
"context"
"crypto/tls"
"testing"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
func Test_configureCipherSuites(t *testing.T) {
t.Parallel()
cipherNames := func(ciphers []*tls.CipherSuite) []string {
var names []string
for _, c := range ciphers {
names = append(names, c.Name)
}
return names
}
cipherIDs := func(ciphers []*tls.CipherSuite) []uint16 {
var ids []uint16
for _, c := range ciphers {
ids = append(ids, c.ID)
}
return ids
}
cipherByName := func(cipher string) *tls.CipherSuite {
for _, c := range append(tls.CipherSuites(), tls.InsecureCipherSuites()...) {
if cipher == c.Name {
c := c
return c
}
}
return nil
}
tests := []struct {
name string
wantErr string
wantWarnings []string
inputCiphers []string
minTLS uint16
maxTLS uint16
allowInsecure bool
expectCiphers []uint16
}{
{
name: "AllSecure",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
inputCiphers: cipherNames(tls.CipherSuites()),
wantWarnings: []string{},
expectCiphers: cipherIDs(tls.CipherSuites()),
},
{
name: "AllowInsecure",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
inputCiphers: append(cipherNames(tls.CipherSuites()), tls.InsecureCipherSuites()[0].Name),
allowInsecure: true,
wantWarnings: []string{
"insecure tls cipher specified",
},
expectCiphers: append(cipherIDs(tls.CipherSuites()), tls.InsecureCipherSuites()[0].ID),
},
{
name: "AllInsecure",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
inputCiphers: append(cipherNames(tls.CipherSuites()), cipherNames(tls.InsecureCipherSuites())...),
allowInsecure: true,
wantWarnings: []string{
"insecure tls cipher specified",
},
expectCiphers: append(cipherIDs(tls.CipherSuites()), cipherIDs(tls.InsecureCipherSuites())...),
},
{
// Providing ciphers that are not compatible with any tls version
// enabled should generate a warning.
name: "ExcessiveCiphers",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS11,
inputCiphers: []string{
"TLS_RSA_WITH_AES_128_CBC_SHA",
// Only for TLS 1.3
"TLS_AES_128_GCM_SHA256",
},
allowInsecure: true,
wantWarnings: []string{
"cipher not supported for tls versions",
},
expectCiphers: cipherIDs([]*tls.CipherSuite{
cipherByName("TLS_RSA_WITH_AES_128_CBC_SHA"),
cipherByName("TLS_AES_128_GCM_SHA256"),
}),
},
// Errors
{
name: "NotRealCiphers",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
inputCiphers: []string{"RSA-Fake"},
wantErr: "unsupported tls ciphers",
},
{
name: "NoCiphers",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
wantErr: "no tls ciphers supported",
},
{
name: "InsecureNotAllowed",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
inputCiphers: append(cipherNames(tls.CipherSuites()), tls.InsecureCipherSuites()[0].Name),
wantErr: "insecure tls ciphers specified",
},
{
name: "TLS1.3",
minTLS: tls.VersionTLS13,
maxTLS: tls.VersionTLS13,
inputCiphers: cipherNames(tls.CipherSuites()),
wantErr: "'--tls-ciphers' cannot be specified when using minimum tls version 1.3",
},
{
name: "TLSUnsupported",
minTLS: tls.VersionTLS10,
maxTLS: tls.VersionTLS13,
// TLS_RSA_WITH_AES_128_GCM_SHA256 only supports tls 1.2
inputCiphers: []string{"TLS_RSA_WITH_AES_128_GCM_SHA256"},
wantErr: "no tls ciphers supported for tls versions",
},
{
name: "Min>Max",
minTLS: tls.VersionTLS13,
maxTLS: tls.VersionTLS12,
wantErr: "minimum tls version (TLS 1.3) cannot be greater than maximum tls version (TLS 1.2)",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
var out bytes.Buffer
logger := slog.Make(sloghuman.Sink(&out))
found, err := configureCipherSuites(ctx, logger, tt.inputCiphers, tt.allowInsecure, tt.minTLS, tt.maxTLS)
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
} else {
require.NoError(t, err, "no error")
require.ElementsMatch(t, tt.expectCiphers, found, "expected ciphers")
if len(tt.wantWarnings) > 0 {
logger.Sync()
for _, w := range tt.wantWarnings {
assert.Contains(t, out.String(), w, "expected warning")
}
}
}
})
}
}
func TestRedirectHTTPToHTTPSDeprecation(t *testing.T) {
t.Parallel()
testcases := []struct {
name string
environ serpent.Environ
flags []string
expected bool
}{
{
name: "AllUnset",
environ: serpent.Environ{},
flags: []string{},
expected: false,
},
{
name: "CODER_TLS_REDIRECT_HTTP=true",
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP", Value: "true"}},
flags: []string{},
expected: true,
},
{
name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS=true",
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS", Value: "true"}},
flags: []string{},
expected: true,
},
{
name: "CODER_TLS_REDIRECT_HTTP=false",
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP", Value: "false"}},
flags: []string{},
expected: false,
},
{
name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS=false",
environ: serpent.Environ{{Name: "CODER_TLS_REDIRECT_HTTP_TO_HTTPS", Value: "false"}},
flags: []string{},
expected: false,
},
{
name: "--tls-redirect-http-to-https",
environ: serpent.Environ{},
flags: []string{"--tls-redirect-http-to-https"},
expected: true,
},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil)
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
_ = flags.Bool("tls-redirect-http-to-https", true, "")
err := flags.Parse(tc.flags)
require.NoError(t, err)
inv := (&serpent.Invocation{Environ: tc.environ}).WithTestParsedFlags(t, flags)
cfg := &codersdk.DeploymentValues{}
opts := cfg.Options()
err = opts.SetDefaults()
require.NoError(t, err)
redirectHTTPToHTTPSDeprecation(ctx, logger, inv, cfg)
require.Equal(t, tc.expected, cfg.RedirectToAccessURL.Value())
})
}
}
func TestIsDERPPath(t *testing.T) {
t.Parallel()
testcases := []struct {
path string
expected bool
}{
//{
// path: "/derp",
// expected: true,
// },
{
path: "/derp/",
expected: true,
},
{
path: "/derp/latency-check",
expected: true,
},
{
path: "/derp/latency-check/",
expected: true,
},
{
path: "",
expected: false,
},
{
path: "/",
expected: false,
},
{
path: "/derptastic",
expected: false,
},
{
path: "/api/v2/derp",
expected: false,
},
{
path: "//",
expected: false,
},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.path, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.expected, isDERPPath(tc.path))
})
}
}
func TestEscapePostgresURLUserInfo(t *testing.T) {
t.Parallel()
testcases := []struct {
input string
output string
err error
}{
{
input: "postgres://coder:coder@localhost:5432/coder",
output: "postgres://coder:coder@localhost:5432/coder",
err: nil,
},
{
input: "postgres://coder:co{der@localhost:5432/coder",
output: "postgres://coder:co%7Bder@localhost:5432/coder",
err: nil,
},
{
input: "postgres://coder:co:der@localhost:5432/coder",
output: "postgres://coder:co:der@localhost:5432/coder",
err: nil,
},
{
input: "postgres://coder:co der@localhost:5432/coder",
output: "postgres://coder:co%20der@localhost:5432/coder",
err: nil,
},
{
input: "postgres://local host:5432/coder",
output: "",
err: xerrors.New("parse postgres url: parse \"postgres://local host:5432/coder\": invalid character \" \" in host name"),
},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.input, func(t *testing.T) {
t.Parallel()
o, err := escapePostgresURLUserInfo(tc.input)
require.Equal(t, tc.output, o)
if tc.err != nil {
require.Error(t, err)
require.EqualValues(t, tc.err.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}