mirror of https://github.com/coder/coder.git
chore: Remove WebRTC networking (#3881)
* chore: Remove WebRTC networking * Fix race condition * Fix WebSocket not closing
This commit is contained in:
parent
1186e643ec
commit
714c366d16
11
Makefile
11
Makefile
|
@ -386,7 +386,6 @@ lint/shellcheck: $(shell shfmt -f .)
|
|||
gen: \
|
||||
coderd/database/dump.sql \
|
||||
coderd/database/querier.go \
|
||||
peerbroker/proto/peerbroker.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
site/src/api/typesGenerated.ts
|
||||
|
@ -395,7 +394,7 @@ gen: \
|
|||
# Mark all generated files as fresh so make thinks they're up-to-date. This is
|
||||
# used during releases so we don't run generation scripts.
|
||||
gen/mark-fresh:
|
||||
files="coderd/database/dump.sql coderd/database/querier.go peerbroker/proto/peerbroker.pb.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts"
|
||||
files="coderd/database/dump.sql coderd/database/querier.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts"
|
||||
for file in $$files; do
|
||||
echo "$$file"
|
||||
if [ ! -f "$$file" ]; then
|
||||
|
@ -417,14 +416,6 @@ coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/dat
|
|||
coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) coderd/database/gen/enum/main.go
|
||||
./coderd/database/generate.sh
|
||||
|
||||
peerbroker/proto/peerbroker.pb.go: peerbroker/proto/peerbroker.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./peerbroker/proto/peerbroker.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
|
|
188
agent/agent.go
188
agent/agent.go
|
@ -11,7 +11,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
@ -34,8 +33,6 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent/usershell"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/pty"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/retry"
|
||||
|
@ -64,7 +61,6 @@ var (
|
|||
|
||||
type Options struct {
|
||||
CoordinatorDialer CoordinatorDialer
|
||||
WebRTCDialer WebRTCDialer
|
||||
FetchMetadata FetchMetadata
|
||||
|
||||
StatsReporter StatsReporter
|
||||
|
@ -80,8 +76,6 @@ type Metadata struct {
|
|||
Directory string `json:"directory"`
|
||||
}
|
||||
|
||||
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
|
||||
|
||||
// CoordinatorDialer is a function that constructs a new broker.
|
||||
// A dialer must be passed in to allow for reconnects.
|
||||
type CoordinatorDialer func(ctx context.Context) (net.Conn, error)
|
||||
|
@ -95,7 +89,6 @@ func New(options Options) io.Closer {
|
|||
}
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
server := &agent{
|
||||
webrtcDialer: options.WebRTCDialer,
|
||||
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
|
||||
logger: options.Logger,
|
||||
closeCancel: cancelFunc,
|
||||
|
@ -111,8 +104,7 @@ func New(options Options) io.Closer {
|
|||
}
|
||||
|
||||
type agent struct {
|
||||
webrtcDialer WebRTCDialer
|
||||
logger slog.Logger
|
||||
logger slog.Logger
|
||||
|
||||
reconnectingPTYs sync.Map
|
||||
reconnectingPTYTimeout time.Duration
|
||||
|
@ -173,9 +165,6 @@ func (a *agent) run(ctx context.Context) {
|
|||
}
|
||||
}()
|
||||
|
||||
if a.webrtcDialer != nil {
|
||||
go a.runWebRTCNetworking(ctx)
|
||||
}
|
||||
if metadata.DERPMap != nil {
|
||||
go a.runTailnet(ctx, metadata.DERPMap)
|
||||
}
|
||||
|
@ -326,49 +315,6 @@ func (a *agent) runCoordinator(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
func (a *agent) runWebRTCNetworking(ctx context.Context) {
|
||||
var peerListener *peerbroker.Listener
|
||||
var err error
|
||||
// An exponential back-off occurs when the connection is failing to dial.
|
||||
// This is to prevent server spam in case of a coderd outage.
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
peerListener, err = a.webrtcDialer(ctx, a.logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
if a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
a.logger.Info(context.Background(), "connected to webrtc broker")
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := peerListener.Accept()
|
||||
if err != nil {
|
||||
if a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
|
||||
a.runWebRTCNetworking(ctx)
|
||||
return
|
||||
}
|
||||
a.closeMutex.Lock()
|
||||
a.connCloseWait.Add(1)
|
||||
a.closeMutex.Unlock()
|
||||
go a.handlePeerConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) runStartupScript(ctx context.Context, script string) error {
|
||||
if script == "" {
|
||||
return nil
|
||||
|
@ -401,74 +347,6 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) {
|
||||
go func() {
|
||||
select {
|
||||
case <-a.closed:
|
||||
case <-peerConn.Closed():
|
||||
}
|
||||
_ = peerConn.Close()
|
||||
a.connCloseWait.Done()
|
||||
}()
|
||||
for {
|
||||
channel, err := peerConn.Accept(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, peer.ErrClosed) || a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
conn := channel.NetConn()
|
||||
|
||||
switch channel.Protocol() {
|
||||
case ProtocolSSH:
|
||||
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
|
||||
case ProtocolReconnectingPTY:
|
||||
rawID := channel.Label()
|
||||
// The ID format is referenced in conn.go.
|
||||
// <uuid>:<height>:<width>
|
||||
idParts := strings.SplitN(rawID, ":", 4)
|
||||
if len(idParts) != 4 {
|
||||
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
|
||||
continue
|
||||
}
|
||||
id := idParts[0]
|
||||
// Enforce a consistent format for IDs.
|
||||
_, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
// Parse the initial terminal dimensions.
|
||||
height, err := strconv.Atoi(idParts[1])
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
|
||||
continue
|
||||
}
|
||||
width, err := strconv.Atoi(idParts[2])
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
|
||||
continue
|
||||
}
|
||||
go a.handleReconnectingPTY(ctx, reconnectingPTYInit{
|
||||
ID: id,
|
||||
Height: uint16(height),
|
||||
Width: uint16(width),
|
||||
Command: idParts[3],
|
||||
}, a.stats.wrapConn(conn))
|
||||
case ProtocolDial:
|
||||
go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn))
|
||||
default:
|
||||
a.logger.Warn(ctx, "unhandled protocol from channel",
|
||||
slog.F("protocol", channel.Protocol()),
|
||||
slog.F("label", channel.Label()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) init(ctx context.Context) {
|
||||
a.logger.Info(ctx, "generating host key")
|
||||
// Clients' should ignore the host key when connecting.
|
||||
|
@ -915,70 +793,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYIn
|
|||
}
|
||||
}
|
||||
|
||||
// dialResponse is written to datachannels with protocol "dial" by the agent as
|
||||
// the first packet to signify whether the dial succeeded or failed.
|
||||
type dialResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
writeError := func(responseError error) error {
|
||||
msg := ""
|
||||
if responseError != nil {
|
||||
msg = responseError.Error()
|
||||
if !xerrors.Is(responseError, io.EOF) {
|
||||
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(dialResponse{
|
||||
Error: msg,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
|
||||
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := url.Parse(label)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
|
||||
return
|
||||
}
|
||||
|
||||
network := u.Scheme
|
||||
addr := u.Host + u.Path
|
||||
if strings.HasPrefix(network, "unix") {
|
||||
if runtime.GOOS == "windows" {
|
||||
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
|
||||
return
|
||||
}
|
||||
addr, err = ExpandRelativeHomePath(addr)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
nconn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
|
||||
return
|
||||
}
|
||||
|
||||
err = writeError(nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
Bicopy(ctx, conn, nconn)
|
||||
}
|
||||
|
||||
// isClosed returns whether the API is closed or not.
|
||||
func (a *agent) isClosed() bool {
|
||||
select {
|
||||
|
|
|
@ -20,12 +20,10 @@ import (
|
|||
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
scp "github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -37,10 +35,6 @@ import (
|
|||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
|
@ -54,64 +48,49 @@ func TestMain(m *testing.M) {
|
|||
func TestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
for _, tailscale := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Parallel()
|
||||
|
||||
setupAgent := func(t *testing.T) (agent.Conn, <-chan *agent.Stats) {
|
||||
var derpMap *tailcfg.DERPMap
|
||||
if tailscale {
|
||||
derpMap = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
conn, stats := setupAgent(t, agent.Metadata{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
assert.Empty(t, <-stats)
|
||||
return conn, stats
|
||||
}
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t, agent.Metadata{}, 0)
|
||||
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
assert.EqualValues(t, 1, (<-stats).NumConns)
|
||||
assert.Greater(t, (<-stats).RxBytes, int64(0))
|
||||
assert.Greater(t, (<-stats).TxBytes, int64(0))
|
||||
})
|
||||
|
||||
assert.EqualValues(t, 1, (<-stats).NumConns)
|
||||
assert.Greater(t, (<-stats).RxBytes, int64(0))
|
||||
assert.Greater(t, (<-stats).TxBytes, int64(0))
|
||||
})
|
||||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t, agent.Metadata{}, 0)
|
||||
|
||||
conn, stats := setupAgent(t)
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ptyConn.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *agent.Stats
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = (<-stats)
|
||||
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats: %+v", s,
|
||||
)
|
||||
})
|
||||
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
}
|
||||
require.NoError(t, err)
|
||||
_, err = ptyConn.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *agent.Stats
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = (<-stats)
|
||||
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats: %+v", s,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SessionExec", func(t *testing.T) {
|
||||
|
@ -235,6 +214,7 @@ func TestAgent(t *testing.T) {
|
|||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
client, err := sftp.NewClient(sshClient)
|
||||
require.NoError(t, err)
|
||||
tempFile := filepath.Join(t.TempDir(), "sftp")
|
||||
|
@ -252,6 +232,7 @@ func TestAgent(t *testing.T) {
|
|||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
scpClient, err := scp.NewClientBySSH(sshClient)
|
||||
require.NoError(t, err)
|
||||
tempFile := filepath.Join(t.TempDir(), "scp")
|
||||
|
@ -384,9 +365,7 @@ func TestAgent(t *testing.T) {
|
|||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||
}
|
||||
|
||||
conn, _ := setupAgent(t, agent.Metadata{
|
||||
DERPMap: tailnettest.RunDERPAndSTUN(t),
|
||||
}, 0)
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
id := uuid.NewString()
|
||||
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
|
@ -462,19 +441,6 @@ func TestAgent(t *testing.T) {
|
|||
return l
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix",
|
||||
setup: func(t *testing.T) net.Listener {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
|
@ -496,8 +462,11 @@ func TestAgent(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
// Dial the listener over WebRTC twice and test out of order
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
|
@ -506,36 +475,11 @@ func TestAgent(t *testing.T) {
|
|||
defer conn2.Close()
|
||||
testDial(t, conn2)
|
||||
testDial(t, conn1)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// This test uses Unix listeners so we can very easily ensure that
|
||||
// no other tests decide to listen on the same random port we
|
||||
// picked.
|
||||
t.Skip("this test is unsupported on Windows")
|
||||
return
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
// Try to dial the non-existent Unix socket over WebRTC
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "remote dial error")
|
||||
require.ErrorContains(t, err, "no such file")
|
||||
require.Nil(t, netConn)
|
||||
})
|
||||
|
||||
t.Run("Tailnet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
|
@ -578,7 +522,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
|||
return
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
if !assert.NoError(t, err) {
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
@ -622,11 +566,12 @@ func (c closeFunc) Close() error {
|
|||
}
|
||||
|
||||
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) (
|
||||
agent.Conn,
|
||||
*agent.Conn,
|
||||
<-chan *agent.Stats,
|
||||
) {
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
tailscale := metadata.DERPMap != nil
|
||||
if metadata.DERPMap == nil {
|
||||
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *agent.Stats)
|
||||
|
@ -634,17 +579,18 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
|||
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
|
||||
return metadata, nil
|
||||
},
|
||||
WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
|
||||
listener, err := peerbroker.Listen(server, nil)
|
||||
return listener, err
|
||||
},
|
||||
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
closed := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
<-closed
|
||||
})
|
||||
go coordinator.ServeAgent(serverConn, agentID)
|
||||
go func() {
|
||||
_ = coordinator.ServeAgent(serverConn, agentID)
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
},
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
|
@ -683,46 +629,27 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
|||
},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
_ = closer.Close()
|
||||
})
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := api.NegotiateConnection(context.Background())
|
||||
assert.NoError(t, err)
|
||||
if tailscale {
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &agent.TailnetConn{
|
||||
Conn: conn,
|
||||
}, statsCh
|
||||
}
|
||||
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
return &agent.WebRTCConn{
|
||||
Negotiator: api,
|
||||
Conn: conn,
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &agent.Conn{
|
||||
Conn: conn,
|
||||
}, statsCh
|
||||
}
|
||||
|
||||
|
|
138
agent/conn.go
138
agent/conn.go
|
@ -4,13 +4,9 @@ import (
|
|||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
@ -19,8 +15,6 @@ import (
|
|||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
|
@ -32,123 +26,12 @@ type ReconnectingPTYRequest struct {
|
|||
Width uint16 `json:"width"`
|
||||
}
|
||||
|
||||
// Conn is a temporary interface while we switch from WebRTC to Wireguard networking.
|
||||
type Conn interface {
|
||||
io.Closer
|
||||
Closed() <-chan struct{}
|
||||
Ping() (time.Duration, error)
|
||||
CloseWithError(err error) error
|
||||
ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error)
|
||||
SSH() (net.Conn, error)
|
||||
Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error)
|
||||
SSHClient() (*ssh.Client, error)
|
||||
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Conn wraps a peer connection with helper functions to
|
||||
// communicate with the agent.
|
||||
type WebRTCConn struct {
|
||||
// Negotiator is responsible for exchanging messages.
|
||||
Negotiator proto.DRPCPeerBrokerClient
|
||||
|
||||
*peer.Conn
|
||||
}
|
||||
|
||||
// ReconnectingPTY returns a connection serving a TTY that can
|
||||
// be reconnected to via ID.
|
||||
//
|
||||
// The command is optional and defaults to start a shell.
|
||||
func (c *WebRTCConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
|
||||
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d:%s", id, height, width, command), &peer.ChannelOptions{
|
||||
Protocol: ProtocolReconnectingPTY,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("pty: %w", err)
|
||||
}
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
// SSH dials the built-in SSH server.
|
||||
func (c *WebRTCConn) SSH() (net.Conn, error) {
|
||||
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
|
||||
Protocol: ProtocolSSH,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial: %w", err)
|
||||
}
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
func (*WebRTCConn) Speedtest(_ speedtest.Direction, _ time.Duration) ([]speedtest.Result, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
// SSHClient calls SSH to create a client that uses a weak cipher
|
||||
// for high throughput.
|
||||
func (c *WebRTCConn) SSHClient() (*ssh.Client, error) {
|
||||
netConn, err := c.SSH()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
}
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
|
||||
// SSH host validation isn't helpful, because obtaining a peer
|
||||
// connection already signifies user-intent to dial a workspace.
|
||||
// #nosec
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh conn: %w", err)
|
||||
}
|
||||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
// DialContext dials an arbitrary protocol+address from inside the workspace and
|
||||
// proxies it through the provided net.Conn.
|
||||
func (c *WebRTCConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
u := &url.URL{
|
||||
Scheme: network,
|
||||
}
|
||||
if strings.HasPrefix(network, "unix") {
|
||||
u.Path = addr
|
||||
} else {
|
||||
u.Host = addr
|
||||
}
|
||||
|
||||
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
|
||||
Protocol: ProtocolDial,
|
||||
Unordered: strings.HasPrefix(network, "udp"),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create datachannel: %w", err)
|
||||
}
|
||||
|
||||
// The first message written from the other side is a JSON payload
|
||||
// containing the dial error.
|
||||
dec := json.NewDecoder(channel)
|
||||
var res dialResponse
|
||||
err = dec.Decode(&res)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode agent dial response: %w", err)
|
||||
}
|
||||
if res.Error != "" {
|
||||
_ = channel.Close()
|
||||
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
|
||||
}
|
||||
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
func (c *WebRTCConn) Close() error {
|
||||
_ = c.Negotiator.DRPCConn().Close()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type TailnetConn struct {
|
||||
type Conn struct {
|
||||
*tailnet.Conn
|
||||
CloseFunc func()
|
||||
}
|
||||
|
||||
func (c *TailnetConn) Ping() (time.Duration, error) {
|
||||
func (c *Conn) Ping() (time.Duration, error) {
|
||||
errCh := make(chan error, 1)
|
||||
durCh := make(chan time.Duration, 1)
|
||||
c.Conn.Ping(tailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) {
|
||||
|
@ -166,11 +49,11 @@ func (c *TailnetConn) Ping() (time.Duration, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *TailnetConn) CloseWithError(_ error) error {
|
||||
func (c *Conn) CloseWithError(_ error) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *TailnetConn) Close() error {
|
||||
func (c *Conn) Close() error {
|
||||
if c.CloseFunc != nil {
|
||||
c.CloseFunc()
|
||||
}
|
||||
|
@ -184,7 +67,7 @@ type reconnectingPTYInit struct {
|
|||
Command string
|
||||
}
|
||||
|
||||
func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
|
||||
func (c *Conn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
|
||||
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetReconnectingPTYPort)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -210,13 +93,13 @@ func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command s
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *TailnetConn) SSH() (net.Conn, error) {
|
||||
func (c *Conn) SSH() (net.Conn, error) {
|
||||
return c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSSHPort)))
|
||||
}
|
||||
|
||||
// SSHClient calls SSH to create a client that uses a weak cipher
|
||||
// for high throughput.
|
||||
func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
|
||||
func (c *Conn) SSHClient() (*ssh.Client, error) {
|
||||
netConn, err := c.SSH()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
|
@ -233,7 +116,7 @@ func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
|
|||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
func (c *Conn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSpeedtestPort)))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial speedtest: %w", err)
|
||||
|
@ -245,7 +128,10 @@ func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Dur
|
|||
return results, err
|
||||
}
|
||||
|
||||
func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
if network == "unix" {
|
||||
return nil, xerrors.New("network must be tcp or udp")
|
||||
}
|
||||
_, rawPort, _ := net.SplitHostPort(addr)
|
||||
port, _ := strconv.Atoi(rawPort)
|
||||
ipp := netip.AddrPortFrom(tailnetIP, uint16(port))
|
||||
|
|
|
@ -32,7 +32,6 @@ func workspaceAgent() *cobra.Command {
|
|||
pprofEnabled bool
|
||||
pprofAddress string
|
||||
noReap bool
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "agent",
|
||||
|
@ -184,7 +183,6 @@ func workspaceAgent() *cobra.Command {
|
|||
|
||||
closer := agent.New(agent.Options{
|
||||
FetchMetadata: client.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: client.ListenWorkspaceAgent,
|
||||
Logger: logger,
|
||||
EnvironmentVariables: map[string]string{
|
||||
// Override the "CODER_AGENT_TOKEN" variable in all
|
||||
|
@ -203,6 +201,5 @@ func workspaceAgent() *cobra.Command {
|
|||
cliflag.BoolVarP(cmd.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_AGENT_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.")
|
||||
cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_AGENT_WIREGUARD", true, "Whether to start the Wireguard interface.")
|
||||
return cmd
|
||||
}
|
||||
|
|
|
@ -7,10 +7,13 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestWorkspaceAgent(t *testing.T) {
|
||||
|
@ -63,11 +66,13 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
err = <-errC
|
||||
require.NoError(t, err)
|
||||
|
@ -121,11 +126,13 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
err = <-errC
|
||||
require.NoError(t, err)
|
||||
|
@ -179,11 +186,13 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
err = <-errC
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -139,7 +139,6 @@ func configSSH() *cobra.Command {
|
|||
usePreviousOpts bool
|
||||
dryRun bool
|
||||
skipProxyCommand bool
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
|
@ -289,15 +288,11 @@ func configSSH() *cobra.Command {
|
|||
"\tLogLevel ERROR",
|
||||
)
|
||||
if !skipProxyCommand {
|
||||
wgArg := ""
|
||||
if wireguard {
|
||||
wgArg = "--wireguard "
|
||||
}
|
||||
configOptions = append(
|
||||
configOptions,
|
||||
fmt.Sprintf(
|
||||
"\tProxyCommand %s --global-config %s ssh %s--stdio %s",
|
||||
escapedCoderBinary, escapedGlobalConfig, wgArg, hostname,
|
||||
"\tProxyCommand %s --global-config %s ssh --stdio %s",
|
||||
escapedCoderBinary, escapedGlobalConfig, hostname,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -374,9 +369,6 @@ func configSSH() *cobra.Command {
|
|||
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
|
||||
_ = cmd.Flags().MarkHidden("skip-proxy-command")
|
||||
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
|
||||
cliui.AllowSkipPrompt(cmd)
|
||||
|
||||
return cmd
|
||||
|
|
|
@ -12,12 +12,14 @@ import (
|
|||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
|
@ -106,15 +108,14 @@ func TestConfigSSH(t *testing.T) {
|
|||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
defer func() {
|
||||
_ = agentCloser.Close()
|
||||
}()
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
|
||||
agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer agentConn.Close()
|
||||
|
||||
|
@ -123,17 +124,28 @@ func TestConfigSSH(t *testing.T) {
|
|||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
copyDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(copyDone)
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
break
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
assert.NoError(t, err)
|
||||
go io.Copy(conn, ssh)
|
||||
go io.Copy(ssh, conn)
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(conn, ssh)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(ssh, conn)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
sshConfigFile := sshConfigFileName(t)
|
||||
|
@ -178,6 +190,9 @@ func TestConfigSSH(t *testing.T) {
|
|||
data, err := sshCmd.Output()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", strings.TrimSpace(string(data)))
|
||||
|
||||
_ = listener.Close()
|
||||
<-copyDone
|
||||
}
|
||||
|
||||
func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) {
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -72,7 +74,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
|
|||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// start workspace agent
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
|
||||
agentClient := client
|
||||
clitest.SetupConfig(t, agentClient, root)
|
||||
|
||||
|
@ -85,11 +87,13 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
|
|||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err = dialer.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
return agentClient, agentToken, pubkey
|
||||
}
|
||||
|
|
|
@ -17,9 +17,7 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
@ -28,7 +26,6 @@ func portForward() *cobra.Command {
|
|||
var (
|
||||
tcpForwards []string // <port>:<port>
|
||||
udpForwards []string // <port>:<port>
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "port-forward <workspace>",
|
||||
|
@ -94,16 +91,7 @@ func portForward() *cobra.Command {
|
|||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
var conn agent.Conn
|
||||
if !wireguard {
|
||||
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
} else {
|
||||
logger := slog.Logger{}
|
||||
if cliflag.IsSetBool(cmd, varVerbose) {
|
||||
logger = slog.Make(sloghuman.Sink(cmd.ErrOrStderr())).Named("tailnet").Leveled(slog.LevelDebug)
|
||||
}
|
||||
conn, err = client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -178,12 +166,10 @@ func portForward() *cobra.Command {
|
|||
|
||||
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
|
||||
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
|
||||
cmd.Flags().BoolVarP(&wireguard, "wireguard", "", true, "Specifies whether to use wireguard networking or not.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
|
||||
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
|
||||
|
||||
var (
|
||||
|
|
|
@ -377,7 +377,7 @@ func setupTestListener(t *testing.T, l net.Listener) string {
|
|||
|
||||
addr := l.Addr().String()
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
require.NoErrorf(t, err, "split listen path %q", addr)
|
||||
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
||||
addr = port
|
||||
|
||||
return addr
|
||||
|
|
|
@ -28,8 +28,6 @@ import (
|
|||
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/turn/v2"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/spf13/afero"
|
||||
|
@ -59,7 +57,6 @@ import (
|
|||
"github.com/coder/coder/coderd/prometheusmetrics"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
|
@ -113,9 +110,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
tlsEnable bool
|
||||
tlsKeyFile string
|
||||
tlsMinVersion string
|
||||
turnRelayAddress string
|
||||
tunnel bool
|
||||
stunServers []string
|
||||
traceEnable bool
|
||||
secureAuthCookie bool
|
||||
sshKeygenAlgorithmRaw string
|
||||
|
@ -300,22 +295,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err)
|
||||
}
|
||||
|
||||
turnServer, err := turnconn.New(&turn.RelayAddressGeneratorStatic{
|
||||
RelayAddress: net.ParseIP(turnRelayAddress),
|
||||
Address: turnRelayAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create turn server: %w", err)
|
||||
}
|
||||
defer turnServer.Close()
|
||||
|
||||
iceServers := make([]webrtc.ICEServer, 0)
|
||||
for _, stunServer := range stunServers {
|
||||
iceServers = append(iceServers, webrtc.ICEServer{
|
||||
URLs: []string{stunServer},
|
||||
})
|
||||
}
|
||||
|
||||
// Validate provided auto-import templates.
|
||||
var (
|
||||
validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(autoImportTemplates))
|
||||
|
@ -360,7 +339,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
|
||||
options := &coderd.Options{
|
||||
AccessURL: accessURLParsed,
|
||||
ICEServers: iceServers,
|
||||
Logger: logger.Named("coderd"),
|
||||
Database: databasefake.New(),
|
||||
DERPMap: derpMap,
|
||||
|
@ -369,8 +347,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
GoogleTokenValidator: googleTokenValidator,
|
||||
SecureAuthCookie: secureAuthCookie,
|
||||
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
||||
TailscaleEnable: tailscaleEnable,
|
||||
TURNServer: turnServer,
|
||||
TracerProvider: tracerProvider,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
AutoImportTemplates: validatedAutoImportTemplates,
|
||||
|
@ -478,7 +454,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
OIDCAuth: oidcClientID != "",
|
||||
OIDCIssuerURL: oidcIssuerURL,
|
||||
Prometheus: promEnabled,
|
||||
STUN: len(stunServers) != 0,
|
||||
STUN: len(derpServerSTUNAddrs) != 0,
|
||||
Tunnel: tunnel,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -850,13 +826,8 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
|||
`Minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
|
||||
cliflag.BoolVarP(root.Flags(), &tunnel, "tunnel", "", "CODER_TUNNEL", false,
|
||||
"Workspaces must be able to reach the `access-url`. This overrides your access URL with a public access URL that tunnels your Coder deployment.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &stunServers, "stun-server", "", "CODER_STUN_SERVERS", []string{
|
||||
"stun:stun.l.google.com:19302",
|
||||
}, "URLs for STUN servers to enable P2P connections.")
|
||||
cliflag.BoolVarP(root.Flags(), &traceEnable, "trace", "", "CODER_TRACE", false,
|
||||
"Whether application tracing data is collected.")
|
||||
cliflag.StringVarP(root.Flags(), &turnRelayAddress, "turn-relay-address", "", "CODER_TURN_RELAY_ADDRESS", "127.0.0.1",
|
||||
"The address to bind TURN connections.")
|
||||
cliflag.BoolVarP(root.Flags(), &secureAuthCookie, "secure-auth-cookie", "", "CODER_SECURE_AUTH_COOKIE", false,
|
||||
"Controls if the 'Secure' property is set on browser session cookies")
|
||||
cliflag.StringVarP(root.Flags(), &sshKeygenAlgorithmRaw, "ssh-keygen-algorithm", "", "CODER_SSH_KEYGEN_ALGORITHM", "ed25519",
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
@ -73,8 +72,7 @@ func speedtest() *cobra.Command {
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tc, _ := conn.(*agent.TailnetConn)
|
||||
status := tc.Status()
|
||||
status := conn.Status()
|
||||
if len(status.Peers()) != 1 {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@ func TestSpeedtest(t *testing.T) {
|
|||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
|
|
12
cli/ssh.go
12
cli/ssh.go
|
@ -22,7 +22,6 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/autobuild/notify"
|
||||
|
@ -43,7 +42,6 @@ func ssh() *cobra.Command {
|
|||
forwardAgent bool
|
||||
identityAgent string
|
||||
wsPollInterval time.Duration
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
|
@ -88,12 +86,7 @@ func ssh() *cobra.Command {
|
|||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
var conn agent.Conn
|
||||
if !wireguard {
|
||||
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
} else {
|
||||
conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -221,9 +214,6 @@ func ssh() *cobra.Command {
|
|||
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
|
||||
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
|
||||
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,6 @@ func TestSSH(t *testing.T) {
|
|||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
|
@ -112,7 +111,6 @@ func TestSSH(t *testing.T) {
|
|||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
|
@ -181,7 +179,6 @@ func TestSSH(t *testing.T) {
|
|||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -35,7 +34,6 @@ import (
|
|||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
|
@ -65,17 +63,14 @@ type Options struct {
|
|||
GithubOAuth2Config *GithubOAuth2Config
|
||||
OIDCConfig *OIDCConfig
|
||||
PrometheusRegistry *prometheus.Registry
|
||||
ICEServers []webrtc.ICEServer
|
||||
SecureAuthCookie bool
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
Telemetry telemetry.Reporter
|
||||
TURNServer *turnconn.Server
|
||||
TracerProvider trace.TracerProvider
|
||||
AutoImportTemplates []AutoImportTemplate
|
||||
LicenseHandler http.Handler
|
||||
FeaturesService features.Service
|
||||
|
||||
TailscaleEnable bool
|
||||
TailnetCoordinator *tailnet.Coordinator
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
||||
|
@ -92,6 +87,12 @@ func New(options *Options) *API {
|
|||
// Multiply the update by two to allow for some lag-time.
|
||||
options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2
|
||||
}
|
||||
if options.AgentStatsRefreshInterval == 0 {
|
||||
options.AgentStatsRefreshInterval = 10 * time.Minute
|
||||
}
|
||||
if options.MetricsCacheRefreshInterval == 0 {
|
||||
options.MetricsCacheRefreshInterval = time.Hour
|
||||
}
|
||||
if options.APIRateLimit == 0 {
|
||||
options.APIRateLimit = 512
|
||||
}
|
||||
|
@ -149,11 +150,7 @@ func New(options *Options) *API {
|
|||
},
|
||||
metricsCache: metricsCache,
|
||||
}
|
||||
if options.TailscaleEnable {
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
} else {
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
|
||||
}
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
|
@ -415,14 +412,8 @@ func New(options *Options) *API {
|
|||
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
|
||||
r.Get("/metadata", api.workspaceAgentMetadata)
|
||||
r.Post("/version", api.postWorkspaceAgentVersion)
|
||||
r.Get("/listen", api.workspaceAgentListen)
|
||||
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Get("/turn", api.workspaceAgentTurn)
|
||||
r.Get("/iceservers", api.workspaceAgentICEServers)
|
||||
|
||||
r.Get("/coordinate", api.workspaceAgentCoordinate)
|
||||
|
||||
r.Get("/report-stats", api.workspaceAgentReportStats)
|
||||
})
|
||||
r.Route("/{workspaceagent}", func(r chi.Router) {
|
||||
|
@ -432,11 +423,7 @@ func New(options *Options) *API {
|
|||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceAgent)
|
||||
r.Get("/dial", api.workspaceAgentDial)
|
||||
r.Get("/turn", api.userWorkspaceAgentTurn)
|
||||
r.Get("/pty", api.workspaceAgentPTY)
|
||||
r.Get("/iceservers", api.workspaceAgentICEServers)
|
||||
|
||||
r.Get("/connection", api.workspaceAgentConnection)
|
||||
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
|
||||
})
|
||||
|
|
|
@ -188,18 +188,14 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
"GET:/api/v2/users/oidc/callback": {NoAuthorize: true},
|
||||
|
||||
// All workspaceagents endpoints do not use rbac
|
||||
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
|
||||
|
||||
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
|
||||
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)},
|
||||
|
@ -256,14 +252,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
|||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/dial": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/pty": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
|
|
|
@ -54,7 +54,6 @@ import (
|
|||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/coderd/util/ptr"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
|
@ -202,12 +201,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
|
||||
}
|
||||
|
||||
turnServer, err := turnconn.New(nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = turnServer.Close()
|
||||
})
|
||||
|
||||
features := coderd.DisabledImplementations
|
||||
if options.Auditor != nil {
|
||||
features.Auditor = options.Auditor
|
||||
|
@ -231,7 +224,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
|||
OIDCConfig: options.OIDCConfig,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
TURNServer: turnServer,
|
||||
APIRateLimit: options.APIRateLimit,
|
||||
Authorizer: options.Authorizer,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
|
|
|
@ -604,7 +604,6 @@ func TestTemplateDAUs(t *testing.T) {
|
|||
agentCloser := agent.New(agent.Options{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
StatsReporter: agentClient.AgentReportStats,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
})
|
||||
|
|
|
@ -1,203 +0,0 @@
|
|||
package turnconn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/turn/v2"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
// reservedAddress is a magic address that's used exclusively
|
||||
// for proxying via Coder. We don't proxy all TURN connections,
|
||||
// because that'd exclude the possibility of a customer using
|
||||
// their own TURN server.
|
||||
reservedAddress = "127.0.0.1:12345"
|
||||
credential = "coder"
|
||||
localhost = &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
}
|
||||
|
||||
// Proxy is a an ICE Server that uses a special hostname
|
||||
// to indicate traffic should be proxied.
|
||||
Proxy = webrtc.ICEServer{
|
||||
URLs: []string{"turns:" + reservedAddress},
|
||||
Username: "coder",
|
||||
Credential: credential,
|
||||
}
|
||||
)
|
||||
|
||||
// New constructs a new TURN server binding to the relay address provided.
|
||||
// The relay address is used to broadcast the location of an accepted connection.
|
||||
func New(relayAddress *turn.RelayAddressGeneratorStatic) (*Server, error) {
|
||||
if relayAddress == nil {
|
||||
relayAddress = &turn.RelayAddressGeneratorStatic{
|
||||
RelayAddress: localhost.IP,
|
||||
Address: "127.0.0.1",
|
||||
}
|
||||
}
|
||||
logger := logging.NewDefaultLoggerFactory()
|
||||
logger.DefaultLogLevel = logging.LogLevelDisabled
|
||||
server := &Server{
|
||||
conns: make(chan net.Conn, 1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
server.listener = &listener{
|
||||
srv: server,
|
||||
}
|
||||
var err error
|
||||
server.turn, err = turn.NewServer(turn.ServerConfig{
|
||||
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
|
||||
// TURN connections require credentials. It's not important
|
||||
// for our use-case, because our listener is entirely in-memory.
|
||||
return turn.GenerateAuthKey(Proxy.Username, "", credential), true
|
||||
},
|
||||
ListenerConfigs: []turn.ListenerConfig{{
|
||||
Listener: server.listener,
|
||||
RelayAddressGenerator: relayAddress,
|
||||
}},
|
||||
LoggerFactory: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create server: %w", err)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Server accepts and connects TURN allocations.
|
||||
//
|
||||
// This is a thin wrapper around pion/turn that pipes
|
||||
// connections directly to the in-memory handler.
|
||||
type Server struct {
|
||||
listener *listener
|
||||
turn *turn.Server
|
||||
|
||||
closeMutex sync.Mutex
|
||||
closed chan (struct{})
|
||||
conns chan (net.Conn)
|
||||
}
|
||||
|
||||
// Accept consumes a new connection into the TURN server.
|
||||
// A unique remote address must exist per-connection.
|
||||
// pion/turn indexes allocations based on the address.
|
||||
func (s *Server) Accept(nc net.Conn, remoteAddress, localAddress *net.TCPAddr) *Conn {
|
||||
if localAddress == nil {
|
||||
localAddress = localhost
|
||||
}
|
||||
conn := &Conn{
|
||||
Conn: nc,
|
||||
remoteAddress: remoteAddress,
|
||||
localAddress: localAddress,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
s.conns <- conn
|
||||
return conn
|
||||
}
|
||||
|
||||
// Close ends the TURN server.
|
||||
func (s *Server) Close() error {
|
||||
s.closeMutex.Lock()
|
||||
defer s.closeMutex.Unlock()
|
||||
if s.isClosed() {
|
||||
return nil
|
||||
}
|
||||
err := s.turn.Close()
|
||||
close(s.conns)
|
||||
close(s.closed)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) isClosed() bool {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// listener implements net.Listener for the TURN
|
||||
// server to consume.
|
||||
type listener struct {
|
||||
srv *Server
|
||||
}
|
||||
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.srv.conns
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (*listener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*listener) Addr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
closed chan struct{}
|
||||
localAddress *net.TCPAddr
|
||||
remoteAddress *net.TCPAddr
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.localAddress
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddress
|
||||
}
|
||||
|
||||
// Closed returns a channel which is closed when
|
||||
// the connection is.
|
||||
func (c *Conn) Closed() <-chan struct{} {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type dialer func(network, addr string) (c net.Conn, err error)
|
||||
|
||||
func (d dialer) Dial(network, addr string) (c net.Conn, err error) {
|
||||
return d(network, addr)
|
||||
}
|
||||
|
||||
// ProxyDialer accepts a proxy function that's called when the connection
|
||||
// address matches the reserved host in the "Proxy" ICE server.
|
||||
//
|
||||
// This should be passed to WebRTC connections as an ICE dialer.
|
||||
func ProxyDialer(proxyFunc func() (c net.Conn, err error)) proxy.Dialer {
|
||||
return dialer(func(network, addr string) (net.Conn, error) {
|
||||
if addr != reservedAddress {
|
||||
return proxy.Direct.Dial(network, addr)
|
||||
}
|
||||
netConn, err := proxyFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
localAddress: localhost,
|
||||
closed: make(chan struct{}),
|
||||
Conn: netConn,
|
||||
}, nil
|
||||
})
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
package turnconn_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/peer"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestTURNConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
turnServer, err := turnconn.New(nil)
|
||||
require.NoError(t, err)
|
||||
defer turnServer.Close()
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
clientDialer, clientTURN := net.Pipe()
|
||||
turnServer.Accept(clientTURN, &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 16000,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
clientSettings := webrtc.SettingEngine{}
|
||||
clientSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
|
||||
clientSettings.SetRelayAcceptanceMinWait(0)
|
||||
clientSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) {
|
||||
return clientDialer, nil
|
||||
}))
|
||||
client, err := peer.Client([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{
|
||||
SettingEngine: clientSettings,
|
||||
Logger: logger.Named("client"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
}()
|
||||
|
||||
serverDialer, serverTURN := net.Pipe()
|
||||
turnServer.Accept(serverTURN, &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 16001,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
serverSettings := webrtc.SettingEngine{}
|
||||
serverSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
|
||||
serverSettings.SetRelayAcceptanceMinWait(0)
|
||||
serverSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) {
|
||||
return serverDialer, nil
|
||||
}))
|
||||
server, err := peer.Server([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{
|
||||
SettingEngine: serverSettings,
|
||||
Logger: logger.Named("server"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = server.Close()
|
||||
}()
|
||||
exchange(t, client, server)
|
||||
|
||||
_, err = client.Ping()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func exchange(t *testing.T, client, server *peer.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
t.Cleanup(wg.Wait)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case c := <-server.LocalCandidate():
|
||||
client.AddRemoteCandidate(c)
|
||||
case c := <-server.LocalSessionDescription():
|
||||
client.SetRemoteSessionDescription(c)
|
||||
case <-server.Closed():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case c := <-client.LocalCandidate():
|
||||
server.AddRemoteCandidate(c)
|
||||
case c := <-client.LocalSessionDescription():
|
||||
server.SetRemoteSessionDescription(c)
|
||||
case <-client.Closed():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
|
@ -15,7 +15,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -30,12 +29,7 @@ import (
|
|||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
|
@ -66,67 +60,6 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
|
|||
httpapi.Write(rw, http.StatusOK, apiAgent)
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
workspaceAgent := httpmw.WorkspaceAgentParam(r)
|
||||
workspace := httpmw.WorkspaceParam(r)
|
||||
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
|
||||
httpapi.Write(rw, http.StatusPreconditionFailed, codersdk.Response{
|
||||
Message: fmt.Sprintf("Agent isn't connected! Status: %s.", apiAgent.Status),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// end span so we don't get long lived trace data
|
||||
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
|
||||
|
||||
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
|
||||
ChannelID: workspaceAgent.ID.String(),
|
||||
Logger: api.Logger.Named("peerbroker-proxy-dial"),
|
||||
Pubsub: api.Pubsub,
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
|
||||
|
@ -186,231 +119,6 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques
|
|||
httpapi.Write(rw, http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Internal error fetching workspace build job.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Ensure the resource is still valid!
|
||||
// We only accept agents for resources on the latest build.
|
||||
ensureLatestBuild := func() error {
|
||||
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if build.ID != latestBuild.ID {
|
||||
return xerrors.New("build is outdated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
|
||||
slog.F("resource", resource),
|
||||
slog.F("agent", workspaceAgent),
|
||||
)
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Agent trying to connect from non-latest build.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
|
||||
ChannelID: workspaceAgent.ID.String(),
|
||||
Pubsub: api.Pubsub,
|
||||
Logger: api.Logger.Named("peerbroker-proxy-listen"),
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
defer closer.Close()
|
||||
|
||||
firstConnectedAt := workspaceAgent.FirstConnectedAt
|
||||
if !firstConnectedAt.Valid {
|
||||
firstConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
lastConnectedAt := sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
disconnectedAt := workspaceAgent.DisconnectedAt
|
||||
updateConnectionTimes := func() error {
|
||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||
ID: workspaceAgent.ID,
|
||||
FirstConnectedAt: firstConnectedAt,
|
||||
LastConnectedAt: lastConnectedAt,
|
||||
DisconnectedAt: disconnectedAt,
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
disconnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
_ = updateConnectionTimes()
|
||||
}()
|
||||
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// end span so we don't get long lived trace data
|
||||
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
|
||||
|
||||
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
|
||||
|
||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-session.CloseChan():
|
||||
return
|
||||
case <-ticker.C:
|
||||
lastConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
// Disconnect agents that are no longer valid.
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentICEServers(rw http.ResponseWriter, _ *http.Request) {
|
||||
httpapi.Write(rw, http.StatusOK, api.ICEServers)
|
||||
}
|
||||
|
||||
// userWorkspaceAgentTurn is a user connecting to a remote workspace agent
|
||||
// through turn.
|
||||
func (api *API) userWorkspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
|
||||
workspace := httpmw.WorkspaceParam(r)
|
||||
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Passed authorization
|
||||
api.workspaceAgentTurn(rw, r)
|
||||
}
|
||||
|
||||
// workspaceAgentTurn proxies a WebSocket connection to the TURN server.
|
||||
func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
|
||||
remoteAddress := &net.TCPAddr{
|
||||
IP: net.ParseIP(r.RemoteAddr),
|
||||
}
|
||||
// By default requests have the remote address and port.
|
||||
host, port, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid remote address.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
remoteAddress.IP = net.ParseIP(host)
|
||||
remoteAddress.Port, err = strconv.Atoi(port)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Port for remote address %q must be an integer.", r.RemoteAddr),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
wsConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
// end span so we don't get long lived trace data
|
||||
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
|
||||
|
||||
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
|
||||
select {
|
||||
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
|
||||
}
|
||||
|
||||
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
|
||||
// This is used for the web terminal.
|
||||
func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -492,75 +200,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
|||
_, _ = io.Copy(ptNetConn, wsNetConn)
|
||||
}
|
||||
|
||||
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
|
||||
// r.Context() for cancellation if it's use is safe or r.Hijack() has
|
||||
// not been performed.
|
||||
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
|
||||
ChannelID: agentID.String(),
|
||||
Logger: api.Logger.Named("peerbroker-proxy-dial"),
|
||||
Pubsub: api.Pubsub,
|
||||
})
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
}()
|
||||
|
||||
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := peerClient.NegotiateConnection(ctx)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, xerrors.Errorf("negotiate: %w", err)
|
||||
}
|
||||
options := &peer.ConnOptions{
|
||||
Logger: api.Logger.Named("agent-dialer"),
|
||||
}
|
||||
options.SettingEngine.SetSrflxAcceptanceMinWait(0)
|
||||
options.SettingEngine.SetRelayAcceptanceMinWait(0)
|
||||
// Use the ProxyDialer for the TURN server.
|
||||
// This is required for connections where P2P is not enabled.
|
||||
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
|
||||
clientPipe, serverPipe := net.Pipe()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = clientPipe.Close()
|
||||
_ = serverPipe.Close()
|
||||
}()
|
||||
localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
|
||||
remoteAddress := &net.TCPAddr{
|
||||
IP: net.ParseIP(r.RemoteAddr),
|
||||
}
|
||||
// By default requests have the remote address and port.
|
||||
host, port, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("split remote address: %w", err)
|
||||
}
|
||||
remoteAddress.IP = net.ParseIP(host)
|
||||
remoteAddress.Port, err = strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("convert remote port: %w", err)
|
||||
}
|
||||
api.TURNServer.Accept(clientPipe, remoteAddress, localAddress)
|
||||
return serverPipe, nil
|
||||
}))
|
||||
peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, xerrors.Errorf("dial: %w", err)
|
||||
}
|
||||
go func() {
|
||||
<-peerConn.Closed()
|
||||
cancelFunc()
|
||||
}()
|
||||
return &agent.WebRTCConn{
|
||||
Negotiator: peerClient,
|
||||
Conn: peerConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
|
||||
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
go func() {
|
||||
<-r.Context().Done()
|
||||
|
@ -587,7 +227,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (a
|
|||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
return &agent.TailnetConn{
|
||||
return &agent.Conn{
|
||||
Conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
@ -609,6 +249,48 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to accept websocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Internal error fetching workspace build job.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Ensure the resource is still valid!
|
||||
// We only accept agents for resources on the latest build.
|
||||
ensureLatestBuild := func() error {
|
||||
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if build.ID != latestBuild.ID {
|
||||
return xerrors.New("build is outdated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
|
||||
slog.F("resource", resource),
|
||||
slog.F("agent", workspaceAgent),
|
||||
)
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Agent trying to connect from non-latest build.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
|
@ -618,12 +300,88 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
})
|
||||
return
|
||||
}
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID)
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
firstConnectedAt := workspaceAgent.FirstConnectedAt
|
||||
if !firstConnectedAt.Valid {
|
||||
firstConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
lastConnectedAt := sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
disconnectedAt := workspaceAgent.DisconnectedAt
|
||||
updateConnectionTimes := func() error {
|
||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||
ID: workspaceAgent.ID,
|
||||
FirstConnectedAt: firstConnectedAt,
|
||||
LastConnectedAt: lastConnectedAt,
|
||||
DisconnectedAt: disconnectedAt,
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
disconnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
_ = updateConnectionTimes()
|
||||
}()
|
||||
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// end span so we don't get long lived trace data
|
||||
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
|
||||
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
|
||||
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
closeChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(closeChan)
|
||||
err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-closeChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
lastConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err := ensureLatestBuild()
|
||||
if err != nil {
|
||||
// Disconnect agents that are no longer valid.
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
@ -18,7 +17,6 @@ import (
|
|||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/testutil"
|
||||
|
@ -112,7 +110,6 @@ func TestWorkspaceAgentListen(t *testing.T) {
|
|||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
|
||||
})
|
||||
defer func() {
|
||||
|
@ -123,13 +120,15 @@ func TestWorkspaceAgentListen(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
_, err = conn.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("FailNonLatestBuild", func(t *testing.T) {
|
||||
|
@ -202,75 +201,12 @@ func TestWorkspaceAgentListen(t *testing.T) {
|
|||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.SessionToken = authToken
|
||||
|
||||
_, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
|
||||
_, err = agentClient.ListenWorkspaceAgentTailnet(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "build is outdated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentTURN(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
|
||||
})
|
||||
defer func() {
|
||||
_ = agentCloser.Close()
|
||||
}()
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
opts := &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client"),
|
||||
}
|
||||
// Force a TURN connection!
|
||||
opts.SettingEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4})
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, opts)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
_, err = conn.Ping()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentTailnet(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, nil)
|
||||
|
@ -306,7 +242,6 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
|
|||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
|
||||
})
|
||||
|
@ -373,7 +308,6 @@ func TestWorkspaceAgentPTY(t *testing.T) {
|
|||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
|
||||
})
|
||||
defer func() {
|
||||
|
|
|
@ -103,7 +103,6 @@ func setupProxyTest(t *testing.T) (*codersdk.Client, uuid.UUID, codersdk.Workspa
|
|||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
|
|
|
@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
|
|||
}
|
||||
|
||||
// Dialer creates a new agent connection by ID.
|
||||
type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error)
|
||||
type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error)
|
||||
|
||||
// Conn wraps an agent connection with a reusable HTTP transport.
|
||||
type Conn struct {
|
||||
agent.Conn
|
||||
*agent.Conn
|
||||
|
||||
locks atomic.Uint64
|
||||
timeoutMutex sync.Mutex
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestCache(t *testing.T) {
|
|||
t.Parallel()
|
||||
t.Run("Same", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
|
||||
return setupAgent(t, agent.Metadata{}, 0), nil
|
||||
}, 0)
|
||||
defer func() {
|
||||
|
@ -50,7 +50,7 @@ func TestCache(t *testing.T) {
|
|||
t.Run("Expire", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
called := atomic.NewInt32(0)
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
|
||||
called.Add(1)
|
||||
return setupAgent(t, agent.Metadata{}, 0), nil
|
||||
}, time.Microsecond)
|
||||
|
@ -69,7 +69,7 @@ func TestCache(t *testing.T) {
|
|||
})
|
||||
t.Run("NoExpireWhenLocked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
|
||||
return setupAgent(t, agent.Metadata{}, 0), nil
|
||||
}, time.Microsecond)
|
||||
defer func() {
|
||||
|
@ -102,7 +102,7 @@ func TestCache(t *testing.T) {
|
|||
}()
|
||||
go server.Serve(random)
|
||||
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
|
||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
|
||||
return setupAgent(t, agent.Metadata{}, 0), nil
|
||||
}, time.Microsecond)
|
||||
defer func() {
|
||||
|
@ -139,7 +139,7 @@ func TestCache(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn {
|
||||
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
|
||||
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
|
@ -180,7 +180,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
|||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &agent.TailnetConn{
|
||||
return &agent.Conn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,6 +135,7 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
|||
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
|
||||
go func() {
|
||||
defer close(logs)
|
||||
defer conn.Close(websocket.StatusGoingAway, "")
|
||||
var log ProvisionerJobLog
|
||||
for {
|
||||
err = decoder.Decode(&log)
|
||||
|
|
|
@ -14,9 +14,6 @@ import (
|
|||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
"nhooyr.io/websocket/wsjson"
|
||||
|
@ -25,11 +22,6 @@ import (
|
|||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
@ -206,69 +198,6 @@ func (c *Client) WorkspaceAgentMetadata(ctx context.Context) (agent.Metadata, er
|
|||
return agentMetadata, json.NewDecoder(res.Body).Decode(&agentMetadata)
|
||||
}
|
||||
|
||||
// ListenWorkspaceAgent connects as a workspace agent identifying with the session token.
|
||||
// On each inbound connection request, connection info is fetched.
|
||||
func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
|
||||
serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/listen")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(serverURL, []*http.Cookie{{
|
||||
Name: SessionTokenKey,
|
||||
Value: c.SessionToken,
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
}
|
||||
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
if res == nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("multiplex client: %w", err)
|
||||
}
|
||||
return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
|
||||
// This can be cached if it adds to latency too much.
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/iceservers", nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, nil, readBodyAsError(res)
|
||||
}
|
||||
var iceServers []webrtc.ICEServer
|
||||
err = json.NewDecoder(res.Body).Decode(&iceServers)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
options := webrtc.SettingEngine{}
|
||||
options.SetSrflxAcceptanceMinWait(0)
|
||||
options.SetRelayAcceptanceMinWait(0)
|
||||
options.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, "/api/v2/workspaceagents/me/turn"))
|
||||
iceServers = append(iceServers, turnconn.Proxy)
|
||||
return iceServers, &peer.ConnOptions{
|
||||
SettingEngine: options,
|
||||
Logger: logger,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) {
|
||||
coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate")
|
||||
if err != nil {
|
||||
|
@ -286,17 +215,20 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err
|
|||
Jar: jar,
|
||||
}
|
||||
// nolint:bodyclose
|
||||
conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
||||
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if res == nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
|
||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
}
|
||||
|
||||
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) {
|
||||
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*agent.Conn, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -349,10 +281,12 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if errors.Is(err, context.Canceled) {
|
||||
_ = ws.Close(websocket.StatusAbnormalClosure, "")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to dial", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusAbnormalClosure, "")
|
||||
continue
|
||||
}
|
||||
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
|
||||
|
@ -362,15 +296,18 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
logger.Debug(ctx, "serving coordinator")
|
||||
err = <-errChan
|
||||
if errors.Is(err, context.Canceled) {
|
||||
_ = ws.Close(websocket.StatusAbnormalClosure, "")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "error serving coordinator", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusAbnormalClosure, "")
|
||||
continue
|
||||
}
|
||||
_ = ws.Close(websocket.StatusAbnormalClosure, "")
|
||||
}
|
||||
}()
|
||||
return &agent.TailnetConn{
|
||||
return &agent.Conn{
|
||||
Conn: conn,
|
||||
CloseFunc: func() {
|
||||
cancelFunc()
|
||||
|
@ -379,78 +316,6 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
|
|||
}, nil
|
||||
}
|
||||
|
||||
// DialWorkspaceAgent creates a connection to the specified resource.
|
||||
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (agent.Conn, error) {
|
||||
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String()))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(serverURL, []*http.Cookie{{
|
||||
Name: SessionTokenKey,
|
||||
Value: c.SessionToken,
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
}
|
||||
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
if res == nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("multiplex client: %w", err)
|
||||
}
|
||||
client := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session))
|
||||
stream, err := client.NegotiateConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("negotiate connection: %w", err)
|
||||
}
|
||||
|
||||
res, err = c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/iceservers", agentID.String()), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
var iceServers []webrtc.ICEServer
|
||||
err = json.NewDecoder(res.Body).Decode(&iceServers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if options == nil {
|
||||
options = &peer.ConnOptions{}
|
||||
}
|
||||
options.SettingEngine.SetSrflxAcceptanceMinWait(0)
|
||||
options.SettingEngine.SetRelayAcceptanceMinWait(0)
|
||||
options.SettingEngine.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, fmt.Sprintf("/api/v2/workspaceagents/%s/turn", agentID.String())))
|
||||
iceServers = append(iceServers, turnconn.Proxy)
|
||||
|
||||
peerConn, err := peerbroker.Dial(stream, iceServers, options)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial peer: %w", err)
|
||||
}
|
||||
return &agent.WebRTCConn{
|
||||
Negotiator: client,
|
||||
Conn: peerConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WorkspaceAgent returns an agent by ID.
|
||||
func (c *Client) WorkspaceAgent(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s", id), nil)
|
||||
|
@ -509,27 +374,6 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
|
|||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
}
|
||||
|
||||
func (c *Client) turnProxyDialer(ctx context.Context, httpClient *http.Client, path string) proxy.Dialer {
|
||||
return turnconn.ProxyDialer(func() (net.Conn, error) {
|
||||
turnURL, err := c.URL.Parse(path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
conn, res, err := websocket.Dial(ctx, turnURL.String(), &websocket.DialOptions{
|
||||
HTTPClient: httpClient,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
if res == nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
})
|
||||
}
|
||||
|
||||
// AgentReportStats begins a stat streaming connection with the Coder server.
|
||||
// It is resilient to network failures and intermittent coderd issues.
|
||||
func (c *Client) AgentReportStats(
|
||||
|
@ -584,6 +428,7 @@ func (c *Client) AgentReportStats(
|
|||
var req AgentStatsReportRequest
|
||||
err := wsjson.Read(ctx, conn, &req)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, "")
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -597,6 +442,7 @@ func (c *Client) AgentReportStats(
|
|||
|
||||
err = wsjson.Write(ctx, conn, resp)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, "")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
24
go.mod
24
go.mod
|
@ -39,7 +39,7 @@ replace github.com/fatedier/kcp-go => github.com/coder/kcp-go v2.0.4-0.202204091
|
|||
// https://github.com/pion/udp/pull/73
|
||||
replace github.com/pion/udp => github.com/mafredri/udp v0.1.2-0.20220805105907-b2872e92e98d
|
||||
|
||||
// https://github.com/hashicorp/hc-install/pull/68
|
||||
// https://github.com/hashicorp/hc-dinstall/pull/68
|
||||
replace github.com/hashicorp/hc-install => github.com/mafredri/hc-install v0.4.1-0.20220727132613-e91868e28445
|
||||
|
||||
// https://github.com/tcnksm/go-httpstat/pull/29
|
||||
|
@ -119,12 +119,7 @@ require (
|
|||
github.com/nhatthm/otelsql v0.4.0
|
||||
github.com/open-policy-agent/opa v0.41.0
|
||||
github.com/ory/dockertest/v3 v3.9.1
|
||||
github.com/pion/datachannel v1.5.2
|
||||
github.com/pion/logging v0.2.2
|
||||
github.com/pion/transport v0.13.1
|
||||
github.com/pion/turn/v2 v2.0.8
|
||||
github.com/pion/udp v0.1.1
|
||||
github.com/pion/webrtc/v3 v3.1.43
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e
|
||||
github.com/pkg/sftp v1.13.5
|
||||
|
@ -150,7 +145,6 @@ require (
|
|||
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167
|
||||
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4
|
||||
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b
|
||||
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
|
||||
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
|
||||
|
@ -171,6 +165,11 @@ require (
|
|||
tailscale.com v1.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/pion/transport v0.13.1 // indirect
|
||||
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.0.0-rc.1 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||
|
@ -258,17 +257,6 @@ require (
|
|||
github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 // indirect
|
||||
github.com/opencontainers/runc v1.1.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.2 // indirect
|
||||
github.com/pion/dtls/v2 v2.1.5 // indirect
|
||||
github.com/pion/ice/v2 v2.2.6 // indirect
|
||||
github.com/pion/interceptor v0.1.11 // indirect
|
||||
github.com/pion/mdns v0.0.5 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/rtcp v1.2.9 // indirect
|
||||
github.com/pion/rtp v1.7.13 // indirect
|
||||
github.com/pion/sctp v1.8.2 // indirect
|
||||
github.com/pion/sdp/v3 v3.0.5 // indirect
|
||||
github.com/pion/srtp/v2 v2.0.10 // indirect
|
||||
github.com/pion/stun v0.3.5 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
|
|
43
go.sum
43
go.sum
|
@ -1451,7 +1451,6 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108
|
|||
github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0=
|
||||
github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
|
||||
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c=
|
||||
github.com/onsi/gomega v0.0.0-20151007035656-2152b45fa28a/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
|
||||
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
|
||||
|
@ -1526,43 +1525,9 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
|||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/datachannel v1.5.2 h1:piB93s8LGmbECrpO84DnkIVWasRMk3IimbcXkTQLE6E=
|
||||
github.com/pion/datachannel v1.5.2/go.mod h1:FTGQWaHrdCwIJ1rw6xBIfZVkslikjShim5yr05XFuCQ=
|
||||
github.com/pion/dtls/v2 v2.1.3/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus=
|
||||
github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c=
|
||||
github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY=
|
||||
github.com/pion/ice/v2 v2.2.6 h1:R/vaLlI1J2gCx141L5PEwtuGAGcyS6e7E0hDeJFq5Ig=
|
||||
github.com/pion/ice/v2 v2.2.6/go.mod h1:SWuHiOGP17lGromHTFadUe1EuPgFh/oCU6FCMZHooVE=
|
||||
github.com/pion/interceptor v0.1.11 h1:00U6OlqxA3FFB50HSg25J/8cWi7P6FbSzw4eFn24Bvs=
|
||||
github.com/pion/interceptor v0.1.11/go.mod h1:tbtKjZY14awXd7Bq0mmWvgtHB5MDaRN7HV3OZ/uy7s8=
|
||||
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
|
||||
github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g=
|
||||
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U=
|
||||
github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo=
|
||||
github.com/pion/rtp v1.7.13 h1:qcHwlmtiI50t1XivvoawdCGTP4Uiypzfrsap+bijcoA=
|
||||
github.com/pion/rtp v1.7.13/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko=
|
||||
github.com/pion/sctp v1.8.0/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s=
|
||||
github.com/pion/sctp v1.8.2 h1:yBBCIrUMJ4yFICL3RIvR4eh/H2BTTvlligmSTy+3kiA=
|
||||
github.com/pion/sctp v1.8.2/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s=
|
||||
github.com/pion/sdp/v3 v3.0.5 h1:ouvI7IgGl+V4CrqskVtr3AaTrPvPisEOxwgpdktctkU=
|
||||
github.com/pion/sdp/v3 v3.0.5/go.mod h1:iiFWFpQO8Fy3S5ldclBkpXqmWy02ns78NOKoLLL0YQw=
|
||||
github.com/pion/srtp/v2 v2.0.10 h1:b8ZvEuI+mrL8hbr/f1YiJFB34UMrOac3R3N1yq2UN0w=
|
||||
github.com/pion/srtp/v2 v2.0.10/go.mod h1:XEeSWaK9PfuMs7zxXyiN252AHPbH12NX5q/CFDWtUuA=
|
||||
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
|
||||
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
|
||||
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
|
||||
github.com/pion/transport v0.12.3/go.mod h1:OViWW9SP2peE/HbwBvARicmAVnesphkNkCVZIWJ6q9A=
|
||||
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
|
||||
github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA=
|
||||
github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg=
|
||||
github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw=
|
||||
github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
|
||||
github.com/pion/webrtc/v3 v3.1.43 h1:YT3ZTO94UT4kSBvZnRAH82+0jJPUruiKr9CEstdlQzk=
|
||||
github.com/pion/webrtc/v3 v3.1.43/go.mod h1:G/J8k0+grVsjC/rjCZ24AKoCCxcFFODgh7zThNZGs0M=
|
||||
github.com/pkg/browser v0.0.0-20210706143420-7d21f8c997e2/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||
|
@ -2042,9 +2007,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
|||
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220516162934-403b01795ae8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 h1:O8uGbHCqlTp2P6QJSLmCojM4mN6UemYv8K+dCnmHmu0=
|
||||
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
|
@ -2155,7 +2117,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY
|
|||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
|
@ -2178,7 +2139,6 @@ golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qx
|
|||
golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211201190559-0a0e4e1bb54c/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220107192237-5cfca573fb4d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
|
@ -2186,13 +2146,11 @@ golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qx
|
|||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220401154927-543a649e0bdd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY=
|
||||
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
|
||||
golang.org/x/oauth2 v0.0.0-20180227000427-d7d64896b5ff/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
|
@ -2389,7 +2347,6 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
|
317
peer/channel.go
317
peer/channel.go
|
@ -1,317 +0,0 @@
|
|||
package peer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/datachannel"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB
|
||||
maxBufferedAmount uint64 = 1024 * 1024 // 1 MB
|
||||
// For some reason messages larger just don't work...
|
||||
// This shouldn't be a huge deal for real-world usage.
|
||||
// See: https://github.com/pion/datachannel/issues/59
|
||||
maxMessageLength = 64 * 1024 // 64 KB
|
||||
)
|
||||
|
||||
// newChannel creates a new channel and initializes it.
|
||||
// The initialization overrides listener handles, and detaches
|
||||
// the channel on open. The datachannel should not be manually
|
||||
// mutated after being passed to this function.
|
||||
func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOptions) *Channel {
|
||||
channel := &Channel{
|
||||
opts: opts,
|
||||
conn: conn,
|
||||
dc: dc,
|
||||
|
||||
opened: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
sendMore: make(chan struct{}, 1),
|
||||
}
|
||||
channel.init()
|
||||
return channel
|
||||
}
|
||||
|
||||
type ChannelOptions struct {
|
||||
// ID is a channel ID that should be used when `Negotiated`
|
||||
// is true.
|
||||
ID uint16
|
||||
|
||||
// Negotiated returns whether the data channel will already
|
||||
// be active on the other end. Defaults to false.
|
||||
Negotiated bool
|
||||
|
||||
// Arbitrary string that can be parsed on `Accept`.
|
||||
Protocol string
|
||||
|
||||
// Unordered determines whether the channel acts like
|
||||
// a UDP connection. Defaults to false.
|
||||
Unordered bool
|
||||
|
||||
// Whether the channel will be left open on disconnect or not.
|
||||
// If true, data will be buffered on either end to be sent
|
||||
// once reconnected. Defaults to false.
|
||||
OpenOnDisconnect bool
|
||||
}
|
||||
|
||||
// Channel represents a WebRTC DataChannel.
|
||||
//
|
||||
// This struct wraps webrtc.DataChannel to add concurrent-safe usage,
|
||||
// data bufferring, and standardized errors for connection state.
|
||||
//
|
||||
// It modifies the default behavior of a DataChannel by closing on
|
||||
// WebRTC PeerConnection failure. This is done to emulate TCP connections.
|
||||
// This option can be changed in the options when creating a Channel.
|
||||
type Channel struct {
|
||||
opts *ChannelOptions
|
||||
|
||||
conn *Conn
|
||||
dc *webrtc.DataChannel
|
||||
// This field can be nil. It becomes set after the DataChannel
|
||||
// has been opened and is detached.
|
||||
rwc datachannel.ReadWriteCloser
|
||||
reader io.Reader
|
||||
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
closeError error
|
||||
|
||||
opened chan struct{}
|
||||
|
||||
// sendMore is used to block Write operations on a full buffer.
|
||||
// It's signaled when the buffer can accept more data.
|
||||
sendMore chan struct{}
|
||||
writeMutex sync.Mutex
|
||||
}
|
||||
|
||||
// init attaches listeners to the DataChannel to detect opening,
|
||||
// closing, and when the channel is ready to transmit data.
|
||||
//
|
||||
// This should only be called once on creation.
|
||||
func (c *Channel) init() {
|
||||
// WebRTC connections maintain an internal buffer that can fill when:
|
||||
// 1. Data is being sent faster than it can flush.
|
||||
// 2. The connection is disconnected, but data is still being sent.
|
||||
//
|
||||
// This applies a maximum in-memory buffer for data, and will cause
|
||||
// write operations to block once the threshold is set.
|
||||
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
|
||||
c.dc.OnBufferedAmountLow(func() {
|
||||
// Grab the lock to protect the sendMore channel from being
|
||||
// closed in between the isClosed check and the send.
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
case c.sendMore <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
c.dc.OnClose(func() {
|
||||
c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
|
||||
_ = c.closeWithError(ErrClosed)
|
||||
})
|
||||
c.dc.OnOpen(func() {
|
||||
c.closeMutex.Lock()
|
||||
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
|
||||
var err error
|
||||
c.rwc, err = c.dc.Detach()
|
||||
if err != nil {
|
||||
c.closeMutex.Unlock()
|
||||
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
|
||||
return
|
||||
}
|
||||
c.closeMutex.Unlock()
|
||||
|
||||
// pion/webrtc will return an io.ErrShortBuffer when a read
|
||||
// is triggered with a buffer size less than the chunks written.
|
||||
//
|
||||
// This makes sense when considering UDP connections, because
|
||||
// buffering of data that has no transmit guarantees is likely
|
||||
// to cause unexpected behavior.
|
||||
//
|
||||
// When ordered, this adds a bufio.Reader. This ensures additional
|
||||
// data on TCP-like connections can be read in parts, while still
|
||||
// being buffered.
|
||||
if c.opts.Unordered {
|
||||
c.reader = c.rwc
|
||||
} else {
|
||||
// This must be the max message length otherwise a short
|
||||
// buffer error can occur.
|
||||
c.reader = bufio.NewReaderSize(c.rwc, maxMessageLength)
|
||||
}
|
||||
close(c.opened)
|
||||
})
|
||||
|
||||
c.conn.dcDisconnectListeners.Add(1)
|
||||
c.conn.dcFailedListeners.Add(1)
|
||||
c.conn.dcClosedWaitGroup.Add(1)
|
||||
go func() {
|
||||
var err error
|
||||
// A DataChannel can disconnect multiple times, so this needs to loop.
|
||||
for {
|
||||
select {
|
||||
case <-c.conn.closedRTC:
|
||||
// If this channel was closed, there's no need to close again.
|
||||
err = c.conn.closeError
|
||||
case <-c.conn.Closed():
|
||||
// If the RTC connection closed with an error, this channel
|
||||
// should end with the same one.
|
||||
err = c.conn.closeError
|
||||
case <-c.conn.dcDisconnectChannel:
|
||||
// If the RTC connection is disconnected, we need to check if
|
||||
// the DataChannel is supposed to end on disconnect.
|
||||
if c.opts.OpenOnDisconnect {
|
||||
continue
|
||||
}
|
||||
err = xerrors.Errorf("rtc disconnected. closing: %w", ErrClosed)
|
||||
case <-c.conn.dcFailedChannel:
|
||||
// If the RTC connection failed, close the Channel.
|
||||
err = ErrFailed
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.closeWithError(err)
|
||||
}()
|
||||
}
|
||||
|
||||
// Read blocks until data is received.
|
||||
//
|
||||
// This will block until the underlying DataChannel has been opened.
|
||||
func (c *Channel) Read(bytes []byte) (int, error) {
|
||||
err := c.waitOpened()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bytesRead, err := c.reader.Read(bytes)
|
||||
if err != nil {
|
||||
if c.isClosed() {
|
||||
return 0, c.closeError
|
||||
}
|
||||
// An EOF always occurs when the connection is closed.
|
||||
// Alternative close errors will occur first if an unexpected
|
||||
// close has occurred.
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
err = c.closeWithError(ErrClosed)
|
||||
}
|
||||
}
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
// Write sends data to the underlying DataChannel.
|
||||
//
|
||||
// This function will block if too much data is being sent.
|
||||
// Data will buffer if the connection is temporarily disconnected,
|
||||
// and will be flushed upon reconnection.
|
||||
//
|
||||
// If the Channel is setup to close on disconnect, any buffered
|
||||
// data will be lost.
|
||||
func (c *Channel) Write(bytes []byte) (n int, err error) {
|
||||
if len(bytes) > maxMessageLength {
|
||||
return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength)
|
||||
}
|
||||
|
||||
c.writeMutex.Lock()
|
||||
defer c.writeMutex.Unlock()
|
||||
|
||||
err = c.waitOpened()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount {
|
||||
<-c.sendMore
|
||||
}
|
||||
|
||||
return c.rwc.Write(bytes)
|
||||
}
|
||||
|
||||
// Close gracefully closes the DataChannel.
|
||||
func (c *Channel) Close() error {
|
||||
return c.closeWithError(nil)
|
||||
}
|
||||
|
||||
// Label returns the label of the underlying DataChannel.
|
||||
func (c *Channel) Label() string {
|
||||
return c.dc.Label()
|
||||
}
|
||||
|
||||
// Protocol returns the protocol of the underlying DataChannel.
|
||||
func (c *Channel) Protocol() string {
|
||||
return c.dc.Protocol()
|
||||
}
|
||||
|
||||
// NetConn wraps the DataChannel in a struct fulfilling net.Conn.
|
||||
// Read, Write, and Close operations can still be used on the *Channel struct.
|
||||
func (c *Channel) NetConn() net.Conn {
|
||||
return &fakeNetConn{
|
||||
c: c,
|
||||
addr: &peerAddr{},
|
||||
}
|
||||
}
|
||||
|
||||
// closeWithError closes the Channel with the error provided.
|
||||
// If a graceful close occurs, the error will be nil.
|
||||
func (c *Channel) closeWithError(err error) error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
if c.isClosed() {
|
||||
return c.closeError
|
||||
}
|
||||
|
||||
c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
|
||||
if err == nil {
|
||||
c.closeError = ErrClosed
|
||||
} else {
|
||||
c.closeError = err
|
||||
}
|
||||
if c.rwc != nil {
|
||||
_ = c.rwc.Close()
|
||||
}
|
||||
_ = c.dc.Close()
|
||||
|
||||
close(c.closed)
|
||||
close(c.sendMore)
|
||||
c.conn.dcDisconnectListeners.Sub(1)
|
||||
c.conn.dcFailedListeners.Sub(1)
|
||||
c.conn.dcClosedWaitGroup.Done()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Channel) isClosed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Channel) waitOpened() error {
|
||||
select {
|
||||
case <-c.opened:
|
||||
// Re-check the closed channel to prioritize closure.
|
||||
if c.isClosed() {
|
||||
return c.closeError
|
||||
}
|
||||
return nil
|
||||
case <-c.closed:
|
||||
return c.closeError
|
||||
}
|
||||
}
|
616
peer/conn.go
616
peer/conn.go
|
@ -1,616 +0,0 @@
|
|||
package peer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"go.uber.org/atomic"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDisconnected occurs when the connection has disconnected.
|
||||
// The connection will be attempting to reconnect at this point.
|
||||
ErrDisconnected = xerrors.New("connection is disconnected")
|
||||
// ErrFailed occurs when the connection has failed.
|
||||
// The connection will not retry after this point.
|
||||
ErrFailed = xerrors.New("connection has failed")
|
||||
// ErrClosed occurs when the connection was closed. It wraps io.EOF
|
||||
// to fulfill expected read errors from closed pipes.
|
||||
ErrClosed = xerrors.Errorf("connection was closed: %w", io.EOF)
|
||||
|
||||
// The amount of random bytes sent in a ping.
|
||||
pingDataLength = 64
|
||||
)
|
||||
|
||||
// Client creates a new client connection.
|
||||
func Client(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) {
|
||||
return newWithClientOrServer(servers, true, opts)
|
||||
}
|
||||
|
||||
// Server creates a new server connection.
|
||||
func Server(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) {
|
||||
return newWithClientOrServer(servers, false, opts)
|
||||
}
|
||||
|
||||
// newWithClientOrServer constructs a new connection with the client option.
|
||||
// nolint:revive
|
||||
func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOptions) (*Conn, error) {
|
||||
if opts == nil {
|
||||
opts = &ConnOptions{}
|
||||
}
|
||||
|
||||
opts.SettingEngine.DetachDataChannels()
|
||||
logger := logging.NewDefaultLoggerFactory()
|
||||
logger.DefaultLogLevel = logging.LogLevelDisabled
|
||||
opts.SettingEngine.LoggerFactory = logger
|
||||
api := webrtc.NewAPI(webrtc.WithSettingEngine(opts.SettingEngine))
|
||||
rtc, err := api.NewPeerConnection(webrtc.Configuration{
|
||||
ICEServers: servers,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
conn := &Conn{
|
||||
pingChannelID: 1,
|
||||
pingEchoChannelID: 2,
|
||||
rtc: rtc,
|
||||
offerer: client,
|
||||
closed: make(chan struct{}),
|
||||
closedRTC: make(chan struct{}),
|
||||
closedICE: make(chan struct{}),
|
||||
dcOpenChannel: make(chan *webrtc.DataChannel, 8),
|
||||
dcDisconnectChannel: make(chan struct{}),
|
||||
dcFailedChannel: make(chan struct{}),
|
||||
localCandidateChannel: make(chan webrtc.ICECandidateInit),
|
||||
localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
|
||||
negotiated: make(chan struct{}),
|
||||
remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
|
||||
settingEngine: opts.SettingEngine,
|
||||
}
|
||||
conn.loggerValue.Store(opts.Logger)
|
||||
if client {
|
||||
// If we're the client, we want to flip the echo and
|
||||
// ping channel IDs so pings don't accidentally hit each other.
|
||||
conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID
|
||||
}
|
||||
err = conn.init()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("init: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type ConnOptions struct {
|
||||
Logger slog.Logger
|
||||
|
||||
// Enables customization on the underlying WebRTC connection.
|
||||
SettingEngine webrtc.SettingEngine
|
||||
}
|
||||
|
||||
// Conn represents a WebRTC peer connection.
|
||||
//
|
||||
// This struct wraps webrtc.PeerConnection to add bidirectional pings,
|
||||
// concurrent-safe webrtc.DataChannel, and standardized errors for connection state.
|
||||
type Conn struct {
|
||||
rtc *webrtc.PeerConnection
|
||||
// Determines whether this connection will send the offer or the answer.
|
||||
offerer bool
|
||||
|
||||
closed chan struct{}
|
||||
closedRTC chan struct{}
|
||||
closedRTCMutex sync.Mutex
|
||||
closedICE chan struct{}
|
||||
closedICEMutex sync.Mutex
|
||||
closeMutex sync.Mutex
|
||||
closeError error
|
||||
|
||||
dcCreateMutex sync.Mutex
|
||||
dcOpenChannel chan *webrtc.DataChannel
|
||||
dcDisconnectChannel chan struct{}
|
||||
dcDisconnectListeners atomic.Uint32
|
||||
dcFailedChannel chan struct{}
|
||||
dcFailedListeners atomic.Uint32
|
||||
dcClosedWaitGroup sync.WaitGroup
|
||||
|
||||
localCandidateChannel chan webrtc.ICECandidateInit
|
||||
localSessionDescriptionChannel chan webrtc.SessionDescription
|
||||
remoteSessionDescriptionChannel chan webrtc.SessionDescription
|
||||
|
||||
negotiated chan struct{}
|
||||
|
||||
loggerValue atomic.Value
|
||||
settingEngine webrtc.SettingEngine
|
||||
|
||||
pingChannelID uint16
|
||||
pingEchoChannelID uint16
|
||||
|
||||
pingEchoChan *Channel
|
||||
pingEchoOnce sync.Once
|
||||
pingEchoError error
|
||||
pingMutex sync.Mutex
|
||||
pingOnce sync.Once
|
||||
pingChan *Channel
|
||||
pingError error
|
||||
}
|
||||
|
||||
func (c *Conn) logger() slog.Logger {
|
||||
log, valid := c.loggerValue.Load().(slog.Logger)
|
||||
if !valid {
|
||||
return slog.Logger{}
|
||||
}
|
||||
|
||||
return log
|
||||
}
|
||||
|
||||
func (c *Conn) init() error {
|
||||
c.rtc.OnNegotiationNeeded(c.negotiate)
|
||||
c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) {
|
||||
c.closedICEMutex.Lock()
|
||||
defer c.closedICEMutex.Unlock()
|
||||
select {
|
||||
case <-c.closedICE:
|
||||
// Don't log more state changes if we've already closed.
|
||||
return
|
||||
default:
|
||||
c.logger().Debug(context.Background(), "ice connection state updated",
|
||||
slog.F("state", iceConnectionState))
|
||||
|
||||
if iceConnectionState == webrtc.ICEConnectionStateClosed {
|
||||
// pion/webrtc can update this state multiple times.
|
||||
// A connection can never become un-closed, so we
|
||||
// close the channel if it isn't already.
|
||||
close(c.closedICE)
|
||||
}
|
||||
}
|
||||
})
|
||||
c.rtc.OnICEGatheringStateChange(func(iceGatherState webrtc.ICEGathererState) {
|
||||
c.closedICEMutex.Lock()
|
||||
defer c.closedICEMutex.Unlock()
|
||||
select {
|
||||
case <-c.closedICE:
|
||||
// Don't log more state changes if we've already closed.
|
||||
return
|
||||
default:
|
||||
c.logger().Debug(context.Background(), "ice gathering state updated",
|
||||
slog.F("state", iceGatherState))
|
||||
|
||||
if iceGatherState == webrtc.ICEGathererStateClosed {
|
||||
// pion/webrtc can update this state multiple times.
|
||||
// A connection can never become un-closed, so we
|
||||
// close the channel if it isn't already.
|
||||
close(c.closedICE)
|
||||
}
|
||||
}
|
||||
})
|
||||
c.rtc.OnConnectionStateChange(func(peerConnectionState webrtc.PeerConnectionState) {
|
||||
go func() {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
c.logger().Debug(context.Background(), "rtc connection updated",
|
||||
slog.F("state", peerConnectionState))
|
||||
}()
|
||||
|
||||
switch peerConnectionState {
|
||||
case webrtc.PeerConnectionStateDisconnected:
|
||||
for i := 0; i < int(c.dcDisconnectListeners.Load()); i++ {
|
||||
select {
|
||||
case c.dcDisconnectChannel <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
case webrtc.PeerConnectionStateFailed:
|
||||
for i := 0; i < int(c.dcFailedListeners.Load()); i++ {
|
||||
select {
|
||||
case c.dcFailedChannel <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
// pion/webrtc can update this state multiple times.
|
||||
// A connection can never become un-closed, so we
|
||||
// close the channel if it isn't already.
|
||||
c.closedRTCMutex.Lock()
|
||||
defer c.closedRTCMutex.Unlock()
|
||||
select {
|
||||
case <-c.closedRTC:
|
||||
default:
|
||||
close(c.closedRTC)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// These functions need to check if the conn is closed, because they can be
|
||||
// called after being closed.
|
||||
c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) {
|
||||
c.logger().Debug(context.Background(), "signaling state updated",
|
||||
slog.F("state", signalState))
|
||||
})
|
||||
c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) {
|
||||
c.logger().Debug(context.Background(), "dtls transport state updated",
|
||||
slog.F("state", dtlsTransportState))
|
||||
})
|
||||
c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) {
|
||||
c.logger().Debug(context.Background(), "selected candidate pair changed",
|
||||
slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote))
|
||||
})
|
||||
c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) {
|
||||
if iceCandidate == nil {
|
||||
return
|
||||
}
|
||||
// Run this in a goroutine so we don't block pion/webrtc
|
||||
// from continuing.
|
||||
go func() {
|
||||
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
|
||||
select {
|
||||
case <-c.closed:
|
||||
case c.localCandidateChannel <- iceCandidate.ToJSON():
|
||||
}
|
||||
}()
|
||||
})
|
||||
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closed:
|
||||
case c.dcOpenChannel <- dc:
|
||||
}
|
||||
}()
|
||||
})
|
||||
_, err := c.pingChannel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.pingEchoChannel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// negotiate is triggered when a connection is ready to be established.
|
||||
// See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/
|
||||
func (c *Conn) negotiate() {
|
||||
c.logger().Debug(context.Background(), "negotiating")
|
||||
// ICE candidates cannot be added until SessionDescriptions have been
|
||||
// exchanged between peers.
|
||||
defer func() {
|
||||
select {
|
||||
case <-c.negotiated:
|
||||
default:
|
||||
close(c.negotiated)
|
||||
}
|
||||
}()
|
||||
|
||||
if c.offerer {
|
||||
offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{})
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("create offer: %w", err))
|
||||
return
|
||||
}
|
||||
// pion/webrtc will panic if Close is called while this
|
||||
// function is being executed.
|
||||
c.closeMutex.Lock()
|
||||
err = c.rtc.SetLocalDescription(offer)
|
||||
c.closeMutex.Unlock()
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
|
||||
return
|
||||
}
|
||||
c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer))
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.localSessionDescriptionChannel <- offer:
|
||||
}
|
||||
c.logger().Debug(context.Background(), "sent offer")
|
||||
}
|
||||
|
||||
var sessionDescription webrtc.SessionDescription
|
||||
c.logger().Debug(context.Background(), "awaiting remote description...")
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case sessionDescription = <-c.remoteSessionDescriptionChannel:
|
||||
}
|
||||
c.logger().Debug(context.Background(), "setting remote description")
|
||||
|
||||
err := c.rtc.SetRemoteDescription(sessionDescription)
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
|
||||
return
|
||||
}
|
||||
|
||||
if !c.offerer {
|
||||
answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{})
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("create answer: %w", err))
|
||||
return
|
||||
}
|
||||
// pion/webrtc will panic if Close is called while this
|
||||
// function is being executed.
|
||||
c.closeMutex.Lock()
|
||||
err = c.rtc.SetLocalDescription(answer)
|
||||
c.closeMutex.Unlock()
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
|
||||
return
|
||||
}
|
||||
c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer))
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.localSessionDescriptionChannel <- answer:
|
||||
}
|
||||
c.logger().Debug(context.Background(), "sent answer")
|
||||
}
|
||||
}
|
||||
|
||||
// AddRemoteCandidate adds a remote candidate to the RTC connection.
|
||||
func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) {
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
// This must occur in a goroutine to allow the SessionDescriptions
|
||||
// to be exchanged first.
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closed:
|
||||
case <-c.negotiated:
|
||||
}
|
||||
if c.isClosed() {
|
||||
return
|
||||
}
|
||||
c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate))
|
||||
err := c.rtc.AddICECandidate(i)
|
||||
if err != nil {
|
||||
if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed {
|
||||
return
|
||||
}
|
||||
_ = c.CloseWithError(xerrors.Errorf("accept candidate: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// SetRemoteSessionDescription sets the remote description for the WebRTC connection.
|
||||
func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
case c.remoteSessionDescriptionChannel <- sessionDescription:
|
||||
}
|
||||
}
|
||||
|
||||
// LocalSessionDescription returns a channel that emits a session description
|
||||
// when one is required to be exchanged.
|
||||
func (c *Conn) LocalSessionDescription() <-chan webrtc.SessionDescription {
|
||||
return c.localSessionDescriptionChannel
|
||||
}
|
||||
|
||||
// LocalCandidate returns a channel that emits when a local candidate
|
||||
// needs to be exchanged with a remote connection.
|
||||
func (c *Conn) LocalCandidate() <-chan webrtc.ICECandidateInit {
|
||||
return c.localCandidateChannel
|
||||
}
|
||||
|
||||
func (c *Conn) pingChannel() (*Channel, error) {
|
||||
c.pingOnce.Do(func() {
|
||||
c.pingChan, c.pingError = c.dialChannel(context.Background(), "ping", &ChannelOptions{
|
||||
ID: c.pingChannelID,
|
||||
Negotiated: true,
|
||||
OpenOnDisconnect: true,
|
||||
})
|
||||
if c.pingError != nil {
|
||||
return
|
||||
}
|
||||
})
|
||||
return c.pingChan, c.pingError
|
||||
}
|
||||
|
||||
func (c *Conn) pingEchoChannel() (*Channel, error) {
|
||||
c.pingEchoOnce.Do(func() {
|
||||
c.pingEchoChan, c.pingEchoError = c.dialChannel(context.Background(), "echo", &ChannelOptions{
|
||||
ID: c.pingEchoChannelID,
|
||||
Negotiated: true,
|
||||
OpenOnDisconnect: true,
|
||||
})
|
||||
if c.pingEchoError != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
data := make([]byte, pingDataLength)
|
||||
bytesRead, err := c.pingEchoChan.Read(data)
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
|
||||
return
|
||||
}
|
||||
_, err = c.pingEchoChan.Write(data[:bytesRead])
|
||||
if err != nil {
|
||||
_ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return c.pingEchoChan, c.pingEchoError
|
||||
}
|
||||
|
||||
// SetConfiguration applies options to the WebRTC connection.
|
||||
// Generally used for updating transport options, like ICE servers.
|
||||
func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error {
|
||||
return c.rtc.SetConfiguration(configuration)
|
||||
}
|
||||
|
||||
// Accept blocks waiting for a channel to be opened.
|
||||
func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
|
||||
var dataChannel *webrtc.DataChannel
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-c.closed:
|
||||
return nil, c.closeError
|
||||
case dataChannel = <-c.dcOpenChannel:
|
||||
}
|
||||
|
||||
return newChannel(c, dataChannel, &ChannelOptions{}), nil
|
||||
}
|
||||
|
||||
// CreateChannel creates a new DataChannel.
|
||||
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
|
||||
if opts == nil {
|
||||
opts = &ChannelOptions{}
|
||||
}
|
||||
if opts.ID == c.pingChannelID || opts.ID == c.pingEchoChannelID {
|
||||
return nil, xerrors.Errorf("datachannel id %d and %d are reserved for ping", c.pingChannelID, c.pingEchoChannelID)
|
||||
}
|
||||
return c.dialChannel(ctx, label, opts)
|
||||
}
|
||||
|
||||
func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
|
||||
// pion/webrtc is slower when opening multiple channels
|
||||
// in parallel than it is sequentially.
|
||||
c.dcCreateMutex.Lock()
|
||||
defer c.dcCreateMutex.Unlock()
|
||||
|
||||
c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts))
|
||||
var id *uint16
|
||||
if opts.ID != 0 {
|
||||
id = &opts.ID
|
||||
}
|
||||
ordered := true
|
||||
if opts.Unordered {
|
||||
ordered = false
|
||||
}
|
||||
if opts.OpenOnDisconnect && !opts.Negotiated {
|
||||
return nil, xerrors.New("OpenOnDisconnect is only allowed for Negotiated channels")
|
||||
}
|
||||
if c.isClosed() {
|
||||
return nil, xerrors.Errorf("closed: %w", c.closeError)
|
||||
}
|
||||
|
||||
dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{
|
||||
ID: id,
|
||||
Negotiated: &opts.Negotiated,
|
||||
Ordered: &ordered,
|
||||
Protocol: &opts.Protocol,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create data channel: %w", err)
|
||||
}
|
||||
return newChannel(c, dataChannel, opts), nil
|
||||
}
|
||||
|
||||
// Ping returns the duration it took to round-trip data.
|
||||
// Multiple pings cannot occur at the same time, so this function will block.
|
||||
func (c *Conn) Ping() (time.Duration, error) {
|
||||
// Pings are not async, so we need a mutex.
|
||||
c.pingMutex.Lock()
|
||||
defer c.pingMutex.Unlock()
|
||||
|
||||
ping, err := c.pingChannel()
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("get ping channel: %w", err)
|
||||
}
|
||||
pingDataSent := make([]byte, pingDataLength)
|
||||
_, err = rand.Read(pingDataSent)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("read random ping data: %w", err)
|
||||
}
|
||||
start := time.Now()
|
||||
_, err = ping.Write(pingDataSent)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("send ping: %w", err)
|
||||
}
|
||||
c.logger().Debug(context.Background(), "wrote ping",
|
||||
slog.F("connection_state", c.rtc.ConnectionState()))
|
||||
|
||||
pingDataReceived := make([]byte, pingDataLength)
|
||||
_, err = ping.Read(pingDataReceived)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("read ping: %w", err)
|
||||
}
|
||||
end := time.Now()
|
||||
if !bytes.Equal(pingDataSent, pingDataReceived) {
|
||||
return 0, xerrors.Errorf("ping data inconsistency sent != received")
|
||||
}
|
||||
return end.Sub(start), nil
|
||||
}
|
||||
|
||||
func (c *Conn) Closed() <-chan struct{} {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// Close closes the connection and frees all associated resources.
|
||||
func (c *Conn) Close() error {
|
||||
return c.CloseWithError(nil)
|
||||
}
|
||||
|
||||
func (c *Conn) isClosed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CloseWithError closes the connection; subsequent reads/writes will return the error err.
|
||||
func (c *Conn) CloseWithError(err error) error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
if c.isClosed() {
|
||||
return c.closeError
|
||||
}
|
||||
|
||||
logger := c.logger()
|
||||
|
||||
logger.Debug(context.Background(), "closing conn with error", slog.Error(err))
|
||||
if err == nil {
|
||||
c.closeError = ErrClosed
|
||||
} else {
|
||||
c.closeError = err
|
||||
}
|
||||
|
||||
if ch, _ := c.pingChannel(); ch != nil {
|
||||
_ = ch.closeWithError(c.closeError)
|
||||
}
|
||||
// If the WebRTC connection has already been closed (due to failure or disconnect),
|
||||
// this call will return an error that isn't typed. We don't check the error because
|
||||
// closing an already closed connection isn't an issue for us.
|
||||
_ = c.rtc.Close()
|
||||
|
||||
// Waiting for pion/webrtc to report closed state on both of these
|
||||
// ensures no goroutine leaks.
|
||||
if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew {
|
||||
logger.Debug(context.Background(), "waiting for rtc connection close...")
|
||||
<-c.closedRTC
|
||||
}
|
||||
if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew {
|
||||
logger.Debug(context.Background(), "waiting for ice connection close...")
|
||||
<-c.closedICE
|
||||
}
|
||||
|
||||
// Waits for all DataChannels to exit before officially labeling as closed.
|
||||
// All logging, goroutines, and async functionality is cleaned up after this.
|
||||
c.dcClosedWaitGroup.Wait()
|
||||
|
||||
// Disable logging!
|
||||
c.loggerValue.Store(slog.Logger{})
|
||||
logger.Sync()
|
||||
|
||||
logger.Debug(context.Background(), "closed")
|
||||
close(c.closed)
|
||||
return err
|
||||
}
|
|
@ -1,434 +0,0 @@
|
|||
package peer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/transport/vnet"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
var (
|
||||
disconnectedTimeout = func() time.Duration {
|
||||
// Connection state is unfortunately time-based. When resources are
|
||||
// contended, a connection can take greater than this timeout to
|
||||
// handshake, which results in a test flake.
|
||||
//
|
||||
// During local testing resources are rarely contended. Reducing this
|
||||
// timeout leads to faster local development.
|
||||
//
|
||||
// In CI resources are frequently contended, so increasing this value
|
||||
// results in less flakes.
|
||||
if os.Getenv("CI") == "true" {
|
||||
return time.Second
|
||||
}
|
||||
return 100 * time.Millisecond
|
||||
}()
|
||||
failedTimeout = disconnectedTimeout * 3
|
||||
keepAliveInterval = time.Millisecond * 2
|
||||
|
||||
// There's a global race in the vnet library allocation code.
|
||||
// This mutex locks around the creation of the vnet.
|
||||
vnetMutex = sync.Mutex{}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// pion/ice doesn't properly close immediately. The solution for this isn't yet known. See:
|
||||
// https://github.com/pion/ice/pull/413
|
||||
goleak.VerifyTestMain(m,
|
||||
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func1"),
|
||||
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func2"),
|
||||
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).taskLoop"),
|
||||
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
_, err := client.Ping()
|
||||
require.NoError(t, err)
|
||||
_, err = server.Ping()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("PingNetworkOffline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, wan := createPair(t)
|
||||
exchange(t, client, server)
|
||||
_, err := server.Ping()
|
||||
require.NoError(t, err)
|
||||
err = wan.Stop()
|
||||
require.NoError(t, err)
|
||||
_, err = server.Ping()
|
||||
require.ErrorIs(t, err, peer.ErrFailed)
|
||||
})
|
||||
|
||||
t.Run("PingReconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, wan := createPair(t)
|
||||
exchange(t, client, server)
|
||||
_, err := server.Ping()
|
||||
require.NoError(t, err)
|
||||
// Create a channel that closes on disconnect.
|
||||
channel, err := server.CreateChannel(context.Background(), "wow", nil)
|
||||
assert.NoError(t, err)
|
||||
defer channel.Close()
|
||||
|
||||
err = wan.Stop()
|
||||
require.NoError(t, err)
|
||||
// Once the connection is marked as disconnected, this
|
||||
// channel will be closed.
|
||||
_, err = channel.Read(make([]byte, 4))
|
||||
assert.ErrorIs(t, err, peer.ErrClosed)
|
||||
err = wan.Start()
|
||||
require.NoError(t, err)
|
||||
_, err = server.Ping()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Accept", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
defer cch.Close()
|
||||
|
||||
sch, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sch.Close()
|
||||
|
||||
_ = cch.Close()
|
||||
_, err = sch.Read(make([]byte, 4))
|
||||
require.ErrorIs(t, err, peer.ErrClosed)
|
||||
})
|
||||
|
||||
t.Run("AcceptNetworkOffline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, wan := createPair(t)
|
||||
exchange(t, client, server)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
defer cch.Close()
|
||||
sch, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sch.Close()
|
||||
|
||||
err = wan.Stop()
|
||||
require.NoError(t, err)
|
||||
_ = cch.Close()
|
||||
_, err = sch.Read(make([]byte, 4))
|
||||
require.ErrorIs(t, err, peer.ErrClosed)
|
||||
})
|
||||
|
||||
t.Run("Buffering", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
defer cch.Close()
|
||||
|
||||
readErr := make(chan error, 1)
|
||||
go func() {
|
||||
sch, err := server.Accept(ctx)
|
||||
if err != nil {
|
||||
readErr <- err
|
||||
_ = cch.Close()
|
||||
return
|
||||
}
|
||||
defer sch.Close()
|
||||
|
||||
bytes := make([]byte, 4096)
|
||||
for {
|
||||
_, err = sch.Read(bytes)
|
||||
if err != nil {
|
||||
readErr <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
bytes := make([]byte, 4096)
|
||||
for i := 0; i < 1024; i++ {
|
||||
_, err = cch.Write(bytes)
|
||||
require.NoError(t, err, "write i=%d", i)
|
||||
}
|
||||
_ = cch.Close()
|
||||
|
||||
select {
|
||||
case err = <-readErr:
|
||||
require.ErrorIs(t, err, peer.ErrClosed, "read error")
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "timeout waiting for read error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NetConn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
srv, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
go func() {
|
||||
sch, err := server.Accept(ctx)
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
defer sch.Close()
|
||||
|
||||
nc2 := sch.NetConn()
|
||||
defer nc2.Close()
|
||||
|
||||
nc1, err := net.Dial("tcp", srv.Addr().String())
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
defer nc1.Close()
|
||||
|
||||
go func() {
|
||||
defer nc1.Close()
|
||||
defer nc2.Close()
|
||||
_, _ = io.Copy(nc1, nc2)
|
||||
}()
|
||||
_, _ = io.Copy(nc2, nc1)
|
||||
}()
|
||||
go func() {
|
||||
server := http.Server{
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
}),
|
||||
}
|
||||
defer server.Close()
|
||||
_ = server.Serve(srv)
|
||||
}()
|
||||
|
||||
//nolint:forcetypeassert
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
var cch *peer.Channel
|
||||
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cch.NetConn(), nil
|
||||
}
|
||||
c := http.Client{
|
||||
Transport: defaultTransport,
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := c.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, resp.StatusCode, 200)
|
||||
// Triggers any connections to close.
|
||||
// This test below ensures the DataChannel actually closes.
|
||||
defaultTransport.CloseIdleConnections()
|
||||
err = cch.Close()
|
||||
require.ErrorIs(t, err, peer.ErrClosed)
|
||||
})
|
||||
|
||||
t.Run("CloseBeforeNegotiate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
err := client.Close()
|
||||
require.NoError(t, err)
|
||||
err = server.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("CloseWithError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
|
||||
require.NoError(t, err)
|
||||
expectedErr := xerrors.New("wow")
|
||||
_ = conn.CloseWithError(expectedErr)
|
||||
_, err = conn.CreateChannel(context.Background(), "", nil)
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
t.Run("PingConcurrent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := client.Ping()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := server.Ping()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("CandidateBeforeSessionDescription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
server.SetRemoteSessionDescription(<-client.LocalSessionDescription())
|
||||
sdp := <-server.LocalSessionDescription()
|
||||
client.AddRemoteCandidate(<-server.LocalCandidate())
|
||||
client.SetRemoteSessionDescription(sdp)
|
||||
server.AddRemoteCandidate(<-client.LocalCandidate())
|
||||
_, err := client.Ping()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ShortBuffer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
go func() {
|
||||
channel, err := client.CreateChannel(ctx, "test", nil)
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
defer channel.Close()
|
||||
_, err = channel.Write([]byte{1, 2})
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
channel, err := server.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
defer channel.Close()
|
||||
data := make([]byte, 1)
|
||||
_, err = channel.Read(data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint8(0x1), data[0])
|
||||
_, err = channel.Read(data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint8(0x2), data[0])
|
||||
})
|
||||
}
|
||||
|
||||
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
|
||||
loggingFactory := logging.NewDefaultLoggerFactory()
|
||||
loggingFactory.DefaultLogLevel = logging.LogLevelDisabled
|
||||
vnetMutex.Lock()
|
||||
defer vnetMutex.Unlock()
|
||||
wan, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
CIDR: "1.2.3.0/24",
|
||||
LoggerFactory: loggingFactory,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
c1Net := vnet.NewNet(&vnet.NetConfig{
|
||||
StaticIPs: []string{"1.2.3.4"},
|
||||
})
|
||||
err = wan.AddNet(c1Net)
|
||||
require.NoError(t, err)
|
||||
c2Net := vnet.NewNet(&vnet.NetConfig{
|
||||
StaticIPs: []string{"1.2.3.5"},
|
||||
})
|
||||
err = wan.AddNet(c2Net)
|
||||
require.NoError(t, err)
|
||||
|
||||
c1SettingEngine := webrtc.SettingEngine{}
|
||||
c1SettingEngine.SetVNet(c1Net)
|
||||
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
|
||||
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
||||
channel1, err := peer.Client([]webrtc.ICEServer{{}}, &peer.ConnOptions{
|
||||
SettingEngine: c1SettingEngine,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
channel1.Close()
|
||||
})
|
||||
c2SettingEngine := webrtc.SettingEngine{}
|
||||
c2SettingEngine.SetVNet(c2Net)
|
||||
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
|
||||
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
|
||||
channel2, err := peer.Server([]webrtc.ICEServer{{}}, &peer.ConnOptions{
|
||||
SettingEngine: c2SettingEngine,
|
||||
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
channel2.Close()
|
||||
})
|
||||
|
||||
err = wan.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = wan.Stop()
|
||||
})
|
||||
|
||||
return channel1, channel2, wan
|
||||
}
|
||||
|
||||
func exchange(t *testing.T, client, server *peer.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case c := <-server.LocalCandidate():
|
||||
client.AddRemoteCandidate(c)
|
||||
case c := <-server.LocalSessionDescription():
|
||||
client.SetRemoteSessionDescription(c)
|
||||
case <-server.Closed():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case c := <-client.LocalCandidate():
|
||||
server.AddRemoteCandidate(c)
|
||||
case c := <-client.LocalSessionDescription():
|
||||
server.SetRemoteSessionDescription(c)
|
||||
case <-client.Closed():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type peerAddr struct{}
|
||||
|
||||
// Statically checks if we properly implement net.Addr.
|
||||
var _ net.Addr = &peerAddr{}
|
||||
|
||||
func (*peerAddr) Network() string {
|
||||
return "peer"
|
||||
}
|
||||
|
||||
func (*peerAddr) String() string {
|
||||
return "peer/unknown-addr"
|
||||
}
|
||||
|
||||
type fakeNetConn struct {
|
||||
c *Channel
|
||||
addr *peerAddr
|
||||
}
|
||||
|
||||
// Statically checks if we properly implement net.Conn.
|
||||
var _ net.Conn = &fakeNetConn{}
|
||||
|
||||
func (c *fakeNetConn) Read(b []byte) (n int, err error) {
|
||||
return c.c.Read(b)
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) Write(b []byte) (n int, err error) {
|
||||
return c.c.Write(b)
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) Close() error {
|
||||
return c.c.Close()
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) LocalAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *fakeNetConn) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (*fakeNetConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeNetConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeNetConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
package peerbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
)
|
||||
|
||||
// Dial consumes the PeerBroker gRPC connection negotiation stream to produce a WebRTC peered connection.
|
||||
func Dial(stream proto.DRPCPeerBroker_NegotiateConnectionClient, iceServers []webrtc.ICEServer, opts *peer.ConnOptions) (*peer.Conn, error) {
|
||||
peerConn, err := peer.Client(iceServers, opts)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
// Exchanging messages from the peer connection to negotiate a connection.
|
||||
for {
|
||||
select {
|
||||
case <-peerConn.Closed():
|
||||
return
|
||||
case sessionDescription := <-peerConn.LocalSessionDescription():
|
||||
err = stream.Send(&proto.Exchange{
|
||||
Message: &proto.Exchange_Sdp{
|
||||
Sdp: &proto.WebRTCSessionDescription{
|
||||
SdpType: int32(sessionDescription.Type),
|
||||
Sdp: sessionDescription.SDP,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err))
|
||||
return
|
||||
}
|
||||
case iceCandidate := <-peerConn.LocalCandidate():
|
||||
err = stream.Send(&proto.Exchange{
|
||||
Message: &proto.Exchange_IceCandidate{
|
||||
IceCandidate: iceCandidate.Candidate,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// Exchanging messages from the server to negotiate a connection.
|
||||
for {
|
||||
serverToClientMessage, err := stream.Recv()
|
||||
if err != nil {
|
||||
// p2p connections should never die if this stream does due
|
||||
// to proper closure or context cancellation!
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
_ = peerConn.CloseWithError(xerrors.Errorf("recv: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case serverToClientMessage.GetSdp() != nil:
|
||||
peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPType(serverToClientMessage.GetSdp().SdpType),
|
||||
SDP: serverToClientMessage.GetSdp().Sdp,
|
||||
})
|
||||
case serverToClientMessage.GetIceCandidate() != "":
|
||||
peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{
|
||||
Candidate: serverToClientMessage.GetIceCandidate(),
|
||||
})
|
||||
default:
|
||||
_ = peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(serverToClientMessage).String()))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return peerConn, nil
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
package peerbroker_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Connect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
|
||||
return []webrtc.ICEServer{{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
}}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
|
||||
SettingEngine: settingEngine,
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := api.NegotiateConnection(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
}}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
SettingEngine: settingEngine,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
serverConn, err := listener.Accept()
|
||||
require.NoError(t, err)
|
||||
defer serverConn.Close()
|
||||
_, err = serverConn.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = clientConn.Ping()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
|
@ -1,188 +0,0 @@
|
|||
package peerbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"golang.org/x/xerrors"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
)
|
||||
|
||||
// ConnSettingsFunc returns initialization options for a connection
|
||||
type ConnSettingsFunc func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error)
|
||||
|
||||
// Listen consumes the transport as the server-side of the PeerBroker dRPC service.
|
||||
// The Accept function must be serviced, or new connections will hang.
|
||||
func Listen(connListener net.Listener, connSettingsFunc ConnSettingsFunc) (*Listener, error) {
|
||||
if connSettingsFunc == nil {
|
||||
connSettingsFunc = func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
|
||||
return []webrtc.ICEServer{}, nil, nil
|
||||
}
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
listener := &Listener{
|
||||
connectionChannel: make(chan *peer.Conn),
|
||||
connectionListener: connListener,
|
||||
|
||||
closeFunc: cancelFunc,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterPeerBroker(mux, &peerBrokerService{
|
||||
connSettingsFunc: connSettingsFunc,
|
||||
|
||||
listener: listener,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("register peer broker: %w", err)
|
||||
}
|
||||
srv := drpcserver.New(mux)
|
||||
go func() {
|
||||
err := srv.Serve(ctx, connListener)
|
||||
_ = listener.closeWithError(err)
|
||||
}()
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
connectionChannel chan *peer.Conn
|
||||
connectionListener net.Listener
|
||||
|
||||
closeFunc context.CancelFunc
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
closeError error
|
||||
}
|
||||
|
||||
// Accept blocks until a connection arrives or the listener is closed.
|
||||
func (l *Listener) Accept() (*peer.Conn, error) {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return nil, l.closeError
|
||||
case conn := <-l.connectionChannel:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close ends the listener. This will block all new WebRTC connections
|
||||
// from establishing, but will not close active connections.
|
||||
func (l *Listener) Close() error {
|
||||
return l.closeWithError(io.EOF)
|
||||
}
|
||||
|
||||
func (l *Listener) closeWithError(err error) error {
|
||||
l.closeMutex.Lock()
|
||||
defer l.closeMutex.Unlock()
|
||||
|
||||
if l.isClosed() {
|
||||
return l.closeError
|
||||
}
|
||||
|
||||
_ = l.connectionListener.Close()
|
||||
l.closeError = err
|
||||
l.closeFunc()
|
||||
close(l.closed)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) isClosed() bool {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Implements the PeerBroker service protobuf definition.
|
||||
type peerBrokerService struct {
|
||||
listener *Listener
|
||||
|
||||
connSettingsFunc ConnSettingsFunc
|
||||
}
|
||||
|
||||
// NegotiateConnection negotiates a WebRTC connection.
|
||||
func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
|
||||
iceServers, connOptions, err := b.connSettingsFunc(stream.Context())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get connection settings: %w", err)
|
||||
}
|
||||
peerConn, err := peer.Server(iceServers, connOptions)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
select {
|
||||
case <-b.listener.closed:
|
||||
return peerConn.CloseWithError(b.listener.closeError)
|
||||
case b.listener.connectionChannel <- peerConn:
|
||||
}
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
for {
|
||||
select {
|
||||
case <-peerConn.Closed():
|
||||
return
|
||||
case sessionDescription := <-peerConn.LocalSessionDescription():
|
||||
err = stream.Send(&proto.Exchange{
|
||||
Message: &proto.Exchange_Sdp{
|
||||
Sdp: &proto.WebRTCSessionDescription{
|
||||
SdpType: int32(sessionDescription.Type),
|
||||
Sdp: sessionDescription.SDP,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err))
|
||||
return
|
||||
}
|
||||
case iceCandidate := <-peerConn.LocalCandidate():
|
||||
err = stream.Send(&proto.Exchange{
|
||||
Message: &proto.Exchange_IceCandidate{
|
||||
IceCandidate: iceCandidate.Candidate,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
for {
|
||||
clientToServerMessage, err := stream.Recv()
|
||||
if err != nil {
|
||||
// p2p connections should never die if this stream does due
|
||||
// to proper closure or context cancellation!
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return peerConn.CloseWithError(xerrors.Errorf("recv: %w", err))
|
||||
}
|
||||
|
||||
switch {
|
||||
case clientToServerMessage.GetSdp() != nil:
|
||||
peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{
|
||||
Type: webrtc.SDPType(clientToServerMessage.GetSdp().SdpType),
|
||||
SDP: clientToServerMessage.GetSdp().Sdp,
|
||||
})
|
||||
case clientToServerMessage.GetIceCandidate() != "":
|
||||
peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{
|
||||
Candidate: clientToServerMessage.GetIceCandidate(),
|
||||
})
|
||||
default:
|
||||
return peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(clientToServerMessage).String()))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package peerbroker_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Ensures connections blocked on Accept() are
|
||||
// closed if the listener is.
|
||||
t.Run("NoAcceptClosed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
listener, err := peerbroker.Listen(server, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := api.NegotiateConnection(ctx)
|
||||
require.NoError(t, err)
|
||||
clientConn, err := peerbroker.Dial(stream, nil, nil)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
_ = listener.Close()
|
||||
})
|
||||
|
||||
// Ensures Accept() properly exits when Close() is called.
|
||||
t.Run("AcceptClosed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
listener, err := peerbroker.Listen(server, nil)
|
||||
require.NoError(t, err)
|
||||
go listener.Close()
|
||||
_, err = listener.Accept()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
})
|
||||
}
|
|
@ -1,269 +0,0 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.5
|
||||
// source: peerbroker/proto/peerbroker.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type WebRTCSessionDescription struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
SdpType int32 `protobuf:"varint,1,opt,name=sdp_type,json=sdpType,proto3" json:"sdp_type,omitempty"`
|
||||
Sdp string `protobuf:"bytes,2,opt,name=sdp,proto3" json:"sdp,omitempty"`
|
||||
}
|
||||
|
||||
func (x *WebRTCSessionDescription) Reset() {
|
||||
*x = WebRTCSessionDescription{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *WebRTCSessionDescription) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*WebRTCSessionDescription) ProtoMessage() {}
|
||||
|
||||
func (x *WebRTCSessionDescription) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use WebRTCSessionDescription.ProtoReflect.Descriptor instead.
|
||||
func (*WebRTCSessionDescription) Descriptor() ([]byte, []int) {
|
||||
return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *WebRTCSessionDescription) GetSdpType() int32 {
|
||||
if x != nil {
|
||||
return x.SdpType
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *WebRTCSessionDescription) GetSdp() string {
|
||||
if x != nil {
|
||||
return x.Sdp
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Exchange struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Message:
|
||||
//
|
||||
// *Exchange_Sdp
|
||||
// *Exchange_IceCandidate
|
||||
Message isExchange_Message `protobuf_oneof:"message"`
|
||||
}
|
||||
|
||||
func (x *Exchange) Reset() {
|
||||
*x = Exchange{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Exchange) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Exchange) ProtoMessage() {}
|
||||
|
||||
func (x *Exchange) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Exchange.ProtoReflect.Descriptor instead.
|
||||
func (*Exchange) Descriptor() ([]byte, []int) {
|
||||
return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (m *Exchange) GetMessage() isExchange_Message {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Exchange) GetSdp() *WebRTCSessionDescription {
|
||||
if x, ok := x.GetMessage().(*Exchange_Sdp); ok {
|
||||
return x.Sdp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Exchange) GetIceCandidate() string {
|
||||
if x, ok := x.GetMessage().(*Exchange_IceCandidate); ok {
|
||||
return x.IceCandidate
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type isExchange_Message interface {
|
||||
isExchange_Message()
|
||||
}
|
||||
|
||||
type Exchange_Sdp struct {
|
||||
Sdp *WebRTCSessionDescription `protobuf:"bytes,1,opt,name=sdp,proto3,oneof"`
|
||||
}
|
||||
|
||||
type Exchange_IceCandidate struct {
|
||||
IceCandidate string `protobuf:"bytes,2,opt,name=ice_candidate,json=iceCandidate,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*Exchange_Sdp) isExchange_Message() {}
|
||||
|
||||
func (*Exchange_IceCandidate) isExchange_Message() {}
|
||||
|
||||
var File_peerbroker_proto_peerbroker_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_peerbroker_proto_peerbroker_proto_rawDesc = []byte{
|
||||
0x0a, 0x21, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x2f, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x22,
|
||||
0x47, 0x0a, 0x18, 0x57, 0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e,
|
||||
0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x0a, 0x08, 0x73,
|
||||
0x64, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x73,
|
||||
0x64, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x64, 0x70, 0x22, 0x76, 0x0a, 0x08, 0x45, 0x78, 0x63, 0x68,
|
||||
0x61, 0x6e, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28,
|
||||
0x0b, 0x32, 0x24, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x57,
|
||||
0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x44, 0x65, 0x73, 0x63,
|
||||
0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x73, 0x64, 0x70, 0x12, 0x25,
|
||||
0x0a, 0x0d, 0x69, 0x63, 0x65, 0x5f, 0x63, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x18,
|
||||
0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x69, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64,
|
||||
0x69, 0x64, 0x61, 0x74, 0x65, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x32, 0x53, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x42, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x12, 0x45,
|
||||
0x0a, 0x13, 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, 0x65, 0x43, 0x6f, 0x6e, 0x6e, 0x65,
|
||||
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x14, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b,
|
||||
0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x14, 0x2e, 0x70, 0x65,
|
||||
0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67,
|
||||
0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
|
||||
0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f,
|
||||
0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_peerbroker_proto_peerbroker_proto_rawDescOnce sync.Once
|
||||
file_peerbroker_proto_peerbroker_proto_rawDescData = file_peerbroker_proto_peerbroker_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_peerbroker_proto_peerbroker_proto_rawDescGZIP() []byte {
|
||||
file_peerbroker_proto_peerbroker_proto_rawDescOnce.Do(func() {
|
||||
file_peerbroker_proto_peerbroker_proto_rawDescData = protoimpl.X.CompressGZIP(file_peerbroker_proto_peerbroker_proto_rawDescData)
|
||||
})
|
||||
return file_peerbroker_proto_peerbroker_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_peerbroker_proto_peerbroker_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_peerbroker_proto_peerbroker_proto_goTypes = []interface{}{
|
||||
(*WebRTCSessionDescription)(nil), // 0: peerbroker.WebRTCSessionDescription
|
||||
(*Exchange)(nil), // 1: peerbroker.Exchange
|
||||
}
|
||||
var file_peerbroker_proto_peerbroker_proto_depIdxs = []int32{
|
||||
0, // 0: peerbroker.Exchange.sdp:type_name -> peerbroker.WebRTCSessionDescription
|
||||
1, // 1: peerbroker.PeerBroker.NegotiateConnection:input_type -> peerbroker.Exchange
|
||||
1, // 2: peerbroker.PeerBroker.NegotiateConnection:output_type -> peerbroker.Exchange
|
||||
2, // [2:3] is the sub-list for method output_type
|
||||
1, // [1:2] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_peerbroker_proto_peerbroker_proto_init() }
|
||||
func file_peerbroker_proto_peerbroker_proto_init() {
|
||||
if File_peerbroker_proto_peerbroker_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_peerbroker_proto_peerbroker_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*WebRTCSessionDescription); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_peerbroker_proto_peerbroker_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Exchange); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_peerbroker_proto_peerbroker_proto_msgTypes[1].OneofWrappers = []interface{}{
|
||||
(*Exchange_Sdp)(nil),
|
||||
(*Exchange_IceCandidate)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_peerbroker_proto_peerbroker_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_peerbroker_proto_peerbroker_proto_goTypes,
|
||||
DependencyIndexes: file_peerbroker_proto_peerbroker_proto_depIdxs,
|
||||
MessageInfos: file_peerbroker_proto_peerbroker_proto_msgTypes,
|
||||
}.Build()
|
||||
File_peerbroker_proto_peerbroker_proto = out.File
|
||||
file_peerbroker_proto_peerbroker_proto_rawDesc = nil
|
||||
file_peerbroker_proto_peerbroker_proto_goTypes = nil
|
||||
file_peerbroker_proto_peerbroker_proto_depIdxs = nil
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/coder/coder/peerbroker/proto";
|
||||
|
||||
package peerbroker;
|
||||
|
||||
message WebRTCSessionDescription {
|
||||
int32 sdp_type = 1;
|
||||
string sdp = 2;
|
||||
}
|
||||
|
||||
message Exchange {
|
||||
oneof message {
|
||||
WebRTCSessionDescription sdp = 1;
|
||||
string ice_candidate = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// PeerBroker mediates WebRTC connection signaling.
|
||||
service PeerBroker {
|
||||
// NegotiateConnection establishes a bidirectional stream to negotiate a new WebRTC connection.
|
||||
// 1. Client sends WebRTCSessionDescription to the server.
|
||||
// 2. Server sends WebRTCSessionDescription to the client, exchanging encryption keys.
|
||||
// 3. Client<->Server exchange ICE Candidates to establish a peered connection.
|
||||
//
|
||||
// See: https://davekilian.com/webrtc-the-hard-way.html
|
||||
rpc NegotiateConnection(stream Exchange) returns (stream Exchange);
|
||||
}
|
|
@ -1,146 +0,0 @@
|
|||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.26
|
||||
// source: peerbroker/proto/peerbroker.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
errors "errors"
|
||||
protojson "google.golang.org/protobuf/encoding/protojson"
|
||||
proto "google.golang.org/protobuf/proto"
|
||||
drpc "storj.io/drpc"
|
||||
drpcerr "storj.io/drpc/drpcerr"
|
||||
)
|
||||
|
||||
type drpcEncoding_File_peerbroker_proto_peerbroker_proto struct{}
|
||||
|
||||
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Marshal(msg drpc.Message) ([]byte, error) {
|
||||
return proto.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
|
||||
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Unmarshal(buf []byte, msg drpc.Message) error {
|
||||
return proto.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
|
||||
return protojson.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
|
||||
return protojson.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
type DRPCPeerBrokerClient interface {
|
||||
DRPCConn() drpc.Conn
|
||||
|
||||
NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error)
|
||||
}
|
||||
|
||||
type drpcPeerBrokerClient struct {
|
||||
cc drpc.Conn
|
||||
}
|
||||
|
||||
func NewDRPCPeerBrokerClient(cc drpc.Conn) DRPCPeerBrokerClient {
|
||||
return &drpcPeerBrokerClient{cc}
|
||||
}
|
||||
|
||||
func (c *drpcPeerBrokerClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcPeerBrokerClient) NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &drpcPeerBroker_NegotiateConnectionClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DRPCPeerBroker_NegotiateConnectionClient interface {
|
||||
drpc.Stream
|
||||
Send(*Exchange) error
|
||||
Recv() (*Exchange, error)
|
||||
}
|
||||
|
||||
type drpcPeerBroker_NegotiateConnectionClient struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcPeerBroker_NegotiateConnectionClient) Send(m *Exchange) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
|
||||
}
|
||||
|
||||
func (x *drpcPeerBroker_NegotiateConnectionClient) Recv() (*Exchange, error) {
|
||||
m := new(Exchange)
|
||||
if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (x *drpcPeerBroker_NegotiateConnectionClient) RecvMsg(m *Exchange) error {
|
||||
return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
|
||||
}
|
||||
|
||||
type DRPCPeerBrokerServer interface {
|
||||
NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error
|
||||
}
|
||||
|
||||
type DRPCPeerBrokerUnimplementedServer struct{}
|
||||
|
||||
func (s *DRPCPeerBrokerUnimplementedServer) NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error {
|
||||
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCPeerBrokerDescription struct{}
|
||||
|
||||
func (DRPCPeerBrokerDescription) NumMethods() int { return 1 }
|
||||
|
||||
func (DRPCPeerBrokerDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return nil, srv.(DRPCPeerBrokerServer).
|
||||
NegotiateConnection(
|
||||
&drpcPeerBroker_NegotiateConnectionStream{in1.(drpc.Stream)},
|
||||
)
|
||||
}, DRPCPeerBrokerServer.NegotiateConnection, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func DRPCRegisterPeerBroker(mux drpc.Mux, impl DRPCPeerBrokerServer) error {
|
||||
return mux.Register(impl, DRPCPeerBrokerDescription{})
|
||||
}
|
||||
|
||||
type DRPCPeerBroker_NegotiateConnectionStream interface {
|
||||
drpc.Stream
|
||||
Send(*Exchange) error
|
||||
Recv() (*Exchange, error)
|
||||
}
|
||||
|
||||
type drpcPeerBroker_NegotiateConnectionStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcPeerBroker_NegotiateConnectionStream) Send(m *Exchange) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
|
||||
}
|
||||
|
||||
func (x *drpcPeerBroker_NegotiateConnectionStream) Recv() (*Exchange, error) {
|
||||
m := new(Exchange)
|
||||
if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (x *drpcPeerBroker_NegotiateConnectionStream) RecvMsg(m *Exchange) error {
|
||||
return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
|
||||
}
|
|
@ -1,283 +0,0 @@
|
|||
package peerbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"golang.org/x/xerrors"
|
||||
protobuf "google.golang.org/protobuf/proto"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
// Each NegotiateConnection() function call spawns a new stream.
|
||||
streamIDLength = len(uuid.NewString())
|
||||
// We shouldn't PubSub anything larger than this!
|
||||
maxPayloadSizeBytes = 8192
|
||||
)
|
||||
|
||||
// ProxyOptions provides values to configure a proxy.
|
||||
type ProxyOptions struct {
|
||||
ChannelID string
|
||||
Logger slog.Logger
|
||||
Pubsub database.Pubsub
|
||||
}
|
||||
|
||||
// ProxyDial writes client negotiation streams over PubSub.
|
||||
//
|
||||
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
|
||||
// messages are small in size (<=8KB), and we don't require delivery
|
||||
// guarantees because connections can always be renegotiated.
|
||||
//
|
||||
// ┌────────────────────┐ ┌─────────────────────────────┐
|
||||
// │ coderd │ │ coderd │
|
||||
//
|
||||
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
|
||||
// │ client │ │ │ │ │ ┌─────┐
|
||||
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
|
||||
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
|
||||
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
|
||||
//
|
||||
// │ │ │NegotiateConnection() streams│
|
||||
// │<stream-id><payload>│ │or write to existing ones. │
|
||||
// └────────────────────┘ └─────────────────────────────┘
|
||||
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
|
||||
proxyDial := &proxyDial{
|
||||
channelID: options.ChannelID,
|
||||
logger: options.Logger,
|
||||
pubsub: options.Pubsub,
|
||||
connection: client,
|
||||
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
|
||||
}
|
||||
return proxyDial, proxyDial.listen()
|
||||
}
|
||||
|
||||
// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
|
||||
// as new NegotiateConnection() streams.
|
||||
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
|
||||
channelID: options.ChannelID,
|
||||
pubsub: options.Pubsub,
|
||||
logger: options.Logger,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("register peer broker: %w", err)
|
||||
}
|
||||
server := drpcserver.New(mux)
|
||||
err = server.Serve(ctx, connListener)
|
||||
if err != nil {
|
||||
if errors.Is(err, yamux.ErrSessionShutdown) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("serve: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type proxyListen struct {
|
||||
channelID string
|
||||
pubsub database.Pubsub
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
|
||||
streamID := uuid.NewString()
|
||||
var err error
|
||||
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
|
||||
err := p.onServerToClientMessage(streamID, stream, message)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe: %w", err)
|
||||
}
|
||||
defer closeSubscribe()
|
||||
for {
|
||||
clientToServerMessage, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return xerrors.Errorf("recv: %w", err)
|
||||
}
|
||||
data, err := protobuf.Marshal(clientToServerMessage)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
if len(data) > maxPayloadSizeBytes {
|
||||
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
|
||||
}
|
||||
data = append([]byte(streamID), data...)
|
||||
err = p.pubsub.Publish(proxyOutID(p.channelID), marshal(data))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
|
||||
var err error
|
||||
message, err = unmarshal(message)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("decode: %w", err)
|
||||
}
|
||||
if len(message) < streamIDLength {
|
||||
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
|
||||
}
|
||||
serverStreamID := string(message[0:streamIDLength])
|
||||
if serverStreamID != streamID {
|
||||
// It's not trying to communicate with this stream!
|
||||
return nil
|
||||
}
|
||||
var msg proto.Exchange
|
||||
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("unmarshal message: %w", err)
|
||||
}
|
||||
err = stream.Send(&msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("send message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type proxyDial struct {
|
||||
channelID string
|
||||
pubsub database.Pubsub
|
||||
logger slog.Logger
|
||||
|
||||
connection proto.DRPCPeerBrokerClient
|
||||
closeSubscribe func()
|
||||
streamMutex sync.Mutex
|
||||
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
|
||||
}
|
||||
|
||||
func (p *proxyDial) listen() error {
|
||||
var err error
|
||||
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
|
||||
err := p.onClientToServerMessage(ctx, message)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
|
||||
var err error
|
||||
message, err = unmarshal(message)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("decode: %w", err)
|
||||
}
|
||||
if len(message) < streamIDLength {
|
||||
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
|
||||
}
|
||||
streamID := string(message[0:streamIDLength])
|
||||
p.streamMutex.Lock()
|
||||
stream, ok := p.streams[streamID]
|
||||
if !ok {
|
||||
stream, err = p.connection.NegotiateConnection(ctx)
|
||||
if err != nil {
|
||||
p.streamMutex.Unlock()
|
||||
return xerrors.Errorf("negotiate connection: %w", err)
|
||||
}
|
||||
p.streams[streamID] = stream
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
|
||||
err := p.onServerToClientMessage(streamID, stream)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
<-stream.Context().Done()
|
||||
p.streamMutex.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMutex.Unlock()
|
||||
}()
|
||||
}
|
||||
p.streamMutex.Unlock()
|
||||
|
||||
var msg proto.Exchange
|
||||
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("unmarshal message: %w", err)
|
||||
}
|
||||
err = stream.Send(&msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
|
||||
for {
|
||||
serverToClientMessage, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
break
|
||||
}
|
||||
return xerrors.Errorf("recv: %w", err)
|
||||
}
|
||||
data, err := protobuf.Marshal(serverToClientMessage)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
if len(data) > maxPayloadSizeBytes {
|
||||
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
|
||||
}
|
||||
data = append([]byte(streamID), data...)
|
||||
err = p.pubsub.Publish(proxyInID(p.channelID), marshal(data))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyDial) Close() error {
|
||||
p.streamMutex.Lock()
|
||||
defer p.streamMutex.Unlock()
|
||||
p.closeSubscribe()
|
||||
return nil
|
||||
}
|
||||
|
||||
// base64 needs to be used here to keep the pubsub messages in UTF-8 range.
|
||||
// PostgreSQL cannot handle non UTF-8 messages over pubsub.
|
||||
func marshal(data []byte) []byte {
|
||||
return []byte(base64.StdEncoding.EncodeToString(data))
|
||||
}
|
||||
|
||||
func unmarshal(data []byte) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(string(data))
|
||||
}
|
||||
|
||||
func proxyOutID(channelID string) string {
|
||||
return fmt.Sprintf("%s-out", channelID)
|
||||
}
|
||||
|
||||
func proxyInID(channelID string) string {
|
||||
return fmt.Sprintf("%s-in", channelID)
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
package peerbroker_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
channelID := "hello"
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
dialerClient, dialerServer := provisionersdk.TransportPipe()
|
||||
defer dialerClient.Close()
|
||||
defer dialerServer.Close()
|
||||
listenerClient, listenerServer := provisionersdk.TransportPipe()
|
||||
defer listenerClient.Close()
|
||||
defer listenerServer.Close()
|
||||
|
||||
listener, err := peerbroker.Listen(listenerServer, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
|
||||
return nil, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
|
||||
}, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
|
||||
ChannelID: channelID,
|
||||
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
|
||||
Pubsub: pubsub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = proxyCloser.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
|
||||
ChannelID: channelID,
|
||||
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
|
||||
Pubsub: pubsub,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
|
||||
stream, err := api.NegotiateConnection(ctx)
|
||||
require.NoError(t, err)
|
||||
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
}}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
serverConn, err := listener.Accept()
|
||||
require.NoError(t, err)
|
||||
defer serverConn.Close()
|
||||
_, err = serverConn.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = clientConn.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = dialerServer.Close()
|
||||
wg.Wait()
|
||||
}
|
Loading…
Reference in New Issue