restore devtunnel test (#3050)

* Dev tunnel test uses local fake server; fixed port

Signed-off-by: Spike Curtis <spike@coder.com>

* Remove parallel for test

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix segfault
This commit is contained in:
Spike Curtis 2022-07-22 08:26:39 -07:00 committed by GitHub
parent 882ee55fd0
commit fa4361db76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 189 additions and 31 deletions

View File

@ -36,6 +36,9 @@ type Config struct {
PublicKey device.NoisePublicKey `json:"public_key"`
Tunnel Node `json:"tunnel"`
// Used in testing. Normally this is nil, indicating to use DefaultClient.
HTTPClient *http.Client `json:"-"`
}
type configExt struct {
Version int `json:"-"`
@ -43,6 +46,9 @@ type configExt struct {
PublicKey device.NoisePublicKey `json:"public_key"`
Tunnel Node `json:"-"`
// Used in testing. Normally this is nil, indicating to use DefaultClient.
HTTPClient *http.Client `json:"-"`
}
// NewWithConfig calls New with the given config. For documentation, see New.
@ -65,17 +71,23 @@ func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel
if err != nil {
return nil, nil, xerrors.Errorf("resolve endpoint: %w", err)
}
// In IPv6, we need to enclose the address to in [] before passing to wireguard's endpoint key, like
// [2001:abcd::1]:8888. We'll use netip.AddrPort to correctly handle this.
wgAddr, err := netip.ParseAddr(wgip.String())
if err != nil {
return nil, nil, xerrors.Errorf("parse address: %w", err)
}
wgEndpoint := netip.AddrPortFrom(wgAddr, cfg.Tunnel.WireguardPort)
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, ""))
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "devtunnel "))
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
public_key=%s
endpoint=%s:%d
endpoint=%s
persistent_keepalive_interval=21
allowed_ip=%s/128`,
hex.EncodeToString(cfg.PrivateKey[:]),
server.ServerPublicKey,
wgip.IP.String(),
cfg.Tunnel.WireguardPort,
wgEndpoint.String(),
server.ServerIP.String(),
))
if err != nil {
@ -97,6 +109,9 @@ allowed_ip=%s/128`,
select {
case <-ctx.Done():
_ = wgListen.Close()
// We need to remove peers before closing to avoid a race condition between dev.Close() and the peer
// goroutines which results in segfault.
dev.RemoveAllPeers()
dev.Close()
<-routineEnd
close(ch)
@ -174,7 +189,11 @@ func sendConfigToServer(ctx context.Context, cfg Config) (ServerResponse, error)
return ServerResponse{}, xerrors.Errorf("new request: %w", err)
}
res, err := http.DefaultClient.Do(req)
client := http.DefaultClient
if cfg.HTTPClient != nil {
client = cfg.HTTPClient
}
res, err := client.Do(req)
if err != nil {
return ServerResponse{}, xerrors.Errorf("do request: %w", err)
}

View File

@ -2,74 +2,89 @@ package devtunnel_test
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"testing"
"time"
"cdr.dev/slog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/devtunnel"
)
const (
ipByte1 = 0xfc
ipByte2 = 0xca
wgPort = 48732
)
var (
serverIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x1})
dnsIP = netip.AddrFrom4([4]byte{1, 1, 1, 1})
clientIP = netip.AddrFrom16([16]byte{ipByte1, ipByte2, 15: 0x2})
)
// The tunnel leaks a few goroutines that aren't impactful to production scenarios.
// func TestMain(m *testing.M) {
// goleak.VerifyTestMain(m)
// }
// TestTunnel cannot run in parallel because we hardcode the UDP port used by the wireguard server.
// nolint: paralleltest
func TestTunnel(t *testing.T) {
t.Parallel()
// It's not super useful for us to test this constantly, it'll only cause
// flakes is the tunnel becomes unavailable for some reason.
t.Skip()
// if testing.Short() {
// t.Skip()
// return
// }
ctx, cancelTun := context.WithCancel(context.Background())
defer cancelTun()
server := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
t.Log("got request for", r.URL)
// Going to use something _slightly_ exotic so that we can't accidentally get some
// default behavior creating a false positive on the test
w.WriteHeader(http.StatusAccepted)
}),
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
}
cfg, err := devtunnel.GenerateConfig()
require.NoError(t, err)
fTunServer := newFakeTunnelServer(t)
cfg := fTunServer.config()
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil), cfg)
tun, errCh, err := devtunnel.NewWithConfig(ctx, slogtest.Make(t, nil).Leveled(slog.LevelDebug), cfg)
require.NoError(t, err)
t.Log(tun.URL)
go server.Serve(tun.Listener)
defer tun.Listener.Close()
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
go func() {
err := server.Serve(tun.Listener)
assert.Equal(t, http.ErrServerClosed, err)
}()
t.Cleanup(func() { _ = server.Close() })
t.Cleanup(func() { tun.Listener.Close() })
require.Eventually(t, func() bool {
req, err := http.NewRequestWithContext(ctx, "GET", tun.URL, nil)
require.NoError(t, err)
res, err := httpClient.Do(req)
res, err := fTunServer.requestHTTP()
require.NoError(t, err)
defer res.Body.Close()
_, _ = io.Copy(io.Discard, res.Body)
return res.StatusCode == http.StatusOK
return res.StatusCode == http.StatusAccepted
}, time.Minute, time.Second)
httpClient.CloseIdleConnections()
assert.NoError(t, server.Close())
cancelTun()
@ -79,3 +94,127 @@ func TestTunnel(t *testing.T) {
t.Error("tunnel did not close after 10 seconds")
}
}
// fakeTunnelServer is a fake version of the real dev tunnel server. It fakes 2 client interactions
// that we want to test:
// 1. Responding to a POST /tun from the client
// 2. Sending an HTTP request down the wireguard connection
//
// Note that for 2, we don't implement a full proxy that accepts arbitrary requests, we just send
// a test request over the Wireguard tunnel to make sure that we can listen. The proxy behavior is
// outside of the scope of the dev tunnel client, which is what we are testing here.
type fakeTunnelServer struct {
t *testing.T
pub device.NoisePublicKey
priv device.NoisePrivateKey
tnet *netstack.Net
device *device.Device
clients int
server *httptest.Server
}
func newFakeTunnelServer(t *testing.T) *fakeTunnelServer {
priv, err := wgtypes.GeneratePrivateKey()
privBytes := [32]byte(priv)
require.NoError(t, err)
pub := priv.PublicKey()
pubBytes := [32]byte(pub)
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{serverIP},
[]netip.Addr{dnsIP},
1280,
)
require.NoError(t, err)
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "server "))
err = dev.IpcSet(fmt.Sprintf(`private_key=%s
listen_port=%d`,
hex.EncodeToString(privBytes[:]),
wgPort,
))
require.NoError(t, err)
t.Cleanup(func() {
dev.RemoveAllPeers()
dev.Close()
})
err = dev.Up()
require.NoError(t, err)
server := newFakeTunnelHTTPSServer(t, pubBytes)
return &fakeTunnelServer{
t: t,
pub: device.NoisePublicKey(pub),
priv: device.NoisePrivateKey(priv),
tnet: tnet,
device: dev,
server: server,
}
}
func newFakeTunnelHTTPSServer(t *testing.T, pubBytes [32]byte) *httptest.Server {
handler := http.NewServeMux()
handler.HandleFunc("/tun", func(writer http.ResponseWriter, request *http.Request) {
assert.Equal(t, "POST", request.Method)
resp := devtunnel.ServerResponse{
Hostname: fmt.Sprintf("[%s]", serverIP.String()),
ServerIP: serverIP,
ServerPublicKey: hex.EncodeToString(pubBytes[:]),
ClientIP: clientIP,
}
b, err := json.Marshal(&resp)
assert.NoError(t, err)
writer.WriteHeader(200)
_, err = writer.Write(b)
assert.NoError(t, err)
})
server := httptest.NewTLSServer(handler)
t.Cleanup(func() {
server.Close()
})
return server
}
func (f *fakeTunnelServer) config() devtunnel.Config {
priv, err := wgtypes.GeneratePrivateKey()
require.NoError(f.t, err)
pub := priv.PublicKey()
f.clients++
assert.Equal(f.t, 1, f.clients) // only allow one client as we hardcode the address
err = f.device.IpcSet(fmt.Sprintf(`public_key=%x
allowed_ip=%s/128`,
pub[:],
clientIP.String(),
))
require.NoError(f.t, err)
return devtunnel.Config{
Version: 1,
PrivateKey: device.NoisePrivateKey(priv),
PublicKey: device.NoisePublicKey(pub),
Tunnel: devtunnel.Node{
HostnameHTTPS: strings.TrimPrefix(f.server.URL, "https://"),
HostnameWireguard: "localhost",
WireguardPort: wgPort,
},
HTTPClient: f.server.Client(),
}
}
func (f *fakeTunnelServer) requestHTTP() (*http.Response, error) {
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
f.t.Log("Dial", network, addr)
nc, err := f.tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(clientIP, 8090))
assert.NoError(f.t, err)
return nc, err
},
}
client := &http.Client{
Transport: transport,
Timeout: 10 * time.Second,
}
return client.Get(fmt.Sprintf("http://[%s]:8090", clientIP))
}