coder/coderd/wsconncache/wsconncache_test.go

178 lines
4.9 KiB
Go

package wsconncache_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, 0)
defer func() {
_ = cache.Close()
}()
conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
require.True(t, conn1 == conn2)
})
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
called.Add(1)
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
require.Equal(t, int32(2), called.Load())
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
time.Sleep(time.Millisecond)
release()
<-conn.Closed()
})
t.Run("HTTPTransport", func(t *testing.T) {
t.Parallel()
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = random.Close()
}()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
server := &http.Server{
ReadHeaderTimeout: time.Minute,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
defer func() {
_ = server.Close()
}()
go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
var wg sync.WaitGroup
// Perform many requests in parallel to simulate
// simultaneous HTTP requests.
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port),
Path: "/",
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
conn, release, err := cache.Acquire(req, uuid.Nil)
if !assert.NoError(t, err) {
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
res := httptest.NewRecorder()
proxy.ServeHTTP(res, req)
resp := res.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}()
}
wg.Wait()
})
}
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
return nil, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
}, nil
})
return metadata, listener, err
}, &agent.Options{
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
ReconnectingPTYTimeout: ptyTimeout,
})
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
_ = closer.Close()
})
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(context.Background())
assert.NoError(t, err)
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
return &agent.Conn{
Negotiator: api,
Conn: conn,
}
}