coder/coderd/devtunnel/tunnel_test.go

239 lines
6.3 KiB
Go

package devtunnel_test
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"testing"
"time"
"cdr.dev/slog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/devtunnel"
"github.com/coder/coder/testutil"
)
const (
ipByte1 = 0xfc
ipByte2 = 0xca
wgPort = 48732
)
var (
serverIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x1})
dnsIP = netip.AddrFrom4([4]byte{1, 1, 1, 1})
clientIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x2})
)
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
// func TestMain(m *testing.M) {
// goleak.VerifyTestMain(m)
// }
// TestTunnel cannot run in parallel because we hardcode the UDP port used by the wireguard server.
// nolint: paralleltest
func TestTunnel(t *testing.T) {
ctx, cancelTun := context.WithCancel(context.Background())
defer cancelTun()
server := http.Server{
ReadHeaderTimeout: time.Minute,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log("got request for", r.URL)
// Going to use something _slightly_ exotic so that we can't accidentally get some
// default behavior creating a false positive on the test
w.WriteHeader(http.StatusAccepted)
}),
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
}
fTunServer := newFakeTunnelServer(t)
cfg := fTunServer.config()
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
require.NoError(t, err)
t.Log(tun.URL)
go func() {
err := server.Serve(tun.Listener)
assert.Equal(t, http.ErrServerClosed, err)
}()
defer func() { _ = server.Close() }()
defer func() { tun.Listener.Close() }()
require.Eventually(t, func() bool {
res, err := fTunServer.requestHTTP()
if !assert.NoError(t, err) {
return false
}
defer res.Body.Close()
_, _ = io.Copy(io.Discard, res.Body)
return res.StatusCode == http.StatusAccepted
}, testutil.WaitShort, testutil.IntervalSlow)
assert.NoError(t, server.Close())
cancelTun()
select {
case <-errCh:
case <-time.After(testutil.WaitLong):
t.Errorf("tunnel did not close after %s", testutil.WaitLong)
}
}
// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
// that we want to test:
// 1. Responding to a POST /tun from the client
// 2. Sending an HTTP request down the wireguard connection
//
// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is
// outside of the scope of the dev tunnel client, which is what we are testing here.
type fakeTunnelServer struct {
t *testing.T
pub device.NoisePublicKey
priv device.NoisePrivateKey
tnet *netstack.Net
device *device.Device
clients int
server *httptest.Server
}
func newFakeTunnelServer(t *testing.T) *fakeTunnelServer {
t.Helper()
priv, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
privBytes := [32]byte(priv)
pub := priv.PublicKey()
pubBytes := [32]byte(pub)
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{serverIP},
[]netip.Addr{dnsIP},
1280,
)
require.NoError(t, err)
ctx := context.Background()
slogger := slogtest.Make(t, nil).Leveled(slog.LevelDebug).Named("server")
logger := &device.Logger{
Verbosef: slog.Stdlib(ctx, slogger, slog.LevelDebug).Printf,
Errorf: slog.Stdlib(ctx, slogger, slog.LevelError).Printf,
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), logger)
t.Cleanup(func() {
dev.RemoveAllPeers()
dev.Close()
slogger.Debug(ctx, "dev.Close()")
})
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
listen_port=%d`,
hex.EncodeToString(privBytes[:]),
wgPort,
))
require.NoError(t, err)
err = dev.Up()
require.NoError(t, err)
server := newFakeTunnelHTTPSServer(t, pubBytes)
return &fakeTunnelServer{
t: t,
pub: device.NoisePublicKey(pub),
priv: device.NoisePrivateKey(priv),
tnet: tnet,
device: dev,
server: server,
}
}
func newFakeTunnelHTTPSServer(t *testing.T, pubBytes [32]byte) *httptest.Server {
handler := http.NewServeMux()
handler.HandleFunc("/tun", func(writer http.ResponseWriter, request *http.Request) {
assert.Equal(t, "POST", request.Method)
resp := devtunnel.ServerResponse{
Hostname: fmt.Sprintf("[%s]", serverIP.String()),
ServerIP: serverIP,
ServerPublicKey: hex.EncodeToString(pubBytes[:]),
ClientIP: clientIP,
}
b, err := json.Marshal(&resp)
assert.NoError(t, err)
writer.WriteHeader(200)
_, err = writer.Write(b)
assert.NoError(t, err)
})
server := httptest.NewTLSServer(handler)
t.Cleanup(func() {
server.Close()
})
return server
}
func (f *fakeTunnelServer) config() devtunnel.Config {
priv, err := wgtypes.GeneratePrivateKey()
require.NoError(f.t, err)
pub := priv.PublicKey()
f.clients++
assert.Equal(f.t, 1, f.clients) // only allow one client as we hardcode the address
err = f.device.IpcSet(fmt.Sprintf(`public_key=%x
allowed_ip=%s/128`,
pub[:],
clientIP.String(),
))
require.NoError(f.t, err)
return devtunnel.Config{
Version: 1,
PrivateKey: device.NoisePrivateKey(priv),
PublicKey: device.NoisePublicKey(pub),
Tunnel: devtunnel.Node{
HostnameHTTPS: strings.TrimPrefix(f.server.URL, "https://"),
HostnameWireguard: "localhost",
WireguardPort: wgPort,
},
HTTPClient: f.server.Client(),
}
}
func (f *fakeTunnelServer) requestHTTP() (*http.Response, error) {
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
f.t.Log("Dial", network, addr)
nc, err := f.tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(clientIP, 8090))
assert.NoError(f.t, err)
return nc, err
},
}
client := &http.Client{
Transport: transport,
Timeout: testutil.WaitLong,
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("http://[%s]:8090", clientIP), nil)
if err != nil {
return nil, err
}
return client.Do(req)
}