coder/coderd/devtunnel/tunnel_test.go

280 lines
7.3 KiB
Go

package devtunnel_test
import (
"context"
"crypto/tls"
"encoding/base32"
"encoding/hex"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"time"
"cdr.dev/slog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/devtunnel"
"github.com/coder/coder/v2/testutil"
"github.com/coder/wgtunnel/tunneld"
"github.com/coder/wgtunnel/tunnelsdk"
)
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
// func TestMain(m *testing.M) {
// goleak.VerifyTestMain(m)
// }
func TestTunnel(t *testing.T) {
t.Parallel()
cases := []struct {
name string
version tunnelsdk.TunnelVersion
}{
{
name: "V1",
version: tunnelsdk.TunnelVersion1,
},
{
name: "V2",
version: tunnelsdk.TunnelVersion2,
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
ctx, cancelTun := context.WithCancel(context.Background())
defer cancelTun()
server := http.Server{
ReadHeaderTimeout: time.Minute,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log("got request for", r.URL)
// Going to use something _slightly_ exotic so that we can't
// accidentally get some default behavior creating a false
// positive on the test
w.WriteHeader(http.StatusAccepted)
}),
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
}
tunServer := newTunnelServer(t)
cfg := tunServer.config(t, c.version)
tun, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
require.NoError(t, err)
require.Len(t, tun.OtherURLs, 1)
t.Log(tun.URL, tun.OtherURLs[0])
hostSplit := strings.SplitN(tun.URL.Host, ".", 2)
require.Len(t, hostSplit, 2)
require.Equal(t, hostSplit[1], tunServer.api.BaseURL.Host)
// Verify the hostname using the same logic as the tunnel server.
ip1, urls := tunServer.api.WireguardPublicKeyToIPAndURLs(cfg.PublicKey, c.version)
require.Len(t, urls, 2)
require.Equal(t, urls[0].String(), tun.URL.String())
require.Equal(t, urls[1].String(), tun.OtherURLs[0].String())
ip2, err := tunServer.api.HostnameToWireguardIP(hostSplit[0])
require.NoError(t, err)
require.Equal(t, ip1, ip2)
// Manually verify the hostname.
switch c.version {
case tunnelsdk.TunnelVersion1:
// The subdomain should be a 32 character hex string.
require.Len(t, hostSplit[0], 32)
_, err := hex.DecodeString(hostSplit[0])
require.NoError(t, err)
case tunnelsdk.TunnelVersion2:
// The subdomain should be a base32 encoded string containing
// 16 bytes once decoded.
dec, err := base32.HexEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(hostSplit[0]))
require.NoError(t, err)
require.Len(t, dec, 8)
}
go func() {
err := server.Serve(tun.Listener)
assert.Equal(t, http.ErrServerClosed, err)
}()
defer func() { _ = server.Close() }()
defer func() { tun.Listener.Close() }()
require.Eventually(t, func() bool {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tun.URL.String(), nil)
if !assert.NoError(t, err) {
return false
}
res, err := tunServer.requestTunnel(tun, req)
if !assert.NoError(t, err) {
return false
}
defer res.Body.Close()
_, _ = io.Copy(io.Discard, res.Body)
return res.StatusCode == http.StatusAccepted
}, testutil.WaitShort, testutil.IntervalSlow)
assert.NoError(t, server.Close())
cancelTun()
select {
case <-tun.Wait():
case <-time.After(testutil.WaitLong):
t.Errorf("tunnel did not close after %s", testutil.WaitLong)
}
})
}
}
func freeUDPPort(t *testing.T) uint16 {
t.Helper()
l, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
require.NoError(t, err, "listen on random UDP port")
_, port, err := net.SplitHostPort(l.LocalAddr().String())
require.NoError(t, err, "split host port")
portUint, err := strconv.ParseUint(port, 10, 16)
require.NoError(t, err, "parse port")
// This is prone to races, but since we have to tell wireguard to create the
// listener and can't pass in a net.Listener, we have to do this.
err = l.Close()
require.NoError(t, err, "close UDP listener")
return uint16(portUint)
}
type tunnelServer struct {
api *tunneld.API
server *httptest.Server
}
func newTunnelServer(t *testing.T) *tunnelServer {
var handler http.Handler
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler != nil {
handler.ServeHTTP(w, r)
}
w.WriteHeader(http.StatusBadGateway)
}))
t.Cleanup(srv.Close)
baseURLParsed, err := url.Parse(srv.URL)
require.NoError(t, err)
require.Equal(t, "https", baseURLParsed.Scheme)
baseURLParsed.Host = net.JoinHostPort("tunnel.coder.com", baseURLParsed.Port())
key, err := tunnelsdk.GeneratePrivateKey()
require.NoError(t, err)
// Sadly the tunnel server needs to be passed a port number and can't be
// passed an active listener (because wireguard needs to make the listener),
// so we may need to try a few times to get a free port.
var td *tunneld.API
for i := 0; i < 10; i++ {
wireguardPort := freeUDPPort(t)
options := &tunneld.Options{
BaseURL: baseURLParsed,
WireguardEndpoint: fmt.Sprintf("127.0.0.1:%d", wireguardPort),
WireguardPort: wireguardPort,
WireguardKey: key,
WireguardMTU: tunneld.DefaultWireguardMTU,
WireguardServerIP: tunneld.DefaultWireguardServerIP,
WireguardNetworkPrefix: tunneld.DefaultWireguardNetworkPrefix,
}
td, err = tunneld.New(options)
if err == nil {
break
}
t.Logf("failed to create tunnel server on port %d: %s", wireguardPort, err)
}
if td == nil {
t.Fatal("failed to create tunnel server in 10 attempts")
}
handler = td.Router()
t.Cleanup(func() {
_ = td.Close()
})
return &tunnelServer{
api: td,
server: srv,
}
}
func (s *tunnelServer) client() *http.Client {
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "tcp", s.server.Listener.Addr().String())
},
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
}
return &http.Client{
Transport: transport,
Timeout: testutil.WaitLong,
}
}
func (s *tunnelServer) config(t *testing.T, version tunnelsdk.TunnelVersion) devtunnel.Config {
priv, err := tunnelsdk.GeneratePrivateKey()
require.NoError(t, err)
privNoise, err := priv.NoisePrivateKey()
require.NoError(t, err)
pubNoise := priv.NoisePublicKey()
if version == 0 {
version = tunnelsdk.TunnelVersionLatest
}
return devtunnel.Config{
Version: version,
PrivateKey: privNoise,
PublicKey: pubNoise,
Tunnel: devtunnel.Node{
RegionID: 0,
ID: 1,
HostnameHTTPS: s.api.BaseURL.Host,
},
HTTPClient: s.client(),
}
}
// requestTunnel performs the given request against the tunnel. The Host header
// will be set to the tunnel's hostname.
func (s *tunnelServer) requestTunnel(tunnel *tunnelsdk.Tunnel, req *http.Request) (*http.Response, error) {
req.URL.Scheme = "https"
req.URL.Host = tunnel.URL.Host
req.Host = tunnel.URL.Host
return s.client().Do(req)
}