mirror of https://github.com/coder/coder.git
434 lines
11 KiB
Go
434 lines
11 KiB
Go
package peer_test
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/logging"
|
|
"github.com/pion/transport/vnet"
|
|
"github.com/pion/webrtc/v3"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/peer"
|
|
"github.com/coder/coder/testutil"
|
|
)
|
|
|
|
var (
|
|
disconnectedTimeout = func() time.Duration {
|
|
// Connection state is unfortunately time-based. When resources are
|
|
// contended, a connection can take greater than this timeout to
|
|
// handshake, which results in a test flake.
|
|
//
|
|
// During local testing resources are rarely contended. Reducing this
|
|
// timeout leads to faster local development.
|
|
//
|
|
// In CI resources are frequently contended, so increasing this value
|
|
// results in less flakes.
|
|
if os.Getenv("CI") == "true" {
|
|
return time.Second
|
|
}
|
|
return 100 * time.Millisecond
|
|
}()
|
|
failedTimeout = disconnectedTimeout * 3
|
|
keepAliveInterval = time.Millisecond * 2
|
|
|
|
// There's a global race in the vnet library allocation code.
|
|
// This mutex locks around the creation of the vnet.
|
|
vnetMutex = sync.Mutex{}
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
// pion/ice doesn't properly close immediately. The solution for this isn't yet known. See:
|
|
// https://github.com/pion/ice/pull/413
|
|
goleak.VerifyTestMain(m,
|
|
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func1"),
|
|
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func2"),
|
|
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).taskLoop"),
|
|
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
|
|
)
|
|
}
|
|
|
|
func TestConn(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("Ping", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
_, err := client.Ping()
|
|
require.NoError(t, err)
|
|
_, err = server.Ping()
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("PingNetworkOffline", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, wan := createPair(t)
|
|
exchange(t, client, server)
|
|
_, err := server.Ping()
|
|
require.NoError(t, err)
|
|
err = wan.Stop()
|
|
require.NoError(t, err)
|
|
_, err = server.Ping()
|
|
require.ErrorIs(t, err, peer.ErrFailed)
|
|
})
|
|
|
|
t.Run("PingReconnect", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, wan := createPair(t)
|
|
exchange(t, client, server)
|
|
_, err := server.Ping()
|
|
require.NoError(t, err)
|
|
// Create a channel that closes on disconnect.
|
|
channel, err := server.CreateChannel(context.Background(), "wow", nil)
|
|
assert.NoError(t, err)
|
|
defer channel.Close()
|
|
|
|
err = wan.Stop()
|
|
require.NoError(t, err)
|
|
// Once the connection is marked as disconnected, this
|
|
// channel will be closed.
|
|
_, err = channel.Read(make([]byte, 4))
|
|
assert.ErrorIs(t, err, peer.ErrClosed)
|
|
err = wan.Start()
|
|
require.NoError(t, err)
|
|
_, err = server.Ping()
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("Accept", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
|
require.NoError(t, err)
|
|
defer cch.Close()
|
|
|
|
sch, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
defer sch.Close()
|
|
|
|
_ = cch.Close()
|
|
_, err = sch.Read(make([]byte, 4))
|
|
require.ErrorIs(t, err, peer.ErrClosed)
|
|
})
|
|
|
|
t.Run("AcceptNetworkOffline", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, wan := createPair(t)
|
|
exchange(t, client, server)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
|
require.NoError(t, err)
|
|
defer cch.Close()
|
|
sch, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
defer sch.Close()
|
|
|
|
err = wan.Stop()
|
|
require.NoError(t, err)
|
|
_ = cch.Close()
|
|
_, err = sch.Read(make([]byte, 4))
|
|
require.ErrorIs(t, err, peer.ErrClosed)
|
|
})
|
|
|
|
t.Run("Buffering", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
|
require.NoError(t, err)
|
|
defer cch.Close()
|
|
|
|
readErr := make(chan error, 1)
|
|
go func() {
|
|
sch, err := server.Accept(ctx)
|
|
if err != nil {
|
|
readErr <- err
|
|
_ = cch.Close()
|
|
return
|
|
}
|
|
defer sch.Close()
|
|
|
|
bytes := make([]byte, 4096)
|
|
for {
|
|
_, err = sch.Read(bytes)
|
|
if err != nil {
|
|
readErr <- err
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
bytes := make([]byte, 4096)
|
|
for i := 0; i < 1024; i++ {
|
|
_, err = cch.Write(bytes)
|
|
require.NoError(t, err, "write i=%d", i)
|
|
}
|
|
_ = cch.Close()
|
|
|
|
select {
|
|
case err = <-readErr:
|
|
require.ErrorIs(t, err, peer.ErrClosed, "read error")
|
|
case <-ctx.Done():
|
|
require.Fail(t, "timeout waiting for read error")
|
|
}
|
|
})
|
|
|
|
t.Run("NetConn", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
srv, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
go func() {
|
|
sch, err := server.Accept(ctx)
|
|
if err != nil {
|
|
assert.NoError(t, err)
|
|
return
|
|
}
|
|
defer sch.Close()
|
|
|
|
nc2 := sch.NetConn()
|
|
defer nc2.Close()
|
|
|
|
nc1, err := net.Dial("tcp", srv.Addr().String())
|
|
if err != nil {
|
|
assert.NoError(t, err)
|
|
return
|
|
}
|
|
defer nc1.Close()
|
|
|
|
go func() {
|
|
defer nc1.Close()
|
|
defer nc2.Close()
|
|
_, _ = io.Copy(nc1, nc2)
|
|
}()
|
|
_, _ = io.Copy(nc2, nc1)
|
|
}()
|
|
go func() {
|
|
server := http.Server{
|
|
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
rw.WriteHeader(200)
|
|
}),
|
|
}
|
|
defer server.Close()
|
|
_ = server.Serve(srv)
|
|
}()
|
|
|
|
//nolint:forcetypeassert
|
|
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
var cch *peer.Channel
|
|
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cch.NetConn(), nil
|
|
}
|
|
c := http.Client{
|
|
Transport: defaultTransport,
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
|
|
require.NoError(t, err)
|
|
resp, err := c.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
require.Equal(t, resp.StatusCode, 200)
|
|
// Triggers any connections to close.
|
|
// This test below ensures the DataChannel actually closes.
|
|
defaultTransport.CloseIdleConnections()
|
|
err = cch.Close()
|
|
require.ErrorIs(t, err, peer.ErrClosed)
|
|
})
|
|
|
|
t.Run("CloseBeforeNegotiate", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
err := client.Close()
|
|
require.NoError(t, err)
|
|
err = server.Close()
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("CloseWithError", func(t *testing.T) {
|
|
t.Parallel()
|
|
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
|
|
require.NoError(t, err)
|
|
expectedErr := xerrors.New("wow")
|
|
_ = conn.CloseWithError(expectedErr)
|
|
_, err = conn.CreateChannel(context.Background(), "", nil)
|
|
require.ErrorIs(t, err, expectedErr)
|
|
})
|
|
|
|
t.Run("PingConcurrent", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := client.Ping()
|
|
assert.NoError(t, err)
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := server.Ping()
|
|
assert.NoError(t, err)
|
|
}()
|
|
wg.Wait()
|
|
})
|
|
|
|
t.Run("CandidateBeforeSessionDescription", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
server.SetRemoteSessionDescription(<-client.LocalSessionDescription())
|
|
sdp := <-server.LocalSessionDescription()
|
|
client.AddRemoteCandidate(<-server.LocalCandidate())
|
|
client.SetRemoteSessionDescription(sdp)
|
|
server.AddRemoteCandidate(<-client.LocalCandidate())
|
|
_, err := client.Ping()
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("ShortBuffer", func(t *testing.T) {
|
|
t.Parallel()
|
|
client, server, _ := createPair(t)
|
|
exchange(t, client, server)
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
go func() {
|
|
channel, err := client.CreateChannel(ctx, "test", nil)
|
|
if err != nil {
|
|
assert.NoError(t, err)
|
|
return
|
|
}
|
|
defer channel.Close()
|
|
_, err = channel.Write([]byte{1, 2})
|
|
assert.NoError(t, err)
|
|
}()
|
|
channel, err := server.Accept(ctx)
|
|
require.NoError(t, err)
|
|
defer channel.Close()
|
|
data := make([]byte, 1)
|
|
_, err = channel.Read(data)
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(0x1), data[0])
|
|
_, err = channel.Read(data)
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(0x2), data[0])
|
|
})
|
|
}
|
|
|
|
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
|
|
loggingFactory := logging.NewDefaultLoggerFactory()
|
|
loggingFactory.DefaultLogLevel = logging.LogLevelDisabled
|
|
vnetMutex.Lock()
|
|
defer vnetMutex.Unlock()
|
|
wan, err := vnet.NewRouter(&vnet.RouterConfig{
|
|
CIDR: "1.2.3.0/24",
|
|
LoggerFactory: loggingFactory,
|
|
})
|
|
require.NoError(t, err)
|
|
c1Net := vnet.NewNet(&vnet.NetConfig{
|
|
StaticIPs: []string{"1.2.3.4"},
|
|
})
|
|
err = wan.AddNet(c1Net)
|
|
require.NoError(t, err)
|
|
c2Net := vnet.NewNet(&vnet.NetConfig{
|
|
StaticIPs: []string{"1.2.3.5"},
|
|
})
|
|
err = wan.AddNet(c2Net)
|
|
require.NoError(t, err)
|
|
|
|
c1SettingEngine := webrtc.SettingEngine{}
|
|
c1SettingEngine.SetVNet(c1Net)
|
|
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
|
|
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
|
channel1, err := peer.Client([]webrtc.ICEServer{{}}, &peer.ConnOptions{
|
|
SettingEngine: c1SettingEngine,
|
|
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
channel1.Close()
|
|
})
|
|
c2SettingEngine := webrtc.SettingEngine{}
|
|
c2SettingEngine.SetVNet(c2Net)
|
|
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
|
|
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
|
channel2, err := peer.Server([]webrtc.ICEServer{{}}, &peer.ConnOptions{
|
|
SettingEngine: c2SettingEngine,
|
|
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
channel2.Close()
|
|
})
|
|
|
|
err = wan.Start()
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = wan.Stop()
|
|
})
|
|
|
|
return channel1, channel2, wan
|
|
}
|
|
|
|
func exchange(t *testing.T, client, server *peer.Conn) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
t.Cleanup(func() {
|
|
_ = client.Close()
|
|
_ = server.Close()
|
|
|
|
wg.Wait()
|
|
})
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case c := <-server.LocalCandidate():
|
|
client.AddRemoteCandidate(c)
|
|
case c := <-server.LocalSessionDescription():
|
|
client.SetRemoteSessionDescription(c)
|
|
case <-server.Closed():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case c := <-client.LocalCandidate():
|
|
server.AddRemoteCandidate(c)
|
|
case c := <-client.LocalSessionDescription():
|
|
server.SetRemoteSessionDescription(c)
|
|
case <-client.Closed():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|