mirror of https://github.com/coder/coder.git
337 lines
9.1 KiB
Go
337 lines
9.1 KiB
Go
package wsconncache_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/netip"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"golang.org/x/xerrors"
|
|
"storj.io/drpc"
|
|
"storj.io/drpc/drpcmux"
|
|
"storj.io/drpc/drpcserver"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/agent"
|
|
"github.com/coder/coder/v2/coderd/wsconncache"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/agentsdk"
|
|
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestCache(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("Same", func(t *testing.T) {
|
|
t.Parallel()
|
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
|
return setupAgent(t, agentsdk.Manifest{}, 0)
|
|
}, 0)
|
|
defer func() {
|
|
_ = cache.Close()
|
|
}()
|
|
conn1, _, err := cache.Acquire(uuid.Nil)
|
|
require.NoError(t, err)
|
|
conn2, _, err := cache.Acquire(uuid.Nil)
|
|
require.NoError(t, err)
|
|
require.True(t, conn1 == conn2)
|
|
})
|
|
t.Run("Expire", func(t *testing.T) {
|
|
t.Parallel()
|
|
called := int32(0)
|
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
|
atomic.AddInt32(&called, 1)
|
|
return setupAgent(t, agentsdk.Manifest{}, 0)
|
|
}, time.Microsecond)
|
|
defer func() {
|
|
_ = cache.Close()
|
|
}()
|
|
conn, release, err := cache.Acquire(uuid.Nil)
|
|
require.NoError(t, err)
|
|
release()
|
|
<-conn.Closed()
|
|
conn, release, err = cache.Acquire(uuid.Nil)
|
|
require.NoError(t, err)
|
|
release()
|
|
<-conn.Closed()
|
|
require.Equal(t, int32(2), called)
|
|
})
|
|
t.Run("NoExpireWhenLocked", func(t *testing.T) {
|
|
t.Parallel()
|
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
|
return setupAgent(t, agentsdk.Manifest{}, 0)
|
|
}, time.Microsecond)
|
|
defer func() {
|
|
_ = cache.Close()
|
|
}()
|
|
conn, release, err := cache.Acquire(uuid.Nil)
|
|
require.NoError(t, err)
|
|
time.Sleep(time.Millisecond)
|
|
release()
|
|
<-conn.Closed()
|
|
})
|
|
t.Run("HTTPTransport", func(t *testing.T) {
|
|
t.Parallel()
|
|
random, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
_ = random.Close()
|
|
}()
|
|
tcpAddr, valid := random.Addr().(*net.TCPAddr)
|
|
require.True(t, valid)
|
|
|
|
server := &http.Server{
|
|
ReadHeaderTimeout: time.Minute,
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}),
|
|
}
|
|
defer func() {
|
|
_ = server.Close()
|
|
}()
|
|
go server.Serve(random)
|
|
|
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
|
return setupAgent(t, agentsdk.Manifest{}, 0)
|
|
}, time.Microsecond)
|
|
defer func() {
|
|
_ = cache.Close()
|
|
}()
|
|
|
|
var wg sync.WaitGroup
|
|
// Perform many requests in parallel to simulate
|
|
// simultaneous HTTP requests.
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
|
|
Scheme: "http",
|
|
Host: fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port),
|
|
Path: "/",
|
|
})
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
|
defer cancel()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req = req.WithContext(ctx)
|
|
conn, release, err := cache.Acquire(uuid.Nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
defer release()
|
|
if !conn.AwaitReachable(ctx) {
|
|
t.Error("agent not reachable")
|
|
return
|
|
}
|
|
|
|
transport := conn.HTTPTransport()
|
|
defer transport.CloseIdleConnections()
|
|
proxy.Transport = transport
|
|
res := httptest.NewRecorder()
|
|
proxy.ServeHTTP(res, req)
|
|
resp := res.Result()
|
|
defer resp.Body.Close()
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
})
|
|
}
|
|
|
|
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) {
|
|
t.Helper()
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
|
|
|
|
coordinator := tailnet.NewCoordinator(logger)
|
|
t.Cleanup(func() {
|
|
_ = coordinator.Close()
|
|
})
|
|
manifest.AgentID = uuid.New()
|
|
closer := agent.New(agent.Options{
|
|
Client: &client{
|
|
t: t,
|
|
agentID: manifest.AgentID,
|
|
manifest: manifest,
|
|
coordinator: coordinator,
|
|
},
|
|
Logger: logger.Named("agent"),
|
|
ReconnectingPTYTimeout: ptyTimeout,
|
|
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = closer.Close()
|
|
})
|
|
conn, err := tailnet.NewConn(&tailnet.Options{
|
|
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
|
DERPMap: manifest.DERPMap,
|
|
DERPForceWebSockets: manifest.DERPForceWebSockets,
|
|
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
|
|
})
|
|
// setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the
|
|
// test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks.
|
|
if !assert.NoError(t, err) {
|
|
return nil, err
|
|
}
|
|
t.Cleanup(func() {
|
|
_ = conn.Close()
|
|
})
|
|
clientID := uuid.New()
|
|
testCtx, testCtxCancel := context.WithCancel(context.Background())
|
|
t.Cleanup(testCtxCancel)
|
|
coordination := tailnet.NewInMemoryCoordination(
|
|
testCtx, logger,
|
|
clientID, manifest.AgentID,
|
|
coordinator, conn,
|
|
)
|
|
t.Cleanup(func() {
|
|
_ = coordination.Close()
|
|
})
|
|
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
|
AgentID: manifest.AgentID,
|
|
AgentIP: codersdk.WorkspaceAgentIP,
|
|
})
|
|
t.Cleanup(func() {
|
|
_ = agentConn.Close()
|
|
})
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
|
defer cancel()
|
|
if !agentConn.AwaitReachable(ctx) {
|
|
// setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end
|
|
// the test, which in turn closes the wsconncache, which in turn waits for the Dialer and
|
|
// deadlocks.
|
|
t.Error("agent not reachable")
|
|
return nil, xerrors.New("agent not reachable")
|
|
}
|
|
return agentConn, nil
|
|
}
|
|
|
|
type client struct {
|
|
t *testing.T
|
|
agentID uuid.UUID
|
|
manifest agentsdk.Manifest
|
|
coordinator tailnet.Coordinator
|
|
}
|
|
|
|
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
|
|
return c.manifest, nil
|
|
}
|
|
|
|
type closer struct {
|
|
closeFunc func() error
|
|
}
|
|
|
|
func (c *closer) Close() error {
|
|
return c.closeFunc()
|
|
}
|
|
|
|
func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) {
|
|
closed := make(chan struct{})
|
|
return make(<-chan agentsdk.DERPMapUpdate), &closer{
|
|
closeFunc: func() error {
|
|
close(closed)
|
|
return nil
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
|
|
logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
|
|
conn, lis := drpcsdk.MemTransportPipe()
|
|
closed := make(chan struct{})
|
|
c.t.Cleanup(func() {
|
|
_ = conn.Close()
|
|
_ = lis.Close()
|
|
<-closed
|
|
})
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&c.coordinator)
|
|
mux := drpcmux.New()
|
|
drpcService := &tailnet.DRPCService{
|
|
CoordPtr: &coordPtr,
|
|
Logger: logger,
|
|
// TODO: handle DERPMap too!
|
|
DerpMapUpdateFrequency: time.Hour,
|
|
DerpMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
|
}
|
|
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("register DRPC service: %w", err)
|
|
}
|
|
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
|
Log: func(err error) {
|
|
if xerrors.Is(err, io.EOF) ||
|
|
xerrors.Is(err, context.Canceled) ||
|
|
xerrors.Is(err, context.DeadlineExceeded) {
|
|
return
|
|
}
|
|
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
|
},
|
|
})
|
|
serveCtx, cancel := context.WithCancel(context.Background())
|
|
c.t.Cleanup(cancel)
|
|
auth := tailnet.AgentTunnelAuth{}
|
|
streamID := tailnet.StreamID{
|
|
Name: "wsconncache_test-agent",
|
|
ID: c.agentID,
|
|
Auth: auth,
|
|
}
|
|
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
|
|
go func() {
|
|
server.Serve(serveCtx, lis)
|
|
close(closed)
|
|
}()
|
|
return conn, nil
|
|
}
|
|
|
|
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {
|
|
return io.NopCloser(strings.NewReader("")), nil
|
|
}
|
|
|
|
func (*client) PostLifecycle(_ context.Context, _ agentsdk.PostLifecycleRequest) error {
|
|
return nil
|
|
}
|
|
|
|
func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
|
|
return nil
|
|
}
|
|
|
|
func (*client) PostMetadata(_ context.Context, _ agentsdk.PostMetadataRequest) error {
|
|
return nil
|
|
}
|
|
|
|
func (*client) PostStartup(_ context.Context, _ agentsdk.PostStartupRequest) error {
|
|
return nil
|
|
}
|
|
|
|
func (*client) PatchLogs(_ context.Context, _ agentsdk.PatchLogs) error {
|
|
return nil
|
|
}
|
|
|
|
func (*client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
|
|
return codersdk.ServiceBannerConfig{}, nil
|
|
}
|