chore: Remove WebRTC networking (#3881)

* chore: Remove WebRTC networking

* Fix race condition

* Fix WebSocket not closing
This commit is contained in:
Kyle Carberry 2022-09-19 19:46:29 -05:00 committed by GitHub
parent 1186e643ec
commit 714c366d16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 310 additions and 4225 deletions

View File

@ -386,7 +386,6 @@ lint/shellcheck: $(shell shfmt -f .)
gen: \
coderd/database/dump.sql \
coderd/database/querier.go \
peerbroker/proto/peerbroker.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
site/src/api/typesGenerated.ts
@ -395,7 +394,7 @@ gen: \
# Mark all generated files as fresh so make thinks they're up-to-date. This is
# used during releases so we don't run generation scripts.
gen/mark-fresh:
files="coderd/database/dump.sql coderd/database/querier.go peerbroker/proto/peerbroker.pb.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts"
files="coderd/database/dump.sql coderd/database/querier.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts"
for file in $$files; do
echo "$$file"
if [ ! -f "$$file" ]; then
@ -417,14 +416,6 @@ coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/dat
coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) coderd/database/gen/enum/main.go
./coderd/database/generate.sh
peerbroker/proto/peerbroker.pb.go: peerbroker/proto/peerbroker.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./peerbroker/proto/peerbroker.proto
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
protoc \
--go_out=. \

View File

@ -11,7 +11,6 @@ import (
"io"
"net"
"net/netip"
"net/url"
"os"
"os/exec"
"os/user"
@ -34,8 +33,6 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/pty"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
@ -64,7 +61,6 @@ var (
type Options struct {
CoordinatorDialer CoordinatorDialer
WebRTCDialer WebRTCDialer
FetchMetadata FetchMetadata
StatsReporter StatsReporter
@ -80,8 +76,6 @@ type Metadata struct {
Directory string `json:"directory"`
}
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
// CoordinatorDialer is a function that constructs a new broker.
// A dialer must be passed in to allow for reconnects.
type CoordinatorDialer func(ctx context.Context) (net.Conn, error)
@ -95,7 +89,6 @@ func New(options Options) io.Closer {
}
ctx, cancelFunc := context.WithCancel(context.Background())
server := &agent{
webrtcDialer: options.WebRTCDialer,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
@ -111,8 +104,7 @@ func New(options Options) io.Closer {
}
type agent struct {
webrtcDialer WebRTCDialer
logger slog.Logger
logger slog.Logger
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
@ -173,9 +165,6 @@ func (a *agent) run(ctx context.Context) {
}
}()
if a.webrtcDialer != nil {
go a.runWebRTCNetworking(ctx)
}
if metadata.DERPMap != nil {
go a.runTailnet(ctx, metadata.DERPMap)
}
@ -326,49 +315,6 @@ func (a *agent) runCoordinator(ctx context.Context) {
}
}
func (a *agent) runWebRTCNetworking(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = a.webrtcDialer(ctx, a.logger)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if a.isClosed() {
return
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "connected to webrtc broker")
break
}
select {
case <-ctx.Done():
return
default:
}
for {
conn, err := peerListener.Accept()
if err != nil {
if a.isClosed() {
return
}
a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
a.runWebRTCNetworking(ctx)
return
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go a.handlePeerConn(ctx, conn)
}
}
func (a *agent) runStartupScript(ctx context.Context, script string) error {
if script == "" {
return nil
@ -401,74 +347,6 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
return nil
}
func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) {
go func() {
select {
case <-a.closed:
case <-peerConn.Closed():
}
_ = peerConn.Close()
a.connCloseWait.Done()
}()
for {
channel, err := peerConn.Accept(ctx)
if err != nil {
if errors.Is(err, peer.ErrClosed) || a.isClosed() {
return
}
a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
return
}
conn := channel.NetConn()
switch channel.Protocol() {
case ProtocolSSH:
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
case ProtocolReconnectingPTY:
rawID := channel.Label()
// The ID format is referenced in conn.go.
// <uuid>:<height>:<width>
idParts := strings.SplitN(rawID, ":", 4)
if len(idParts) != 4 {
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
continue
}
id := idParts[0]
// Enforce a consistent format for IDs.
_, err := uuid.Parse(id)
if err != nil {
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
continue
}
// Parse the initial terminal dimensions.
height, err := strconv.Atoi(idParts[1])
if err != nil {
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
continue
}
width, err := strconv.Atoi(idParts[2])
if err != nil {
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
continue
}
go a.handleReconnectingPTY(ctx, reconnectingPTYInit{
ID: id,
Height: uint16(height),
Width: uint16(width),
Command: idParts[3],
}, a.stats.wrapConn(conn))
case ProtocolDial:
go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn))
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
slog.F("label", channel.Label()),
)
}
}
}
func (a *agent) init(ctx context.Context) {
a.logger.Info(ctx, "generating host key")
// Clients' should ignore the host key when connecting.
@ -915,70 +793,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYIn
}
}
// dialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type dialResponse struct {
Error string `json:"error,omitempty"`
}
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()
writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
if !xerrors.Is(responseError, io.EOF) {
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
}
}
b, err := json.Marshal(dialResponse{
Error: msg,
})
if err != nil {
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}
_, err = conn.Write(b)
return err
}
u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}
network := u.Scheme
addr := u.Host + u.Path
if strings.HasPrefix(network, "unix") {
if runtime.GOOS == "windows" {
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
return
}
addr, err = ExpandRelativeHomePath(addr)
if err != nil {
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
return
}
}
d := net.Dialer{Timeout: 3 * time.Second}
nconn, err := d.DialContext(ctx, network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}
err = writeError(nil)
if err != nil {
return
}
Bicopy(ctx, conn, nconn)
}
// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {

View File

@ -20,12 +20,10 @@ import (
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
scp "github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/pion/udp"
"github.com/pion/webrtc/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -37,10 +35,6 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
@ -54,64 +48,49 @@ func TestMain(m *testing.M) {
func TestAgent(t *testing.T) {
t.Parallel()
t.Run("Stats", func(t *testing.T) {
for _, tailscale := range []bool{true, false} {
t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) {
t.Parallel()
t.Parallel()
setupAgent := func(t *testing.T) (agent.Conn, <-chan *agent.Stats) {
var derpMap *tailcfg.DERPMap
if tailscale {
derpMap = tailnettest.RunDERPAndSTUN(t)
}
conn, stats := setupAgent(t, agent.Metadata{
DERPMap: derpMap,
}, 0)
assert.Empty(t, <-stats)
return conn, stats
}
t.Run("SSH", func(t *testing.T) {
t.Parallel()
conn, stats := setupAgent(t, agent.Metadata{}, 0)
t.Run("SSH", func(t *testing.T) {
t.Parallel()
conn, stats := setupAgent(t)
sshClient, err := conn.SSHClient()
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
sshClient, err := conn.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
assert.EqualValues(t, 1, (<-stats).NumConns)
assert.Greater(t, (<-stats).RxBytes, int64(0))
assert.Greater(t, (<-stats).TxBytes, int64(0))
})
assert.EqualValues(t, 1, (<-stats).NumConns)
assert.Greater(t, (<-stats).RxBytes, int64(0))
assert.Greater(t, (<-stats).TxBytes, int64(0))
})
t.Run("ReconnectingPTY", func(t *testing.T) {
t.Parallel()
t.Run("ReconnectingPTY", func(t *testing.T) {
t.Parallel()
conn, stats := setupAgent(t, agent.Metadata{}, 0)
conn, stats := setupAgent(t)
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
require.NoError(t, err)
defer ptyConn.Close()
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
require.NoError(t, err)
defer ptyConn.Close()
data, err := json.Marshal(agent.ReconnectingPTYRequest{
Data: "echo test\r\n",
})
require.NoError(t, err)
_, err = ptyConn.Write(data)
require.NoError(t, err)
var s *agent.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = (<-stats)
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
})
data, err := json.Marshal(agent.ReconnectingPTYRequest{
Data: "echo test\r\n",
})
}
require.NoError(t, err)
_, err = ptyConn.Write(data)
require.NoError(t, err)
var s *agent.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = (<-stats)
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
})
})
t.Run("SessionExec", func(t *testing.T) {
@ -235,6 +214,7 @@ func TestAgent(t *testing.T) {
conn, _ := setupAgent(t, agent.Metadata{}, 0)
sshClient, err := conn.SSHClient()
require.NoError(t, err)
defer sshClient.Close()
client, err := sftp.NewClient(sshClient)
require.NoError(t, err)
tempFile := filepath.Join(t.TempDir(), "sftp")
@ -252,6 +232,7 @@ func TestAgent(t *testing.T) {
conn, _ := setupAgent(t, agent.Metadata{}, 0)
sshClient, err := conn.SSHClient()
require.NoError(t, err)
defer sshClient.Close()
scpClient, err := scp.NewClientBySSH(sshClient)
require.NoError(t, err)
tempFile := filepath.Join(t.TempDir(), "scp")
@ -384,9 +365,7 @@ func TestAgent(t *testing.T) {
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
conn, _ := setupAgent(t, agent.Metadata{
DERPMap: tailnettest.RunDERPAndSTUN(t),
}, 0)
conn, _ := setupAgent(t, agent.Metadata{}, 0)
id := uuid.NewString()
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
require.NoError(t, err)
@ -462,19 +441,6 @@ func TestAgent(t *testing.T) {
return l
},
},
{
name: "Unix",
setup: func(t *testing.T) net.Listener {
if runtime.GOOS == "windows" {
t.Skip("Unix socket forwarding isn't supported on Windows")
}
tmpDir := t.TempDir()
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
require.NoError(t, err, "create UDP listener")
return l
},
},
}
for _, c := range cases {
@ -496,8 +462,11 @@ func TestAgent(t *testing.T) {
}
}()
// Dial the listener over WebRTC twice and test out of order
conn, _ := setupAgent(t, agent.Metadata{}, 0)
require.Eventually(t, func() bool {
_, err := conn.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
@ -506,36 +475,11 @@ func TestAgent(t *testing.T) {
defer conn2.Close()
testDial(t, conn2)
testDial(t, conn1)
time.Sleep(150 * time.Millisecond)
})
}
})
t.Run("DialError", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This test uses Unix listeners so we can very easily ensure that
// no other tests decide to listen on the same random port we
// picked.
t.Skip("this test is unsupported on Windows")
return
}
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})
// Try to dial the non-existent Unix socket over WebRTC
conn, _ := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
require.Error(t, err)
require.ErrorContains(t, err, "remote dial error")
require.ErrorContains(t, err, "no such file")
require.Nil(t, netConn)
})
t.Run("Tailnet", func(t *testing.T) {
t.Parallel()
derpMap := tailnettest.RunDERPAndSTUN(t)
@ -578,7 +522,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
return
}
ssh, err := agentConn.SSH()
if !assert.NoError(t, err) {
if err != nil {
_ = conn.Close()
return
}
@ -622,11 +566,12 @@ func (c closeFunc) Close() error {
}
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) (
agent.Conn,
*agent.Conn,
<-chan *agent.Stats,
) {
client, server := provisionersdk.TransportPipe()
tailscale := metadata.DERPMap != nil
if metadata.DERPMap == nil {
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
}
coordinator := tailnet.NewCoordinator()
agentID := uuid.New()
statsCh := make(chan *agent.Stats)
@ -634,17 +579,18 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
return metadata, nil
},
WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, nil)
return listener, err
},
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
})
go coordinator.ServeAgent(serverConn, agentID)
go func() {
_ = coordinator.ServeAgent(serverConn, agentID)
close(closed)
}()
return clientConn, nil
},
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
@ -683,46 +629,27 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
},
})
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
_ = closer.Close()
})
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(context.Background())
assert.NoError(t, err)
if tailscale {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
})
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNode)
return &agent.TailnetConn{
Conn: conn,
}, statsCh
}
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
})
return &agent.WebRTCConn{
Negotiator: api,
Conn: conn,
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNode)
return &agent.Conn{
Conn: conn,
}, statsCh
}

View File

@ -4,13 +4,9 @@ import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"net/netip"
"net/url"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
@ -19,8 +15,6 @@ import (
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/tailnet"
)
@ -32,123 +26,12 @@ type ReconnectingPTYRequest struct {
Width uint16 `json:"width"`
}
// Conn is a temporary interface while we switch from WebRTC to Wireguard networking.
type Conn interface {
io.Closer
Closed() <-chan struct{}
Ping() (time.Duration, error)
CloseWithError(err error) error
ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error)
SSH() (net.Conn, error)
Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error)
SSHClient() (*ssh.Client, error)
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
}
// Conn wraps a peer connection with helper functions to
// communicate with the agent.
type WebRTCConn struct {
// Negotiator is responsible for exchanging messages.
Negotiator proto.DRPCPeerBrokerClient
*peer.Conn
}
// ReconnectingPTY returns a connection serving a TTY that can
// be reconnected to via ID.
//
// The command is optional and defaults to start a shell.
func (c *WebRTCConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d:%s", id, height, width, command), &peer.ChannelOptions{
Protocol: ProtocolReconnectingPTY,
})
if err != nil {
return nil, xerrors.Errorf("pty: %w", err)
}
return channel.NetConn(), nil
}
// SSH dials the built-in SSH server.
func (c *WebRTCConn) SSH() (net.Conn, error) {
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: ProtocolSSH,
})
if err != nil {
return nil, xerrors.Errorf("dial: %w", err)
}
return channel.NetConn(), nil
}
func (*WebRTCConn) Speedtest(_ speedtest.Direction, _ time.Duration) ([]speedtest.Result, error) {
return nil, xerrors.New("not implemented")
}
// SSHClient calls SSH to create a client that uses a weak cipher
// for high throughput.
func (c *WebRTCConn) SSHClient() (*ssh.Client, error) {
netConn, err := c.SSH()
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, xerrors.Errorf("ssh conn: %w", err)
}
return ssh.NewClient(sshConn, channels, requests), nil
}
// DialContext dials an arbitrary protocol+address from inside the workspace and
// proxies it through the provided net.Conn.
func (c *WebRTCConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
u := &url.URL{
Scheme: network,
}
if strings.HasPrefix(network, "unix") {
u.Path = addr
} else {
u.Host = addr
}
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
Protocol: ProtocolDial,
Unordered: strings.HasPrefix(network, "udp"),
})
if err != nil {
return nil, xerrors.Errorf("create datachannel: %w", err)
}
// The first message written from the other side is a JSON payload
// containing the dial error.
dec := json.NewDecoder(channel)
var res dialResponse
err = dec.Decode(&res)
if err != nil {
return nil, xerrors.Errorf("decode agent dial response: %w", err)
}
if res.Error != "" {
_ = channel.Close()
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
}
return channel.NetConn(), nil
}
func (c *WebRTCConn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()
}
type TailnetConn struct {
type Conn struct {
*tailnet.Conn
CloseFunc func()
}
func (c *TailnetConn) Ping() (time.Duration, error) {
func (c *Conn) Ping() (time.Duration, error) {
errCh := make(chan error, 1)
durCh := make(chan time.Duration, 1)
c.Conn.Ping(tailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) {
@ -166,11 +49,11 @@ func (c *TailnetConn) Ping() (time.Duration, error) {
}
}
func (c *TailnetConn) CloseWithError(_ error) error {
func (c *Conn) CloseWithError(_ error) error {
return c.Close()
}
func (c *TailnetConn) Close() error {
func (c *Conn) Close() error {
if c.CloseFunc != nil {
c.CloseFunc()
}
@ -184,7 +67,7 @@ type reconnectingPTYInit struct {
Command string
}
func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
func (c *Conn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetReconnectingPTYPort)))
if err != nil {
return nil, err
@ -210,13 +93,13 @@ func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command s
return conn, nil
}
func (c *TailnetConn) SSH() (net.Conn, error) {
func (c *Conn) SSH() (net.Conn, error) {
return c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSSHPort)))
}
// SSHClient calls SSH to create a client that uses a weak cipher
// for high throughput.
func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
func (c *Conn) SSHClient() (*ssh.Client, error) {
netConn, err := c.SSH()
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
@ -233,7 +116,7 @@ func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil
}
func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
func (c *Conn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSpeedtestPort)))
if err != nil {
return nil, xerrors.Errorf("dial speedtest: %w", err)
@ -245,7 +128,10 @@ func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Dur
return results, err
}
func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
if network == "unix" {
return nil, xerrors.New("network must be tcp or udp")
}
_, rawPort, _ := net.SplitHostPort(addr)
port, _ := strconv.Atoi(rawPort)
ipp := netip.AddrPortFrom(tailnetIP, uint16(port))

View File

@ -32,7 +32,6 @@ func workspaceAgent() *cobra.Command {
pprofEnabled bool
pprofAddress string
noReap bool
wireguard bool
)
cmd := &cobra.Command{
Use: "agent",
@ -184,7 +183,6 @@ func workspaceAgent() *cobra.Command {
closer := agent.New(agent.Options{
FetchMetadata: client.WorkspaceAgentMetadata,
WebRTCDialer: client.ListenWorkspaceAgent,
Logger: logger,
EnvironmentVariables: map[string]string{
// Override the "CODER_AGENT_TOKEN" variable in all
@ -203,6 +201,5 @@ func workspaceAgent() *cobra.Command {
cliflag.BoolVarP(cmd.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_AGENT_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.")
cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.")
cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_AGENT_WIREGUARD", true, "Whether to start the Wireguard interface.")
return cmd
}

View File

@ -7,10 +7,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
func TestWorkspaceAgent(t *testing.T) {
@ -63,11 +66,13 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer dialer.Close()
_, err = dialer.Ping()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := dialer.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
cancelFunc()
err = <-errC
require.NoError(t, err)
@ -121,11 +126,13 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer dialer.Close()
_, err = dialer.Ping()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := dialer.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
cancelFunc()
err = <-errC
require.NoError(t, err)
@ -179,11 +186,13 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer dialer.Close()
_, err = dialer.Ping()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := dialer.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
cancelFunc()
err = <-errC
require.NoError(t, err)

View File

@ -139,7 +139,6 @@ func configSSH() *cobra.Command {
usePreviousOpts bool
dryRun bool
skipProxyCommand bool
wireguard bool
)
cmd := &cobra.Command{
Annotations: workspaceCommand,
@ -289,15 +288,11 @@ func configSSH() *cobra.Command {
"\tLogLevel ERROR",
)
if !skipProxyCommand {
wgArg := ""
if wireguard {
wgArg = "--wireguard "
}
configOptions = append(
configOptions,
fmt.Sprintf(
"\tProxyCommand %s --global-config %s ssh %s--stdio %s",
escapedCoderBinary, escapedGlobalConfig, wgArg, hostname,
"\tProxyCommand %s --global-config %s ssh --stdio %s",
escapedCoderBinary, escapedGlobalConfig, hostname,
),
)
}
@ -374,9 +369,6 @@ func configSSH() *cobra.Command {
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
_ = cmd.Flags().MarkHidden("skip-proxy-command")
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.")
_ = cmd.Flags().MarkHidden("wireguard")
cliui.AllowSkipPrompt(cmd)
return cmd

View File

@ -12,12 +12,14 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
@ -106,15 +108,14 @@ func TestConfigSSH(t *testing.T) {
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer func() {
_ = agentCloser.Close()
}()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
agentConn, err := client.DialWorkspaceAgentTailnet(context.Background(), slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer agentConn.Close()
@ -123,17 +124,28 @@ func TestConfigSSH(t *testing.T) {
defer func() {
_ = listener.Close()
}()
copyDone := make(chan struct{})
go func() {
defer close(copyDone)
var wg sync.WaitGroup
for {
conn, err := listener.Accept()
if err != nil {
return
break
}
ssh, err := agentConn.SSH()
assert.NoError(t, err)
go io.Copy(conn, ssh)
go io.Copy(ssh, conn)
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(conn, ssh)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(ssh, conn)
}()
}
wg.Wait()
}()
sshConfigFile := sshConfigFileName(t)
@ -178,6 +190,9 @@ func TestConfigSSH(t *testing.T) {
data, err := sshCmd.Output()
require.NoError(t, err)
require.Equal(t, "test", strings.TrimSpace(string(data)))
_ = listener.Close()
<-copyDone
}
func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) {

View File

@ -20,6 +20,8 @@ import (
"github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh"
"cdr.dev/slog"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
@ -72,7 +74,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
// start workspace agent
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false")
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
agentClient := client
clitest.SetupConfig(t, agentClient, root)
@ -85,11 +87,13 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer dialer.Close()
_, err = dialer.Ping()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err = dialer.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
return agentClient, agentToken, pubkey
}

View File

@ -17,9 +17,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
)
@ -28,7 +26,6 @@ func portForward() *cobra.Command {
var (
tcpForwards []string // <port>:<port>
udpForwards []string // <port>:<port>
wireguard bool
)
cmd := &cobra.Command{
Use: "port-forward <workspace>",
@ -94,16 +91,7 @@ func portForward() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}
var conn agent.Conn
if !wireguard {
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
} else {
logger := slog.Logger{}
if cliflag.IsSetBool(cmd, varVerbose) {
logger = slog.Make(sloghuman.Sink(cmd.ErrOrStderr())).Named("tailnet").Leveled(slog.LevelDebug)
}
conn, err = client.DialWorkspaceAgentTailnet(ctx, logger, workspaceAgent.ID)
}
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
if err != nil {
return err
}
@ -178,12 +166,10 @@ func portForward() *cobra.Command {
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
cmd.Flags().BoolVarP(&wireguard, "wireguard", "", true, "Specifies whether to use wireguard networking or not.")
_ = cmd.Flags().MarkHidden("wireguard")
return cmd
}
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
var (

View File

@ -377,7 +377,7 @@ func setupTestListener(t *testing.T, l net.Listener) string {
addr := l.Addr().String()
_, port, err := net.SplitHostPort(addr)
require.NoErrorf(t, err, "split listen path %q", addr)
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
addr = port
return addr

View File

@ -28,8 +28,6 @@ import (
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/pion/turn/v2"
"github.com/pion/webrtc/v3"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/afero"
@ -59,7 +57,6 @@ import (
"github.com/coder/coder/coderd/prometheusmetrics"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"
@ -113,9 +110,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
tlsEnable bool
tlsKeyFile string
tlsMinVersion string
turnRelayAddress string
tunnel bool
stunServers []string
traceEnable bool
secureAuthCookie bool
sshKeygenAlgorithmRaw string
@ -300,22 +295,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err)
}
turnServer, err := turnconn.New(&turn.RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP(turnRelayAddress),
Address: turnRelayAddress,
})
if err != nil {
return xerrors.Errorf("create turn server: %w", err)
}
defer turnServer.Close()
iceServers := make([]webrtc.ICEServer, 0)
for _, stunServer := range stunServers {
iceServers = append(iceServers, webrtc.ICEServer{
URLs: []string{stunServer},
})
}
// Validate provided auto-import templates.
var (
validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(autoImportTemplates))
@ -360,7 +339,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
options := &coderd.Options{
AccessURL: accessURLParsed,
ICEServers: iceServers,
Logger: logger.Named("coderd"),
Database: databasefake.New(),
DERPMap: derpMap,
@ -369,8 +347,6 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
GoogleTokenValidator: googleTokenValidator,
SecureAuthCookie: secureAuthCookie,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TailscaleEnable: tailscaleEnable,
TURNServer: turnServer,
TracerProvider: tracerProvider,
Telemetry: telemetry.NewNoop(),
AutoImportTemplates: validatedAutoImportTemplates,
@ -478,7 +454,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
OIDCAuth: oidcClientID != "",
OIDCIssuerURL: oidcIssuerURL,
Prometheus: promEnabled,
STUN: len(stunServers) != 0,
STUN: len(derpServerSTUNAddrs) != 0,
Tunnel: tunnel,
})
if err != nil {
@ -850,13 +826,8 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
`Minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
cliflag.BoolVarP(root.Flags(), &tunnel, "tunnel", "", "CODER_TUNNEL", false,
"Workspaces must be able to reach the `access-url`. This overrides your access URL with a public access URL that tunnels your Coder deployment.")
cliflag.StringArrayVarP(root.Flags(), &stunServers, "stun-server", "", "CODER_STUN_SERVERS", []string{
"stun:stun.l.google.com:19302",
}, "URLs for STUN servers to enable P2P connections.")
cliflag.BoolVarP(root.Flags(), &traceEnable, "trace", "", "CODER_TRACE", false,
"Whether application tracing data is collected.")
cliflag.StringVarP(root.Flags(), &turnRelayAddress, "turn-relay-address", "", "CODER_TURN_RELAY_ADDRESS", "127.0.0.1",
"The address to bind TURN connections.")
cliflag.BoolVarP(root.Flags(), &secureAuthCookie, "secure-auth-cookie", "", "CODER_SECURE_AUTH_COOKIE", false,
"Controls if the 'Secure' property is set on browser session cookies")
cliflag.StringVarP(root.Flags(), &sshKeygenAlgorithmRaw, "ssh-keygen-algorithm", "", "CODER_SSH_KEYGEN_ALGORITHM", "ed25519",

View File

@ -12,7 +12,6 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
@ -73,8 +72,7 @@ func speedtest() *cobra.Command {
if err != nil {
continue
}
tc, _ := conn.(*agent.TailnetConn)
status := tc.Status()
status := conn.Status()
if len(status.Peers()) != 1 {
continue
}

View File

@ -25,7 +25,6 @@ func TestSpeedtest(t *testing.T) {
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
})

View File

@ -22,7 +22,6 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/autobuild/notify"
@ -43,7 +42,6 @@ func ssh() *cobra.Command {
forwardAgent bool
identityAgent string
wsPollInterval time.Duration
wireguard bool
)
cmd := &cobra.Command{
Annotations: workspaceCommand,
@ -88,12 +86,7 @@ func ssh() *cobra.Command {
return xerrors.Errorf("await agent: %w", err)
}
var conn agent.Conn
if !wireguard {
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
} else {
conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
}
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
if err != nil {
return err
}
@ -221,9 +214,6 @@ func ssh() *cobra.Command {
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", true, "Whether to use Wireguard for SSH tunneling.")
_ = cmd.Flags().MarkHidden("wireguard")
return cmd
}

View File

@ -90,7 +90,6 @@ func TestSSH(t *testing.T) {
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
})
@ -112,7 +111,6 @@ func TestSSH(t *testing.T) {
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
})
@ -181,7 +179,6 @@ func TestSSH(t *testing.T) {
agentClient.SessionToken = agentToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent"),
})

View File

@ -13,7 +13,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/klauspost/compress/zstd"
"github.com/pion/webrtc/v3"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
@ -35,7 +34,6 @@ import (
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
@ -65,17 +63,14 @@ type Options struct {
GithubOAuth2Config *GithubOAuth2Config
OIDCConfig *OIDCConfig
PrometheusRegistry *prometheus.Registry
ICEServers []webrtc.ICEServer
SecureAuthCookie bool
SSHKeygenAlgorithm gitsshkey.Algorithm
Telemetry telemetry.Reporter
TURNServer *turnconn.Server
TracerProvider trace.TracerProvider
AutoImportTemplates []AutoImportTemplate
LicenseHandler http.Handler
FeaturesService features.Service
TailscaleEnable bool
TailnetCoordinator *tailnet.Coordinator
DERPMap *tailcfg.DERPMap
@ -92,6 +87,12 @@ func New(options *Options) *API {
// Multiply the update by two to allow for some lag-time.
options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2
}
if options.AgentStatsRefreshInterval == 0 {
options.AgentStatsRefreshInterval = 10 * time.Minute
}
if options.MetricsCacheRefreshInterval == 0 {
options.MetricsCacheRefreshInterval = time.Hour
}
if options.APIRateLimit == 0 {
options.APIRateLimit = 512
}
@ -149,11 +150,7 @@ func New(options *Options) *API {
},
metricsCache: metricsCache,
}
if options.TailscaleEnable {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
} else {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
}
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
@ -415,14 +412,8 @@ func New(options *Options) *API {
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
r.Get("/metadata", api.workspaceAgentMetadata)
r.Post("/version", api.postWorkspaceAgentVersion)
r.Get("/listen", api.workspaceAgentListen)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/coordinate", api.workspaceAgentCoordinate)
r.Get("/report-stats", api.workspaceAgentReportStats)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
@ -432,11 +423,7 @@ func New(options *Options) *API {
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceAgent)
r.Get("/dial", api.workspaceAgentDial)
r.Get("/turn", api.userWorkspaceAgentTurn)
r.Get("/pty", api.workspaceAgentPTY)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/connection", api.workspaceAgentConnection)
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
})

View File

@ -188,18 +188,14 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"GET:/api/v2/users/oidc/callback": {NoAuthorize: true},
// All workspaceagents endpoints do not use rbac
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)},
@ -256,14 +252,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/dial": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/pty": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,

View File

@ -54,7 +54,6 @@ import (
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
@ -202,12 +201,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
}
turnServer, err := turnconn.New(nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = turnServer.Close()
})
features := coderd.DisabledImplementations
if options.Auditor != nil {
features.Auditor = options.Auditor
@ -231,7 +224,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
TURNServer: turnServer,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),

View File

@ -604,7 +604,6 @@ func TestTemplateDAUs(t *testing.T) {
agentCloser := agent.New(agent.Options{
Logger: slogtest.Make(t, nil),
StatsReporter: agentClient.AgentReportStats,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
})

View File

@ -1,203 +0,0 @@
package turnconn
import (
"io"
"net"
"sync"
"github.com/pion/logging"
"github.com/pion/turn/v2"
"github.com/pion/webrtc/v3"
"golang.org/x/net/proxy"
"golang.org/x/xerrors"
)
var (
// reservedAddress is a magic address that's used exclusively
// for proxying via Coder. We don't proxy all TURN connections,
// because that'd exclude the possibility of a customer using
// their own TURN server.
reservedAddress = "127.0.0.1:12345"
credential = "coder"
localhost = &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
}
// Proxy is a an ICE Server that uses a special hostname
// to indicate traffic should be proxied.
Proxy = webrtc.ICEServer{
URLs: []string{"turns:" + reservedAddress},
Username: "coder",
Credential: credential,
}
)
// New constructs a new TURN server binding to the relay address provided.
// The relay address is used to broadcast the location of an accepted connection.
func New(relayAddress *turn.RelayAddressGeneratorStatic) (*Server, error) {
if relayAddress == nil {
relayAddress = &turn.RelayAddressGeneratorStatic{
RelayAddress: localhost.IP,
Address: "127.0.0.1",
}
}
logger := logging.NewDefaultLoggerFactory()
logger.DefaultLogLevel = logging.LogLevelDisabled
server := &Server{
conns: make(chan net.Conn, 1),
closed: make(chan struct{}),
}
server.listener = &listener{
srv: server,
}
var err error
server.turn, err = turn.NewServer(turn.ServerConfig{
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
// TURN connections require credentials. It's not important
// for our use-case, because our listener is entirely in-memory.
return turn.GenerateAuthKey(Proxy.Username, "", credential), true
},
ListenerConfigs: []turn.ListenerConfig{{
Listener: server.listener,
RelayAddressGenerator: relayAddress,
}},
LoggerFactory: logger,
})
if err != nil {
return nil, xerrors.Errorf("create server: %w", err)
}
return server, nil
}
// Server accepts and connects TURN allocations.
//
// This is a thin wrapper around pion/turn that pipes
// connections directly to the in-memory handler.
type Server struct {
listener *listener
turn *turn.Server
closeMutex sync.Mutex
closed chan (struct{})
conns chan (net.Conn)
}
// Accept consumes a new connection into the TURN server.
// A unique remote address must exist per-connection.
// pion/turn indexes allocations based on the address.
func (s *Server) Accept(nc net.Conn, remoteAddress, localAddress *net.TCPAddr) *Conn {
if localAddress == nil {
localAddress = localhost
}
conn := &Conn{
Conn: nc,
remoteAddress: remoteAddress,
localAddress: localAddress,
closed: make(chan struct{}),
}
s.conns <- conn
return conn
}
// Close ends the TURN server.
func (s *Server) Close() error {
s.closeMutex.Lock()
defer s.closeMutex.Unlock()
if s.isClosed() {
return nil
}
err := s.turn.Close()
close(s.conns)
close(s.closed)
return err
}
func (s *Server) isClosed() bool {
select {
case <-s.closed:
return true
default:
return false
}
}
// listener implements net.Listener for the TURN
// server to consume.
type listener struct {
srv *Server
}
func (l *listener) Accept() (net.Conn, error) {
conn, ok := <-l.srv.conns
if !ok {
return nil, io.EOF
}
return conn, nil
}
func (*listener) Close() error {
return nil
}
func (*listener) Addr() net.Addr {
return nil
}
type Conn struct {
net.Conn
closed chan struct{}
localAddress *net.TCPAddr
remoteAddress *net.TCPAddr
}
func (c *Conn) LocalAddr() net.Addr {
return c.localAddress
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddress
}
// Closed returns a channel which is closed when
// the connection is.
func (c *Conn) Closed() <-chan struct{} {
return c.closed
}
func (c *Conn) Close() error {
err := c.Conn.Close()
select {
case <-c.closed:
default:
close(c.closed)
}
return err
}
type dialer func(network, addr string) (c net.Conn, err error)
func (d dialer) Dial(network, addr string) (c net.Conn, err error) {
return d(network, addr)
}
// ProxyDialer accepts a proxy function that's called when the connection
// address matches the reserved host in the "Proxy" ICE server.
//
// This should be passed to WebRTC connections as an ICE dialer.
func ProxyDialer(proxyFunc func() (c net.Conn, err error)) proxy.Dialer {
return dialer(func(network, addr string) (net.Conn, error) {
if addr != reservedAddress {
return proxy.Direct.Dial(network, addr)
}
netConn, err := proxyFunc()
if err != nil {
return nil, err
}
return &Conn{
localAddress: localhost,
closed: make(chan struct{}),
Conn: netConn,
}, nil
})
}

View File

@ -1,107 +0,0 @@
package turnconn_test
import (
"net"
"sync"
"testing"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/peer"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestTURNConn(t *testing.T) {
t.Parallel()
turnServer, err := turnconn.New(nil)
require.NoError(t, err)
defer turnServer.Close()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientDialer, clientTURN := net.Pipe()
turnServer.Accept(clientTURN, &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 16000,
}, nil)
require.NoError(t, err)
clientSettings := webrtc.SettingEngine{}
clientSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
clientSettings.SetRelayAcceptanceMinWait(0)
clientSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) {
return clientDialer, nil
}))
client, err := peer.Client([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{
SettingEngine: clientSettings,
Logger: logger.Named("client"),
})
require.NoError(t, err)
defer func() {
_ = client.Close()
}()
serverDialer, serverTURN := net.Pipe()
turnServer.Accept(serverTURN, &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 16001,
}, nil)
require.NoError(t, err)
serverSettings := webrtc.SettingEngine{}
serverSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
serverSettings.SetRelayAcceptanceMinWait(0)
serverSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) {
return serverDialer, nil
}))
server, err := peer.Server([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{
SettingEngine: serverSettings,
Logger: logger.Named("server"),
})
require.NoError(t, err)
defer func() {
_ = server.Close()
}()
exchange(t, client, server)
_, err = client.Ping()
require.NoError(t, err)
}
func exchange(t *testing.T, client, server *peer.Conn) {
var wg sync.WaitGroup
wg.Add(2)
t.Cleanup(wg.Wait)
go func() {
defer wg.Done()
for {
select {
case c := <-server.LocalCandidate():
client.AddRemoteCandidate(c)
case c := <-server.LocalSessionDescription():
client.SetRemoteSessionDescription(c)
case <-server.Closed():
return
}
}
}()
go func() {
defer wg.Done()
for {
select {
case c := <-client.LocalCandidate():
server.AddRemoteCandidate(c)
case c := <-client.LocalSessionDescription():
server.SetRemoteSessionDescription(c)
case <-client.Closed():
return
}
}
}()
}

View File

@ -15,7 +15,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
@ -30,12 +29,7 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/tailnet"
)
@ -66,67 +60,6 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusOK, apiAgent)
}
func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
httpapi.ResourceNotFound(rw)
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
Detail: err.Error(),
})
return
}
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
httpapi.Write(rw, http.StatusPreconditionFailed, codersdk.Response{
Message: fmt.Sprintf("Agent isn't connected! Status: %s.", apiAgent.Status),
})
return
}
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
}
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
@ -186,231 +119,6 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques
httpapi.Write(rw, http.StatusOK, nil)
}
func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace build job.",
Detail: err.Error(),
})
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
err = ensureLatestBuild()
if err != nil {
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Agent trying to connect from non-latest build.",
Detail: err.Error(),
})
return
}
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Pubsub: api.Pubsub,
Logger: api.Logger.Named("peerbroker-proxy-listen"),
})
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
defer closer.Close()
firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: database.Now(),
Valid: true,
}
disconnectedAt := workspaceAgent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
ID: workspaceAgent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
UpdatedAt: database.Now(),
})
if err != nil {
return err
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes()
}()
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-session.CloseChan():
return
case <-ticker.C:
lastConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
}
func (api *API) workspaceAgentICEServers(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusOK, api.ICEServers)
}
// userWorkspaceAgentTurn is a user connecting to a remote workspace agent
// through turn.
func (api *API) userWorkspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
httpapi.ResourceNotFound(rw)
return
}
// Passed authorization
api.workspaceAgentTurn(rw, r)
}
// workspaceAgentTurn proxies a WebSocket connection to the TURN server.
func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
remoteAddress := &net.TCPAddr{
IP: net.ParseIP(r.RemoteAddr),
}
// By default requests have the remote address and port.
host, port, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid remote address.",
Detail: err.Error(),
})
return
}
remoteAddress.IP = net.ParseIP(host)
remoteAddress.Port, err = strconv.Atoi(port)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Port for remote address %q must be an integer.", r.RemoteAddr),
Detail: err.Error(),
})
return
}
wsConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
select {
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
case <-ctx.Done():
}
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
}
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
// This is used for the web terminal.
func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
@ -492,75 +200,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(ptNetConn, wsNetConn)
}
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
// r.Context() for cancellation if it's use is safe or r.Hijack() has
// not been performed.
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
client, server := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
ChannelID: agentID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
_ = client.Close()
_ = server.Close()
}()
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := peerClient.NegotiateConnection(ctx)
if err != nil {
cancelFunc()
return nil, xerrors.Errorf("negotiate: %w", err)
}
options := &peer.ConnOptions{
Logger: api.Logger.Named("agent-dialer"),
}
options.SettingEngine.SetSrflxAcceptanceMinWait(0)
options.SettingEngine.SetRelayAcceptanceMinWait(0)
// Use the ProxyDialer for the TURN server.
// This is required for connections where P2P is not enabled.
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
clientPipe, serverPipe := net.Pipe()
go func() {
<-ctx.Done()
_ = clientPipe.Close()
_ = serverPipe.Close()
}()
localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
remoteAddress := &net.TCPAddr{
IP: net.ParseIP(r.RemoteAddr),
}
// By default requests have the remote address and port.
host, port, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return nil, xerrors.Errorf("split remote address: %w", err)
}
remoteAddress.IP = net.ParseIP(host)
remoteAddress.Port, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("convert remote port: %w", err)
}
api.TURNServer.Accept(clientPipe, remoteAddress, localAddress)
return serverPipe, nil
}))
peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options)
if err != nil {
cancelFunc()
return nil, xerrors.Errorf("dial: %w", err)
}
go func() {
<-peerConn.Closed()
cancelFunc()
}()
return &agent.WebRTCConn{
Negotiator: peerClient,
Conn: peerConn,
}, nil
}
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
clientConn, serverConn := net.Pipe()
go func() {
<-r.Context().Done()
@ -587,7 +227,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (a
_ = conn.Close()
}
}()
return &agent.TailnetConn{
return &agent.Conn{
Conn: conn,
}, nil
}
@ -609,6 +249,48 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace build job.",
Detail: err.Error(),
})
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
err = ensureLatestBuild()
if err != nil {
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Agent trying to connect from non-latest build.",
Detail: err.Error(),
})
return
}
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
@ -618,12 +300,88 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID)
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close()
firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: database.Now(),
Valid: true,
}
disconnectedAt := workspaceAgent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
ID: workspaceAgent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
UpdatedAt: database.Now(),
})
if err != nil {
return err
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes()
}()
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
defer conn.Close(websocket.StatusNormalClosure, "")
closeChan := make(chan struct{})
go func() {
defer close(closeChan)
err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}()
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-closeChan:
return
case <-ticker.C:
}
lastConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err := ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
@ -18,7 +17,6 @@ import (
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
@ -112,7 +110,6 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
@ -123,13 +120,15 @@ func TestWorkspaceAgentListen(t *testing.T) {
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
_, err = conn.Ping()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := conn.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
})
t.Run("FailNonLatestBuild", func(t *testing.T) {
@ -202,75 +201,12 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
_, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
_, err = agentClient.ListenWorkspaceAgentTailnet(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "build is outdated")
})
}
func TestWorkspaceAgentTURN(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
}()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
opts := &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client"),
}
// Force a TURN connection!
opts.SettingEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4})
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, opts)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
_, err = conn.Ping()
require.NoError(t, err)
}
func TestWorkspaceAgentTailnet(t *testing.T) {
t.Parallel()
client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, nil)
@ -306,7 +242,6 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
@ -373,7 +308,6 @@ func TestWorkspaceAgentPTY(t *testing.T) {
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {

View File

@ -103,7 +103,6 @@ func setupProxyTest(t *testing.T) (*codersdk.Client, uuid.UUID, codersdk.Workspa
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {

View File

@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
}
// Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error)
type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
agent.Conn
*agent.Conn
locks atomic.Uint64
timeoutMutex sync.Mutex

View File

@ -35,7 +35,7 @@ func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, 0)
defer func() {
@ -50,7 +50,7 @@ func TestCache(t *testing.T) {
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
called.Add(1)
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
@ -69,7 +69,7 @@ func TestCache(t *testing.T) {
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -102,7 +102,7 @@ func TestCache(t *testing.T) {
}()
go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -139,7 +139,7 @@ func TestCache(t *testing.T) {
})
}
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn {
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator()
@ -180,7 +180,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNode)
return &agent.TailnetConn{
return &agent.Conn{
Conn: conn,
}
}

View File

@ -135,6 +135,7 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
go func() {
defer close(logs)
defer conn.Close(websocket.StatusGoingAway, "")
var log ProvisionerJobLog
for {
err = decoder.Decode(&log)

View File

@ -14,9 +14,6 @@ import (
"cloud.google.com/go/compute/metadata"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/pion/webrtc/v3"
"golang.org/x/net/proxy"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
@ -25,11 +22,6 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
)
@ -206,69 +198,6 @@ func (c *Client) WorkspaceAgentMetadata(ctx context.Context) (agent.Metadata, er
return agentMetadata, json.NewDecoder(res.Body).Decode(&agentMetadata)
}
// ListenWorkspaceAgent connects as a workspace agent identifying with the session token.
// On each inbound connection request, connection info is fetched.
func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/listen")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(serverURL, []*http.Cookie{{
Name: SessionTokenKey,
Value: c.SessionToken,
}})
httpClient := &http.Client{
Jar: jar,
}
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if res == nil {
return nil, err
}
return nil, readBodyAsError(res)
}
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return peerbroker.Listen(session, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
// This can be cached if it adds to latency too much.
res, err := c.Request(ctx, http.MethodGet, "/api/v2/workspaceagents/me/iceservers", nil)
if err != nil {
return nil, nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, nil, readBodyAsError(res)
}
var iceServers []webrtc.ICEServer
err = json.NewDecoder(res.Body).Decode(&iceServers)
if err != nil {
return nil, nil, err
}
options := webrtc.SettingEngine{}
options.SetSrflxAcceptanceMinWait(0)
options.SetRelayAcceptanceMinWait(0)
options.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, "/api/v2/workspaceagents/me/turn"))
iceServers = append(iceServers, turnconn.Proxy)
return iceServers, &peer.ConnOptions{
SettingEngine: options,
Logger: logger,
}, nil
})
}
func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, error) {
coordinateURL, err := c.URL.Parse("/api/v2/workspaceagents/me/coordinate")
if err != nil {
@ -286,17 +215,20 @@ func (c *Client) ListenWorkspaceAgentTailnet(ctx context.Context) (net.Conn, err
Jar: jar,
}
// nolint:bodyclose
conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
})
if err != nil {
return nil, err
if res == nil {
return nil, err
}
return nil, readBodyAsError(res)
}
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
}
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (agent.Conn, error) {
func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logger, agentID uuid.UUID) (*agent.Conn, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/connection", agentID), nil)
if err != nil {
return nil, err
@ -349,10 +281,12 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
CompressionMode: websocket.CompressionDisabled,
})
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusAbnormalClosure, "")
return
}
if err != nil {
logger.Debug(ctx, "failed to dial", slog.Error(err))
_ = ws.Close(websocket.StatusAbnormalClosure, "")
continue
}
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
@ -362,15 +296,18 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
logger.Debug(ctx, "serving coordinator")
err = <-errChan
if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusAbnormalClosure, "")
return
}
if err != nil {
logger.Debug(ctx, "error serving coordinator", slog.Error(err))
_ = ws.Close(websocket.StatusAbnormalClosure, "")
continue
}
_ = ws.Close(websocket.StatusAbnormalClosure, "")
}
}()
return &agent.TailnetConn{
return &agent.Conn{
Conn: conn,
CloseFunc: func() {
cancelFunc()
@ -379,78 +316,6 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
}, nil
}
// DialWorkspaceAgent creates a connection to the specified resource.
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (agent.Conn, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String()))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(serverURL, []*http.Cookie{{
Name: SessionTokenKey,
Value: c.SessionToken,
}})
httpClient := &http.Client{
Jar: jar,
}
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if res == nil {
return nil, err
}
return nil, readBodyAsError(res)
}
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
client := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session))
stream, err := client.NegotiateConnection(ctx)
if err != nil {
return nil, xerrors.Errorf("negotiate connection: %w", err)
}
res, err = c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/iceservers", agentID.String()), nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
var iceServers []webrtc.ICEServer
err = json.NewDecoder(res.Body).Decode(&iceServers)
if err != nil {
return nil, err
}
if options == nil {
options = &peer.ConnOptions{}
}
options.SettingEngine.SetSrflxAcceptanceMinWait(0)
options.SettingEngine.SetRelayAcceptanceMinWait(0)
options.SettingEngine.SetICEProxyDialer(c.turnProxyDialer(ctx, httpClient, fmt.Sprintf("/api/v2/workspaceagents/%s/turn", agentID.String())))
iceServers = append(iceServers, turnconn.Proxy)
peerConn, err := peerbroker.Dial(stream, iceServers, options)
if err != nil {
return nil, xerrors.Errorf("dial peer: %w", err)
}
return &agent.WebRTCConn{
Negotiator: client,
Conn: peerConn,
}, nil
}
// WorkspaceAgent returns an agent by ID.
func (c *Client) WorkspaceAgent(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s", id), nil)
@ -509,27 +374,6 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
}
func (c *Client) turnProxyDialer(ctx context.Context, httpClient *http.Client, path string) proxy.Dialer {
return turnconn.ProxyDialer(func() (net.Conn, error) {
turnURL, err := c.URL.Parse(path)
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
conn, res, err := websocket.Dial(ctx, turnURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if res == nil {
return nil, err
}
return nil, readBodyAsError(res)
}
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
})
}
// AgentReportStats begins a stat streaming connection with the Coder server.
// It is resilient to network failures and intermittent coderd issues.
func (c *Client) AgentReportStats(
@ -584,6 +428,7 @@ func (c *Client) AgentReportStats(
var req AgentStatsReportRequest
err := wsjson.Read(ctx, conn, &req)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, "")
return err
}
@ -597,6 +442,7 @@ func (c *Client) AgentReportStats(
err = wsjson.Write(ctx, conn, resp)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, "")
return err
}
}

24
go.mod
View File

@ -39,7 +39,7 @@ replace github.com/fatedier/kcp-go => github.com/coder/kcp-go v2.0.4-0.202204091
// https://github.com/pion/udp/pull/73
replace github.com/pion/udp => github.com/mafredri/udp v0.1.2-0.20220805105907-b2872e92e98d
// https://github.com/hashicorp/hc-install/pull/68
// https://github.com/hashicorp/hc-dinstall/pull/68
replace github.com/hashicorp/hc-install => github.com/mafredri/hc-install v0.4.1-0.20220727132613-e91868e28445
// https://github.com/tcnksm/go-httpstat/pull/29
@ -119,12 +119,7 @@ require (
github.com/nhatthm/otelsql v0.4.0
github.com/open-policy-agent/opa v0.41.0
github.com/ory/dockertest/v3 v3.9.1
github.com/pion/datachannel v1.5.2
github.com/pion/logging v0.2.2
github.com/pion/transport v0.13.1
github.com/pion/turn/v2 v2.0.8
github.com/pion/udp v0.1.1
github.com/pion/webrtc/v3 v3.1.43
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e
github.com/pkg/sftp v1.13.5
@ -150,7 +145,6 @@ require (
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
@ -171,6 +165,11 @@ require (
tailscale.com v1.30.0
)
require (
github.com/pion/transport v0.13.1 // indirect
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
)
require (
filippo.io/edwards25519 v1.0.0-rc.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
@ -258,17 +257,6 @@ require (
github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 // indirect
github.com/opencontainers/runc v1.1.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.2 // indirect
github.com/pion/dtls/v2 v2.1.5 // indirect
github.com/pion/ice/v2 v2.2.6 // indirect
github.com/pion/interceptor v0.1.11 // indirect
github.com/pion/mdns v0.0.5 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.9 // indirect
github.com/pion/rtp v1.7.13 // indirect
github.com/pion/sctp v1.8.2 // indirect
github.com/pion/sdp/v3 v3.0.5 // indirect
github.com/pion/srtp/v2 v2.0.10 // indirect
github.com/pion/stun v0.3.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect

43
go.sum
View File

@ -1451,7 +1451,6 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108
github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0=
github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c=
github.com/onsi/gomega v0.0.0-20151007035656-2152b45fa28a/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
@ -1526,43 +1525,9 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pion/datachannel v1.5.2 h1:piB93s8LGmbECrpO84DnkIVWasRMk3IimbcXkTQLE6E=
github.com/pion/datachannel v1.5.2/go.mod h1:FTGQWaHrdCwIJ1rw6xBIfZVkslikjShim5yr05XFuCQ=
github.com/pion/dtls/v2 v2.1.3/go.mod h1:o6+WvyLDAlXF7YiPB/RlskRoeK+/JtuaZa5emwQcWus=
github.com/pion/dtls/v2 v2.1.5 h1:jlh2vtIyUBShchoTDqpCCqiYCyRFJ/lvf/gQ8TALs+c=
github.com/pion/dtls/v2 v2.1.5/go.mod h1:BqCE7xPZbPSubGasRoDFJeTsyJtdD1FanJYL0JGheqY=
github.com/pion/ice/v2 v2.2.6 h1:R/vaLlI1J2gCx141L5PEwtuGAGcyS6e7E0hDeJFq5Ig=
github.com/pion/ice/v2 v2.2.6/go.mod h1:SWuHiOGP17lGromHTFadUe1EuPgFh/oCU6FCMZHooVE=
github.com/pion/interceptor v0.1.11 h1:00U6OlqxA3FFB50HSg25J/8cWi7P6FbSzw4eFn24Bvs=
github.com/pion/interceptor v0.1.11/go.mod h1:tbtKjZY14awXd7Bq0mmWvgtHB5MDaRN7HV3OZ/uy7s8=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw=
github.com/pion/mdns v0.0.5/go.mod h1:UgssrvdD3mxpi8tMxAXbsppL3vJ4Jipw1mTCW+al01g=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U=
github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo=
github.com/pion/rtp v1.7.13 h1:qcHwlmtiI50t1XivvoawdCGTP4Uiypzfrsap+bijcoA=
github.com/pion/rtp v1.7.13/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko=
github.com/pion/sctp v1.8.0/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s=
github.com/pion/sctp v1.8.2 h1:yBBCIrUMJ4yFICL3RIvR4eh/H2BTTvlligmSTy+3kiA=
github.com/pion/sctp v1.8.2/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s=
github.com/pion/sdp/v3 v3.0.5 h1:ouvI7IgGl+V4CrqskVtr3AaTrPvPisEOxwgpdktctkU=
github.com/pion/sdp/v3 v3.0.5/go.mod h1:iiFWFpQO8Fy3S5ldclBkpXqmWy02ns78NOKoLLL0YQw=
github.com/pion/srtp/v2 v2.0.10 h1:b8ZvEuI+mrL8hbr/f1YiJFB34UMrOac3R3N1yq2UN0w=
github.com/pion/srtp/v2 v2.0.10/go.mod h1:XEeSWaK9PfuMs7zxXyiN252AHPbH12NX5q/CFDWtUuA=
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
github.com/pion/transport v0.12.2/go.mod h1:N3+vZQD9HlDP5GWkZ85LohxNsDcNgofQmyL6ojX5d8Q=
github.com/pion/transport v0.12.3/go.mod h1:OViWW9SP2peE/HbwBvARicmAVnesphkNkCVZIWJ6q9A=
github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g=
github.com/pion/transport v0.13.1 h1:/UH5yLeQtwm2VZIPjxwnNFxjS4DFhyLfS4GlfuKUzfA=
github.com/pion/transport v0.13.1/go.mod h1:EBxbqzyv+ZrmDb82XswEE0BjfQFtuw1Nu6sjnjWCsGg=
github.com/pion/turn/v2 v2.0.8 h1:KEstL92OUN3k5k8qxsXHpr7WWfrdp7iJZHx99ud8muw=
github.com/pion/turn/v2 v2.0.8/go.mod h1:+y7xl719J8bAEVpSXBXvTxStjJv3hbz9YFflvkpcGPw=
github.com/pion/webrtc/v3 v3.1.43 h1:YT3ZTO94UT4kSBvZnRAH82+0jJPUruiKr9CEstdlQzk=
github.com/pion/webrtc/v3 v3.1.43/go.mod h1:G/J8k0+grVsjC/rjCZ24AKoCCxcFFODgh7zThNZGs0M=
github.com/pkg/browser v0.0.0-20210706143420-7d21f8c997e2/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
@ -2042,9 +2007,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220516162934-403b01795ae8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167 h1:O8uGbHCqlTp2P6QJSLmCojM4mN6UemYv8K+dCnmHmu0=
golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -2155,7 +2117,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201201195509-5d6afe98e0b7/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@ -2178,7 +2139,6 @@ golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211201190559-0a0e4e1bb54c/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220107192237-5cfca573fb4d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
@ -2186,13 +2146,11 @@ golang.org/x/net v0.0.0-20220111093109-d55c255bac03/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220401154927-543a649e0bdd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220531201128-c960675eff93/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/oauth2 v0.0.0-20180227000427-d7d64896b5ff/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -2389,7 +2347,6 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -1,317 +0,0 @@
package peer
import (
"bufio"
"context"
"io"
"net"
"sync"
"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
const (
bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB
maxBufferedAmount uint64 = 1024 * 1024 // 1 MB
// For some reason messages larger just don't work...
// This shouldn't be a huge deal for real-world usage.
// See: https://github.com/pion/datachannel/issues/59
maxMessageLength = 64 * 1024 // 64 KB
)
// newChannel creates a new channel and initializes it.
// The initialization overrides listener handles, and detaches
// the channel on open. The datachannel should not be manually
// mutated after being passed to this function.
func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOptions) *Channel {
channel := &Channel{
opts: opts,
conn: conn,
dc: dc,
opened: make(chan struct{}),
closed: make(chan struct{}),
sendMore: make(chan struct{}, 1),
}
channel.init()
return channel
}
type ChannelOptions struct {
// ID is a channel ID that should be used when `Negotiated`
// is true.
ID uint16
// Negotiated returns whether the data channel will already
// be active on the other end. Defaults to false.
Negotiated bool
// Arbitrary string that can be parsed on `Accept`.
Protocol string
// Unordered determines whether the channel acts like
// a UDP connection. Defaults to false.
Unordered bool
// Whether the channel will be left open on disconnect or not.
// If true, data will be buffered on either end to be sent
// once reconnected. Defaults to false.
OpenOnDisconnect bool
}
// Channel represents a WebRTC DataChannel.
//
// This struct wraps webrtc.DataChannel to add concurrent-safe usage,
// data bufferring, and standardized errors for connection state.
//
// It modifies the default behavior of a DataChannel by closing on
// WebRTC PeerConnection failure. This is done to emulate TCP connections.
// This option can be changed in the options when creating a Channel.
type Channel struct {
opts *ChannelOptions
conn *Conn
dc *webrtc.DataChannel
// This field can be nil. It becomes set after the DataChannel
// has been opened and is detached.
rwc datachannel.ReadWriteCloser
reader io.Reader
closed chan struct{}
closeMutex sync.Mutex
closeError error
opened chan struct{}
// sendMore is used to block Write operations on a full buffer.
// It's signaled when the buffer can accept more data.
sendMore chan struct{}
writeMutex sync.Mutex
}
// init attaches listeners to the DataChannel to detect opening,
// closing, and when the channel is ready to transmit data.
//
// This should only be called once on creation.
func (c *Channel) init() {
// WebRTC connections maintain an internal buffer that can fill when:
// 1. Data is being sent faster than it can flush.
// 2. The connection is disconnected, but data is still being sent.
//
// This applies a maximum in-memory buffer for data, and will cause
// write operations to block once the threshold is set.
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
c.dc.OnBufferedAmountLow(func() {
// Grab the lock to protect the sendMore channel from being
// closed in between the isClosed check and the send.
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return
}
select {
case <-c.closed:
case c.sendMore <- struct{}{}:
default:
}
})
c.dc.OnClose(func() {
c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
_ = c.closeWithError(ErrClosed)
})
c.dc.OnOpen(func() {
c.closeMutex.Lock()
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
var err error
c.rwc, err = c.dc.Detach()
if err != nil {
c.closeMutex.Unlock()
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
return
}
c.closeMutex.Unlock()
// pion/webrtc will return an io.ErrShortBuffer when a read
// is triggered with a buffer size less than the chunks written.
//
// This makes sense when considering UDP connections, because
// buffering of data that has no transmit guarantees is likely
// to cause unexpected behavior.
//
// When ordered, this adds a bufio.Reader. This ensures additional
// data on TCP-like connections can be read in parts, while still
// being buffered.
if c.opts.Unordered {
c.reader = c.rwc
} else {
// This must be the max message length otherwise a short
// buffer error can occur.
c.reader = bufio.NewReaderSize(c.rwc, maxMessageLength)
}
close(c.opened)
})
c.conn.dcDisconnectListeners.Add(1)
c.conn.dcFailedListeners.Add(1)
c.conn.dcClosedWaitGroup.Add(1)
go func() {
var err error
// A DataChannel can disconnect multiple times, so this needs to loop.
for {
select {
case <-c.conn.closedRTC:
// If this channel was closed, there's no need to close again.
err = c.conn.closeError
case <-c.conn.Closed():
// If the RTC connection closed with an error, this channel
// should end with the same one.
err = c.conn.closeError
case <-c.conn.dcDisconnectChannel:
// If the RTC connection is disconnected, we need to check if
// the DataChannel is supposed to end on disconnect.
if c.opts.OpenOnDisconnect {
continue
}
err = xerrors.Errorf("rtc disconnected. closing: %w", ErrClosed)
case <-c.conn.dcFailedChannel:
// If the RTC connection failed, close the Channel.
err = ErrFailed
}
if err != nil {
break
}
}
_ = c.closeWithError(err)
}()
}
// Read blocks until data is received.
//
// This will block until the underlying DataChannel has been opened.
func (c *Channel) Read(bytes []byte) (int, error) {
err := c.waitOpened()
if err != nil {
return 0, err
}
bytesRead, err := c.reader.Read(bytes)
if err != nil {
if c.isClosed() {
return 0, c.closeError
}
// An EOF always occurs when the connection is closed.
// Alternative close errors will occur first if an unexpected
// close has occurred.
if xerrors.Is(err, io.EOF) {
err = c.closeWithError(ErrClosed)
}
}
return bytesRead, err
}
// Write sends data to the underlying DataChannel.
//
// This function will block if too much data is being sent.
// Data will buffer if the connection is temporarily disconnected,
// and will be flushed upon reconnection.
//
// If the Channel is setup to close on disconnect, any buffered
// data will be lost.
func (c *Channel) Write(bytes []byte) (n int, err error) {
if len(bytes) > maxMessageLength {
return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength)
}
c.writeMutex.Lock()
defer c.writeMutex.Unlock()
err = c.waitOpened()
if err != nil {
return 0, err
}
if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount {
<-c.sendMore
}
return c.rwc.Write(bytes)
}
// Close gracefully closes the DataChannel.
func (c *Channel) Close() error {
return c.closeWithError(nil)
}
// Label returns the label of the underlying DataChannel.
func (c *Channel) Label() string {
return c.dc.Label()
}
// Protocol returns the protocol of the underlying DataChannel.
func (c *Channel) Protocol() string {
return c.dc.Protocol()
}
// NetConn wraps the DataChannel in a struct fulfilling net.Conn.
// Read, Write, and Close operations can still be used on the *Channel struct.
func (c *Channel) NetConn() net.Conn {
return &fakeNetConn{
c: c,
addr: &peerAddr{},
}
}
// closeWithError closes the Channel with the error provided.
// If a graceful close occurs, the error will be nil.
func (c *Channel) closeWithError(err error) error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return c.closeError
}
c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
if err == nil {
c.closeError = ErrClosed
} else {
c.closeError = err
}
if c.rwc != nil {
_ = c.rwc.Close()
}
_ = c.dc.Close()
close(c.closed)
close(c.sendMore)
c.conn.dcDisconnectListeners.Sub(1)
c.conn.dcFailedListeners.Sub(1)
c.conn.dcClosedWaitGroup.Done()
return err
}
func (c *Channel) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
func (c *Channel) waitOpened() error {
select {
case <-c.opened:
// Re-check the closed channel to prioritize closure.
if c.isClosed() {
return c.closeError
}
return nil
case <-c.closed:
return c.closeError
}
}

View File

@ -1,616 +0,0 @@
package peer
import (
"bytes"
"context"
"crypto/rand"
"io"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/webrtc/v3"
"go.uber.org/atomic"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
var (
// ErrDisconnected occurs when the connection has disconnected.
// The connection will be attempting to reconnect at this point.
ErrDisconnected = xerrors.New("connection is disconnected")
// ErrFailed occurs when the connection has failed.
// The connection will not retry after this point.
ErrFailed = xerrors.New("connection has failed")
// ErrClosed occurs when the connection was closed. It wraps io.EOF
// to fulfill expected read errors from closed pipes.
ErrClosed = xerrors.Errorf("connection was closed: %w", io.EOF)
// The amount of random bytes sent in a ping.
pingDataLength = 64
)
// Client creates a new client connection.
func Client(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) {
return newWithClientOrServer(servers, true, opts)
}
// Server creates a new server connection.
func Server(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) {
return newWithClientOrServer(servers, false, opts)
}
// newWithClientOrServer constructs a new connection with the client option.
// nolint:revive
func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOptions) (*Conn, error) {
if opts == nil {
opts = &ConnOptions{}
}
opts.SettingEngine.DetachDataChannels()
logger := logging.NewDefaultLoggerFactory()
logger.DefaultLogLevel = logging.LogLevelDisabled
opts.SettingEngine.LoggerFactory = logger
api := webrtc.NewAPI(webrtc.WithSettingEngine(opts.SettingEngine))
rtc, err := api.NewPeerConnection(webrtc.Configuration{
ICEServers: servers,
})
if err != nil {
return nil, xerrors.Errorf("create peer connection: %w", err)
}
conn := &Conn{
pingChannelID: 1,
pingEchoChannelID: 2,
rtc: rtc,
offerer: client,
closed: make(chan struct{}),
closedRTC: make(chan struct{}),
closedICE: make(chan struct{}),
dcOpenChannel: make(chan *webrtc.DataChannel, 8),
dcDisconnectChannel: make(chan struct{}),
dcFailedChannel: make(chan struct{}),
localCandidateChannel: make(chan webrtc.ICECandidateInit),
localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
negotiated: make(chan struct{}),
remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
settingEngine: opts.SettingEngine,
}
conn.loggerValue.Store(opts.Logger)
if client {
// If we're the client, we want to flip the echo and
// ping channel IDs so pings don't accidentally hit each other.
conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID
}
err = conn.init()
if err != nil {
return nil, xerrors.Errorf("init: %w", err)
}
return conn, nil
}
type ConnOptions struct {
Logger slog.Logger
// Enables customization on the underlying WebRTC connection.
SettingEngine webrtc.SettingEngine
}
// Conn represents a WebRTC peer connection.
//
// This struct wraps webrtc.PeerConnection to add bidirectional pings,
// concurrent-safe webrtc.DataChannel, and standardized errors for connection state.
type Conn struct {
rtc *webrtc.PeerConnection
// Determines whether this connection will send the offer or the answer.
offerer bool
closed chan struct{}
closedRTC chan struct{}
closedRTCMutex sync.Mutex
closedICE chan struct{}
closedICEMutex sync.Mutex
closeMutex sync.Mutex
closeError error
dcCreateMutex sync.Mutex
dcOpenChannel chan *webrtc.DataChannel
dcDisconnectChannel chan struct{}
dcDisconnectListeners atomic.Uint32
dcFailedChannel chan struct{}
dcFailedListeners atomic.Uint32
dcClosedWaitGroup sync.WaitGroup
localCandidateChannel chan webrtc.ICECandidateInit
localSessionDescriptionChannel chan webrtc.SessionDescription
remoteSessionDescriptionChannel chan webrtc.SessionDescription
negotiated chan struct{}
loggerValue atomic.Value
settingEngine webrtc.SettingEngine
pingChannelID uint16
pingEchoChannelID uint16
pingEchoChan *Channel
pingEchoOnce sync.Once
pingEchoError error
pingMutex sync.Mutex
pingOnce sync.Once
pingChan *Channel
pingError error
}
func (c *Conn) logger() slog.Logger {
log, valid := c.loggerValue.Load().(slog.Logger)
if !valid {
return slog.Logger{}
}
return log
}
func (c *Conn) init() error {
c.rtc.OnNegotiationNeeded(c.negotiate)
c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) {
c.closedICEMutex.Lock()
defer c.closedICEMutex.Unlock()
select {
case <-c.closedICE:
// Don't log more state changes if we've already closed.
return
default:
c.logger().Debug(context.Background(), "ice connection state updated",
slog.F("state", iceConnectionState))
if iceConnectionState == webrtc.ICEConnectionStateClosed {
// pion/webrtc can update this state multiple times.
// A connection can never become un-closed, so we
// close the channel if it isn't already.
close(c.closedICE)
}
}
})
c.rtc.OnICEGatheringStateChange(func(iceGatherState webrtc.ICEGathererState) {
c.closedICEMutex.Lock()
defer c.closedICEMutex.Unlock()
select {
case <-c.closedICE:
// Don't log more state changes if we've already closed.
return
default:
c.logger().Debug(context.Background(), "ice gathering state updated",
slog.F("state", iceGatherState))
if iceGatherState == webrtc.ICEGathererStateClosed {
// pion/webrtc can update this state multiple times.
// A connection can never become un-closed, so we
// close the channel if it isn't already.
close(c.closedICE)
}
}
})
c.rtc.OnConnectionStateChange(func(peerConnectionState webrtc.PeerConnectionState) {
go func() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return
}
c.logger().Debug(context.Background(), "rtc connection updated",
slog.F("state", peerConnectionState))
}()
switch peerConnectionState {
case webrtc.PeerConnectionStateDisconnected:
for i := 0; i < int(c.dcDisconnectListeners.Load()); i++ {
select {
case c.dcDisconnectChannel <- struct{}{}:
default:
}
}
case webrtc.PeerConnectionStateFailed:
for i := 0; i < int(c.dcFailedListeners.Load()); i++ {
select {
case c.dcFailedChannel <- struct{}{}:
default:
}
}
case webrtc.PeerConnectionStateClosed:
// pion/webrtc can update this state multiple times.
// A connection can never become un-closed, so we
// close the channel if it isn't already.
c.closedRTCMutex.Lock()
defer c.closedRTCMutex.Unlock()
select {
case <-c.closedRTC:
default:
close(c.closedRTC)
}
}
})
// These functions need to check if the conn is closed, because they can be
// called after being closed.
c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) {
c.logger().Debug(context.Background(), "signaling state updated",
slog.F("state", signalState))
})
c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) {
c.logger().Debug(context.Background(), "dtls transport state updated",
slog.F("state", dtlsTransportState))
})
c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) {
c.logger().Debug(context.Background(), "selected candidate pair changed",
slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote))
})
c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) {
if iceCandidate == nil {
return
}
// Run this in a goroutine so we don't block pion/webrtc
// from continuing.
go func() {
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
select {
case <-c.closed:
case c.localCandidateChannel <- iceCandidate.ToJSON():
}
}()
})
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
go func() {
select {
case <-c.closed:
case c.dcOpenChannel <- dc:
}
}()
})
_, err := c.pingChannel()
if err != nil {
return err
}
_, err = c.pingEchoChannel()
if err != nil {
return err
}
return nil
}
// negotiate is triggered when a connection is ready to be established.
// See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/
func (c *Conn) negotiate() {
c.logger().Debug(context.Background(), "negotiating")
// ICE candidates cannot be added until SessionDescriptions have been
// exchanged between peers.
defer func() {
select {
case <-c.negotiated:
default:
close(c.negotiated)
}
}()
if c.offerer {
offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{})
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("create offer: %w", err))
return
}
// pion/webrtc will panic if Close is called while this
// function is being executed.
c.closeMutex.Lock()
err = c.rtc.SetLocalDescription(offer)
c.closeMutex.Unlock()
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer))
select {
case <-c.closed:
return
case c.localSessionDescriptionChannel <- offer:
}
c.logger().Debug(context.Background(), "sent offer")
}
var sessionDescription webrtc.SessionDescription
c.logger().Debug(context.Background(), "awaiting remote description...")
select {
case <-c.closed:
return
case sessionDescription = <-c.remoteSessionDescriptionChannel:
}
c.logger().Debug(context.Background(), "setting remote description")
err := c.rtc.SetRemoteDescription(sessionDescription)
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err))
return
}
if !c.offerer {
answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{})
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("create answer: %w", err))
return
}
// pion/webrtc will panic if Close is called while this
// function is being executed.
c.closeMutex.Lock()
err = c.rtc.SetLocalDescription(answer)
c.closeMutex.Unlock()
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer))
select {
case <-c.closed:
return
case c.localSessionDescriptionChannel <- answer:
}
c.logger().Debug(context.Background(), "sent answer")
}
}
// AddRemoteCandidate adds a remote candidate to the RTC connection.
func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) {
if c.isClosed() {
return
}
// This must occur in a goroutine to allow the SessionDescriptions
// to be exchanged first.
go func() {
select {
case <-c.closed:
case <-c.negotiated:
}
if c.isClosed() {
return
}
c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate))
err := c.rtc.AddICECandidate(i)
if err != nil {
if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed {
return
}
_ = c.CloseWithError(xerrors.Errorf("accept candidate: %w", err))
}
}()
}
// SetRemoteSessionDescription sets the remote description for the WebRTC connection.
func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) {
select {
case <-c.closed:
case c.remoteSessionDescriptionChannel <- sessionDescription:
}
}
// LocalSessionDescription returns a channel that emits a session description
// when one is required to be exchanged.
func (c *Conn) LocalSessionDescription() <-chan webrtc.SessionDescription {
return c.localSessionDescriptionChannel
}
// LocalCandidate returns a channel that emits when a local candidate
// needs to be exchanged with a remote connection.
func (c *Conn) LocalCandidate() <-chan webrtc.ICECandidateInit {
return c.localCandidateChannel
}
func (c *Conn) pingChannel() (*Channel, error) {
c.pingOnce.Do(func() {
c.pingChan, c.pingError = c.dialChannel(context.Background(), "ping", &ChannelOptions{
ID: c.pingChannelID,
Negotiated: true,
OpenOnDisconnect: true,
})
if c.pingError != nil {
return
}
})
return c.pingChan, c.pingError
}
func (c *Conn) pingEchoChannel() (*Channel, error) {
c.pingEchoOnce.Do(func() {
c.pingEchoChan, c.pingEchoError = c.dialChannel(context.Background(), "echo", &ChannelOptions{
ID: c.pingEchoChannelID,
Negotiated: true,
OpenOnDisconnect: true,
})
if c.pingEchoError != nil {
return
}
go func() {
for {
data := make([]byte, pingDataLength)
bytesRead, err := c.pingEchoChan.Read(data)
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
return
}
_, err = c.pingEchoChan.Write(data[:bytesRead])
if err != nil {
_ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err))
return
}
}
}()
})
return c.pingEchoChan, c.pingEchoError
}
// SetConfiguration applies options to the WebRTC connection.
// Generally used for updating transport options, like ICE servers.
func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error {
return c.rtc.SetConfiguration(configuration)
}
// Accept blocks waiting for a channel to be opened.
func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
var dataChannel *webrtc.DataChannel
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-c.closed:
return nil, c.closeError
case dataChannel = <-c.dcOpenChannel:
}
return newChannel(c, dataChannel, &ChannelOptions{}), nil
}
// CreateChannel creates a new DataChannel.
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
if opts == nil {
opts = &ChannelOptions{}
}
if opts.ID == c.pingChannelID || opts.ID == c.pingEchoChannelID {
return nil, xerrors.Errorf("datachannel id %d and %d are reserved for ping", c.pingChannelID, c.pingEchoChannelID)
}
return c.dialChannel(ctx, label, opts)
}
func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
// pion/webrtc is slower when opening multiple channels
// in parallel than it is sequentially.
c.dcCreateMutex.Lock()
defer c.dcCreateMutex.Unlock()
c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts))
var id *uint16
if opts.ID != 0 {
id = &opts.ID
}
ordered := true
if opts.Unordered {
ordered = false
}
if opts.OpenOnDisconnect && !opts.Negotiated {
return nil, xerrors.New("OpenOnDisconnect is only allowed for Negotiated channels")
}
if c.isClosed() {
return nil, xerrors.Errorf("closed: %w", c.closeError)
}
dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{
ID: id,
Negotiated: &opts.Negotiated,
Ordered: &ordered,
Protocol: &opts.Protocol,
})
if err != nil {
return nil, xerrors.Errorf("create data channel: %w", err)
}
return newChannel(c, dataChannel, opts), nil
}
// Ping returns the duration it took to round-trip data.
// Multiple pings cannot occur at the same time, so this function will block.
func (c *Conn) Ping() (time.Duration, error) {
// Pings are not async, so we need a mutex.
c.pingMutex.Lock()
defer c.pingMutex.Unlock()
ping, err := c.pingChannel()
if err != nil {
return 0, xerrors.Errorf("get ping channel: %w", err)
}
pingDataSent := make([]byte, pingDataLength)
_, err = rand.Read(pingDataSent)
if err != nil {
return 0, xerrors.Errorf("read random ping data: %w", err)
}
start := time.Now()
_, err = ping.Write(pingDataSent)
if err != nil {
return 0, xerrors.Errorf("send ping: %w", err)
}
c.logger().Debug(context.Background(), "wrote ping",
slog.F("connection_state", c.rtc.ConnectionState()))
pingDataReceived := make([]byte, pingDataLength)
_, err = ping.Read(pingDataReceived)
if err != nil {
return 0, xerrors.Errorf("read ping: %w", err)
}
end := time.Now()
if !bytes.Equal(pingDataSent, pingDataReceived) {
return 0, xerrors.Errorf("ping data inconsistency sent != received")
}
return end.Sub(start), nil
}
func (c *Conn) Closed() <-chan struct{} {
return c.closed
}
// Close closes the connection and frees all associated resources.
func (c *Conn) Close() error {
return c.CloseWithError(nil)
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// CloseWithError closes the connection; subsequent reads/writes will return the error err.
func (c *Conn) CloseWithError(err error) error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.isClosed() {
return c.closeError
}
logger := c.logger()
logger.Debug(context.Background(), "closing conn with error", slog.Error(err))
if err == nil {
c.closeError = ErrClosed
} else {
c.closeError = err
}
if ch, _ := c.pingChannel(); ch != nil {
_ = ch.closeWithError(c.closeError)
}
// If the WebRTC connection has already been closed (due to failure or disconnect),
// this call will return an error that isn't typed. We don't check the error because
// closing an already closed connection isn't an issue for us.
_ = c.rtc.Close()
// Waiting for pion/webrtc to report closed state on both of these
// ensures no goroutine leaks.
if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew {
logger.Debug(context.Background(), "waiting for rtc connection close...")
<-c.closedRTC
}
if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew {
logger.Debug(context.Background(), "waiting for ice connection close...")
<-c.closedICE
}
// Waits for all DataChannels to exit before officially labeling as closed.
// All logging, goroutines, and async functionality is cleaned up after this.
c.dcClosedWaitGroup.Wait()
// Disable logging!
c.loggerValue.Store(slog.Logger{})
logger.Sync()
logger.Debug(context.Background(), "closed")
close(c.closed)
return err
}

View File

@ -1,434 +0,0 @@
package peer_test
import (
"context"
"io"
"net"
"net/http"
"os"
"sync"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/peer"
"github.com/coder/coder/testutil"
)
var (
disconnectedTimeout = func() time.Duration {
// Connection state is unfortunately time-based. When resources are
// contended, a connection can take greater than this timeout to
// handshake, which results in a test flake.
//
// During local testing resources are rarely contended. Reducing this
// timeout leads to faster local development.
//
// In CI resources are frequently contended, so increasing this value
// results in less flakes.
if os.Getenv("CI") == "true" {
return time.Second
}
return 100 * time.Millisecond
}()
failedTimeout = disconnectedTimeout * 3
keepAliveInterval = time.Millisecond * 2
// There's a global race in the vnet library allocation code.
// This mutex locks around the creation of the vnet.
vnetMutex = sync.Mutex{}
)
func TestMain(m *testing.M) {
// pion/ice doesn't properly close immediately. The solution for this isn't yet known. See:
// https://github.com/pion/ice/pull/413
goleak.VerifyTestMain(m,
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func1"),
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).startOnConnectionStateChangeRoutine.func2"),
goleak.IgnoreTopFunction("github.com/pion/ice/v2.(*Agent).taskLoop"),
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
)
}
func TestConn(t *testing.T) {
t.Parallel()
t.Run("Ping", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
_, err := client.Ping()
require.NoError(t, err)
_, err = server.Ping()
require.NoError(t, err)
})
t.Run("PingNetworkOffline", func(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
_, err := server.Ping()
require.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
_, err = server.Ping()
require.ErrorIs(t, err, peer.ErrFailed)
})
t.Run("PingReconnect", func(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
channel, err := server.CreateChannel(context.Background(), "wow", nil)
assert.NoError(t, err)
defer channel.Close()
err = wan.Stop()
require.NoError(t, err)
// Once the connection is marked as disconnected, this
// channel will be closed.
_, err = channel.Read(make([]byte, 4))
assert.ErrorIs(t, err, peer.ErrClosed)
err = wan.Start()
require.NoError(t, err)
_, err = server.Ping()
require.NoError(t, err)
})
t.Run("Accept", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
defer cch.Close()
sch, err := server.Accept(ctx)
require.NoError(t, err)
defer sch.Close()
_ = cch.Close()
_, err = sch.Read(make([]byte, 4))
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("AcceptNetworkOffline", func(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
defer cch.Close()
sch, err := server.Accept(ctx)
require.NoError(t, err)
defer sch.Close()
err = wan.Stop()
require.NoError(t, err)
_ = cch.Close()
_, err = sch.Read(make([]byte, 4))
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("Buffering", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
require.NoError(t, err)
defer cch.Close()
readErr := make(chan error, 1)
go func() {
sch, err := server.Accept(ctx)
if err != nil {
readErr <- err
_ = cch.Close()
return
}
defer sch.Close()
bytes := make([]byte, 4096)
for {
_, err = sch.Read(bytes)
if err != nil {
readErr <- err
return
}
}
}()
bytes := make([]byte, 4096)
for i := 0; i < 1024; i++ {
_, err = cch.Write(bytes)
require.NoError(t, err, "write i=%d", i)
}
_ = cch.Close()
select {
case err = <-readErr:
require.ErrorIs(t, err, peer.ErrClosed, "read error")
case <-ctx.Done():
require.Fail(t, "timeout waiting for read error")
}
})
t.Run("NetConn", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
go func() {
sch, err := server.Accept(ctx)
if err != nil {
assert.NoError(t, err)
return
}
defer sch.Close()
nc2 := sch.NetConn()
defer nc2.Close()
nc1, err := net.Dial("tcp", srv.Addr().String())
if err != nil {
assert.NoError(t, err)
return
}
defer nc1.Close()
go func() {
defer nc1.Close()
defer nc2.Close()
_, _ = io.Copy(nc1, nc2)
}()
_, _ = io.Copy(nc2, nc1)
}()
go func() {
server := http.Server{
ReadHeaderTimeout: time.Minute,
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
}),
}
defer server.Close()
_ = server.Serve(srv)
}()
//nolint:forcetypeassert
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
if err != nil {
return nil, err
}
return cch.NetConn(), nil
}
c := http.Client{
Transport: defaultTransport,
}
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, resp.StatusCode, 200)
// Triggers any connections to close.
// This test below ensures the DataChannel actually closes.
defaultTransport.CloseIdleConnections()
err = cch.Close()
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("CloseBeforeNegotiate", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
err := client.Close()
require.NoError(t, err)
err = server.Close()
require.NoError(t, err)
})
t.Run("CloseWithError", func(t *testing.T) {
t.Parallel()
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
require.NoError(t, err)
expectedErr := xerrors.New("wow")
_ = conn.CloseWithError(expectedErr)
_, err = conn.CreateChannel(context.Background(), "", nil)
require.ErrorIs(t, err, expectedErr)
})
t.Run("PingConcurrent", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := client.Ping()
assert.NoError(t, err)
}()
go func() {
defer wg.Done()
_, err := server.Ping()
assert.NoError(t, err)
}()
wg.Wait()
})
t.Run("CandidateBeforeSessionDescription", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
server.SetRemoteSessionDescription(<-client.LocalSessionDescription())
sdp := <-server.LocalSessionDescription()
client.AddRemoteCandidate(<-server.LocalCandidate())
client.SetRemoteSessionDescription(sdp)
server.AddRemoteCandidate(<-client.LocalCandidate())
_, err := client.Ping()
require.NoError(t, err)
})
t.Run("ShortBuffer", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
go func() {
channel, err := client.CreateChannel(ctx, "test", nil)
if err != nil {
assert.NoError(t, err)
return
}
defer channel.Close()
_, err = channel.Write([]byte{1, 2})
assert.NoError(t, err)
}()
channel, err := server.Accept(ctx)
require.NoError(t, err)
defer channel.Close()
data := make([]byte, 1)
_, err = channel.Read(data)
require.NoError(t, err)
require.Equal(t, uint8(0x1), data[0])
_, err = channel.Read(data)
require.NoError(t, err)
require.Equal(t, uint8(0x2), data[0])
})
}
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
loggingFactory := logging.NewDefaultLoggerFactory()
loggingFactory.DefaultLogLevel = logging.LogLevelDisabled
vnetMutex.Lock()
defer vnetMutex.Unlock()
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: loggingFactory,
})
require.NoError(t, err)
c1Net := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"1.2.3.4"},
})
err = wan.AddNet(c1Net)
require.NoError(t, err)
c2Net := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"1.2.3.5"},
})
err = wan.AddNet(c2Net)
require.NoError(t, err)
c1SettingEngine := webrtc.SettingEngine{}
c1SettingEngine.SetVNet(c1Net)
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
channel1, err := peer.Client([]webrtc.ICEServer{{}}, &peer.ConnOptions{
SettingEngine: c1SettingEngine,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
channel1.Close()
})
c2SettingEngine := webrtc.SettingEngine{}
c2SettingEngine.SetVNet(c2Net)
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
channel2, err := peer.Server([]webrtc.ICEServer{{}}, &peer.ConnOptions{
SettingEngine: c2SettingEngine,
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
channel2.Close()
})
err = wan.Start()
require.NoError(t, err)
t.Cleanup(func() {
_ = wan.Stop()
})
return channel1, channel2, wan
}
func exchange(t *testing.T, client, server *peer.Conn) {
var wg sync.WaitGroup
wg.Add(2)
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
wg.Wait()
})
go func() {
defer wg.Done()
for {
select {
case c := <-server.LocalCandidate():
client.AddRemoteCandidate(c)
case c := <-server.LocalSessionDescription():
client.SetRemoteSessionDescription(c)
case <-server.Closed():
return
}
}
}()
go func() {
defer wg.Done()
for {
select {
case c := <-client.LocalCandidate():
server.AddRemoteCandidate(c)
case c := <-client.LocalSessionDescription():
server.SetRemoteSessionDescription(c)
case <-client.Closed():
return
}
}
}()
}

View File

@ -1,59 +0,0 @@
package peer
import (
"net"
"time"
)
type peerAddr struct{}
// Statically checks if we properly implement net.Addr.
var _ net.Addr = &peerAddr{}
func (*peerAddr) Network() string {
return "peer"
}
func (*peerAddr) String() string {
return "peer/unknown-addr"
}
type fakeNetConn struct {
c *Channel
addr *peerAddr
}
// Statically checks if we properly implement net.Conn.
var _ net.Conn = &fakeNetConn{}
func (c *fakeNetConn) Read(b []byte) (n int, err error) {
return c.c.Read(b)
}
func (c *fakeNetConn) Write(b []byte) (n int, err error) {
return c.c.Write(b)
}
func (c *fakeNetConn) Close() error {
return c.c.Close()
}
func (c *fakeNetConn) LocalAddr() net.Addr {
return c.addr
}
func (c *fakeNetConn) RemoteAddr() net.Addr {
return c.addr
}
func (*fakeNetConn) SetDeadline(_ time.Time) error {
return nil
}
func (*fakeNetConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (*fakeNetConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@ -1,87 +0,0 @@
package peerbroker
import (
"context"
"errors"
"io"
"reflect"
"github.com/pion/webrtc/v3"
"golang.org/x/xerrors"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker/proto"
)
// Dial consumes the PeerBroker gRPC connection negotiation stream to produce a WebRTC peered connection.
func Dial(stream proto.DRPCPeerBroker_NegotiateConnectionClient, iceServers []webrtc.ICEServer, opts *peer.ConnOptions) (*peer.Conn, error) {
peerConn, err := peer.Client(iceServers, opts)
if err != nil {
return nil, xerrors.Errorf("create peer connection: %w", err)
}
go func() {
defer stream.Close()
// Exchanging messages from the peer connection to negotiate a connection.
for {
select {
case <-peerConn.Closed():
return
case sessionDescription := <-peerConn.LocalSessionDescription():
err = stream.Send(&proto.Exchange{
Message: &proto.Exchange_Sdp{
Sdp: &proto.WebRTCSessionDescription{
SdpType: int32(sessionDescription.Type),
Sdp: sessionDescription.SDP,
},
},
})
if err != nil {
_ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err))
return
}
case iceCandidate := <-peerConn.LocalCandidate():
err = stream.Send(&proto.Exchange{
Message: &proto.Exchange_IceCandidate{
IceCandidate: iceCandidate.Candidate,
},
})
if err != nil {
_ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err))
return
}
}
}
}()
go func() {
// Exchanging messages from the server to negotiate a connection.
for {
serverToClientMessage, err := stream.Recv()
if err != nil {
// p2p connections should never die if this stream does due
// to proper closure or context cancellation!
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return
}
_ = peerConn.CloseWithError(xerrors.Errorf("recv: %w", err))
return
}
switch {
case serverToClientMessage.GetSdp() != nil:
peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{
Type: webrtc.SDPType(serverToClientMessage.GetSdp().SdpType),
SDP: serverToClientMessage.GetSdp().Sdp,
})
case serverToClientMessage.GetIceCandidate() != "":
peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{
Candidate: serverToClientMessage.GetIceCandidate(),
})
default:
_ = peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(serverToClientMessage).String()))
return
}
}
}()
return peerConn, nil
}

View File

@ -1,67 +0,0 @@
package peerbroker_test
import (
"context"
"testing"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDial(t *testing.T) {
t.Parallel()
t.Run("Connect", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client, server := provisionersdk.TransportPipe()
defer client.Close()
defer server.Close()
settingEngine := webrtc.SettingEngine{}
listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
return []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
SettingEngine: settingEngine,
}, nil
})
require.NoError(t, err)
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(ctx)
require.NoError(t, err)
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
SettingEngine: settingEngine,
})
require.NoError(t, err)
defer clientConn.Close()
serverConn, err := listener.Accept()
require.NoError(t, err)
defer serverConn.Close()
_, err = serverConn.Ping()
require.NoError(t, err)
_, err = clientConn.Ping()
require.NoError(t, err)
})
}

View File

@ -1,188 +0,0 @@
package peerbroker
import (
"context"
"errors"
"io"
"net"
"reflect"
"sync"
"github.com/pion/webrtc/v3"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker/proto"
)
// ConnSettingsFunc returns initialization options for a connection
type ConnSettingsFunc func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error)
// Listen consumes the transport as the server-side of the PeerBroker dRPC service.
// The Accept function must be serviced, or new connections will hang.
func Listen(connListener net.Listener, connSettingsFunc ConnSettingsFunc) (*Listener, error) {
if connSettingsFunc == nil {
connSettingsFunc = func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
return []webrtc.ICEServer{}, nil, nil
}
}
ctx, cancelFunc := context.WithCancel(context.Background())
listener := &Listener{
connectionChannel: make(chan *peer.Conn),
connectionListener: connListener,
closeFunc: cancelFunc,
closed: make(chan struct{}),
}
mux := drpcmux.New()
err := proto.DRPCRegisterPeerBroker(mux, &peerBrokerService{
connSettingsFunc: connSettingsFunc,
listener: listener,
})
if err != nil {
return nil, xerrors.Errorf("register peer broker: %w", err)
}
srv := drpcserver.New(mux)
go func() {
err := srv.Serve(ctx, connListener)
_ = listener.closeWithError(err)
}()
return listener, nil
}
type Listener struct {
connectionChannel chan *peer.Conn
connectionListener net.Listener
closeFunc context.CancelFunc
closed chan struct{}
closeMutex sync.Mutex
closeError error
}
// Accept blocks until a connection arrives or the listener is closed.
func (l *Listener) Accept() (*peer.Conn, error) {
select {
case <-l.closed:
return nil, l.closeError
case conn := <-l.connectionChannel:
return conn, nil
}
}
// Close ends the listener. This will block all new WebRTC connections
// from establishing, but will not close active connections.
func (l *Listener) Close() error {
return l.closeWithError(io.EOF)
}
func (l *Listener) closeWithError(err error) error {
l.closeMutex.Lock()
defer l.closeMutex.Unlock()
if l.isClosed() {
return l.closeError
}
_ = l.connectionListener.Close()
l.closeError = err
l.closeFunc()
close(l.closed)
return nil
}
func (l *Listener) isClosed() bool {
select {
case <-l.closed:
return true
default:
return false
}
}
// Implements the PeerBroker service protobuf definition.
type peerBrokerService struct {
listener *Listener
connSettingsFunc ConnSettingsFunc
}
// NegotiateConnection negotiates a WebRTC connection.
func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
iceServers, connOptions, err := b.connSettingsFunc(stream.Context())
if err != nil {
return xerrors.Errorf("get connection settings: %w", err)
}
peerConn, err := peer.Server(iceServers, connOptions)
if err != nil {
return xerrors.Errorf("create peer connection: %w", err)
}
select {
case <-b.listener.closed:
return peerConn.CloseWithError(b.listener.closeError)
case b.listener.connectionChannel <- peerConn:
}
go func() {
defer stream.Close()
for {
select {
case <-peerConn.Closed():
return
case sessionDescription := <-peerConn.LocalSessionDescription():
err = stream.Send(&proto.Exchange{
Message: &proto.Exchange_Sdp{
Sdp: &proto.WebRTCSessionDescription{
SdpType: int32(sessionDescription.Type),
Sdp: sessionDescription.SDP,
},
},
})
if err != nil {
_ = peerConn.CloseWithError(xerrors.Errorf("send local session description: %w", err))
return
}
case iceCandidate := <-peerConn.LocalCandidate():
err = stream.Send(&proto.Exchange{
Message: &proto.Exchange_IceCandidate{
IceCandidate: iceCandidate.Candidate,
},
})
if err != nil {
_ = peerConn.CloseWithError(xerrors.Errorf("send local candidate: %w", err))
return
}
}
}
}()
for {
clientToServerMessage, err := stream.Recv()
if err != nil {
// p2p connections should never die if this stream does due
// to proper closure or context cancellation!
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return nil
}
return peerConn.CloseWithError(xerrors.Errorf("recv: %w", err))
}
switch {
case clientToServerMessage.GetSdp() != nil:
peerConn.SetRemoteSessionDescription(webrtc.SessionDescription{
Type: webrtc.SDPType(clientToServerMessage.GetSdp().SdpType),
SDP: clientToServerMessage.GetSdp().Sdp,
})
case clientToServerMessage.GetIceCandidate() != "":
peerConn.AddRemoteCandidate(webrtc.ICECandidateInit{
Candidate: clientToServerMessage.GetIceCandidate(),
})
default:
return peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(clientToServerMessage).String()))
}
}
}

View File

@ -1,52 +0,0 @@
package peerbroker_test
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func TestListen(t *testing.T) {
t.Parallel()
// Ensures connections blocked on Accept() are
// closed if the listener is.
t.Run("NoAcceptClosed", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client, server := provisionersdk.TransportPipe()
defer client.Close()
defer server.Close()
listener, err := peerbroker.Listen(server, nil)
require.NoError(t, err)
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(ctx)
require.NoError(t, err)
clientConn, err := peerbroker.Dial(stream, nil, nil)
require.NoError(t, err)
defer clientConn.Close()
_ = listener.Close()
})
// Ensures Accept() properly exits when Close() is called.
t.Run("AcceptClosed", func(t *testing.T) {
t.Parallel()
client, server := provisionersdk.TransportPipe()
defer client.Close()
defer server.Close()
listener, err := peerbroker.Listen(server, nil)
require.NoError(t, err)
go listener.Close()
_, err = listener.Accept()
require.ErrorIs(t, err, io.EOF)
})
}

View File

@ -1,269 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.21.5
// source: peerbroker/proto/peerbroker.proto
package proto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type WebRTCSessionDescription struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
SdpType int32 `protobuf:"varint,1,opt,name=sdp_type,json=sdpType,proto3" json:"sdp_type,omitempty"`
Sdp string `protobuf:"bytes,2,opt,name=sdp,proto3" json:"sdp,omitempty"`
}
func (x *WebRTCSessionDescription) Reset() {
*x = WebRTCSessionDescription{}
if protoimpl.UnsafeEnabled {
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *WebRTCSessionDescription) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*WebRTCSessionDescription) ProtoMessage() {}
func (x *WebRTCSessionDescription) ProtoReflect() protoreflect.Message {
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use WebRTCSessionDescription.ProtoReflect.Descriptor instead.
func (*WebRTCSessionDescription) Descriptor() ([]byte, []int) {
return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{0}
}
func (x *WebRTCSessionDescription) GetSdpType() int32 {
if x != nil {
return x.SdpType
}
return 0
}
func (x *WebRTCSessionDescription) GetSdp() string {
if x != nil {
return x.Sdp
}
return ""
}
type Exchange struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Message:
//
// *Exchange_Sdp
// *Exchange_IceCandidate
Message isExchange_Message `protobuf_oneof:"message"`
}
func (x *Exchange) Reset() {
*x = Exchange{}
if protoimpl.UnsafeEnabled {
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Exchange) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Exchange) ProtoMessage() {}
func (x *Exchange) ProtoReflect() protoreflect.Message {
mi := &file_peerbroker_proto_peerbroker_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Exchange.ProtoReflect.Descriptor instead.
func (*Exchange) Descriptor() ([]byte, []int) {
return file_peerbroker_proto_peerbroker_proto_rawDescGZIP(), []int{1}
}
func (m *Exchange) GetMessage() isExchange_Message {
if m != nil {
return m.Message
}
return nil
}
func (x *Exchange) GetSdp() *WebRTCSessionDescription {
if x, ok := x.GetMessage().(*Exchange_Sdp); ok {
return x.Sdp
}
return nil
}
func (x *Exchange) GetIceCandidate() string {
if x, ok := x.GetMessage().(*Exchange_IceCandidate); ok {
return x.IceCandidate
}
return ""
}
type isExchange_Message interface {
isExchange_Message()
}
type Exchange_Sdp struct {
Sdp *WebRTCSessionDescription `protobuf:"bytes,1,opt,name=sdp,proto3,oneof"`
}
type Exchange_IceCandidate struct {
IceCandidate string `protobuf:"bytes,2,opt,name=ice_candidate,json=iceCandidate,proto3,oneof"`
}
func (*Exchange_Sdp) isExchange_Message() {}
func (*Exchange_IceCandidate) isExchange_Message() {}
var File_peerbroker_proto_peerbroker_proto protoreflect.FileDescriptor
var file_peerbroker_proto_peerbroker_proto_rawDesc = []byte{
0x0a, 0x21, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2f, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x22,
0x47, 0x0a, 0x18, 0x57, 0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e,
0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x0a, 0x08, 0x73,
0x64, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x73,
0x64, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x64, 0x70, 0x22, 0x76, 0x0a, 0x08, 0x45, 0x78, 0x63, 0x68,
0x61, 0x6e, 0x67, 0x65, 0x12, 0x38, 0x0a, 0x03, 0x73, 0x64, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x24, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x57,
0x65, 0x62, 0x52, 0x54, 0x43, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x44, 0x65, 0x73, 0x63,
0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x00, 0x52, 0x03, 0x73, 0x64, 0x70, 0x12, 0x25,
0x0a, 0x0d, 0x69, 0x63, 0x65, 0x5f, 0x63, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x18,
0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x69, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64,
0x69, 0x64, 0x61, 0x74, 0x65, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x32, 0x53, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x42, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x12, 0x45,
0x0a, 0x13, 0x4e, 0x65, 0x67, 0x6f, 0x74, 0x69, 0x61, 0x74, 0x65, 0x43, 0x6f, 0x6e, 0x6e, 0x65,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x14, 0x2e, 0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b,
0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x14, 0x2e, 0x70, 0x65,
0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2e, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67,
0x65, 0x28, 0x01, 0x30, 0x01, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f,
0x70, 0x65, 0x65, 0x72, 0x62, 0x72, 0x6f, 0x6b, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_peerbroker_proto_peerbroker_proto_rawDescOnce sync.Once
file_peerbroker_proto_peerbroker_proto_rawDescData = file_peerbroker_proto_peerbroker_proto_rawDesc
)
func file_peerbroker_proto_peerbroker_proto_rawDescGZIP() []byte {
file_peerbroker_proto_peerbroker_proto_rawDescOnce.Do(func() {
file_peerbroker_proto_peerbroker_proto_rawDescData = protoimpl.X.CompressGZIP(file_peerbroker_proto_peerbroker_proto_rawDescData)
})
return file_peerbroker_proto_peerbroker_proto_rawDescData
}
var file_peerbroker_proto_peerbroker_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_peerbroker_proto_peerbroker_proto_goTypes = []interface{}{
(*WebRTCSessionDescription)(nil), // 0: peerbroker.WebRTCSessionDescription
(*Exchange)(nil), // 1: peerbroker.Exchange
}
var file_peerbroker_proto_peerbroker_proto_depIdxs = []int32{
0, // 0: peerbroker.Exchange.sdp:type_name -> peerbroker.WebRTCSessionDescription
1, // 1: peerbroker.PeerBroker.NegotiateConnection:input_type -> peerbroker.Exchange
1, // 2: peerbroker.PeerBroker.NegotiateConnection:output_type -> peerbroker.Exchange
2, // [2:3] is the sub-list for method output_type
1, // [1:2] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_peerbroker_proto_peerbroker_proto_init() }
func file_peerbroker_proto_peerbroker_proto_init() {
if File_peerbroker_proto_peerbroker_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_peerbroker_proto_peerbroker_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*WebRTCSessionDescription); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_peerbroker_proto_peerbroker_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Exchange); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_peerbroker_proto_peerbroker_proto_msgTypes[1].OneofWrappers = []interface{}{
(*Exchange_Sdp)(nil),
(*Exchange_IceCandidate)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_peerbroker_proto_peerbroker_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_peerbroker_proto_peerbroker_proto_goTypes,
DependencyIndexes: file_peerbroker_proto_peerbroker_proto_depIdxs,
MessageInfos: file_peerbroker_proto_peerbroker_proto_msgTypes,
}.Build()
File_peerbroker_proto_peerbroker_proto = out.File
file_peerbroker_proto_peerbroker_proto_rawDesc = nil
file_peerbroker_proto_peerbroker_proto_goTypes = nil
file_peerbroker_proto_peerbroker_proto_depIdxs = nil
}

View File

@ -1,28 +0,0 @@
syntax = "proto3";
option go_package = "github.com/coder/coder/peerbroker/proto";
package peerbroker;
message WebRTCSessionDescription {
int32 sdp_type = 1;
string sdp = 2;
}
message Exchange {
oneof message {
WebRTCSessionDescription sdp = 1;
string ice_candidate = 2;
}
}
// PeerBroker mediates WebRTC connection signaling.
service PeerBroker {
// NegotiateConnection establishes a bidirectional stream to negotiate a new WebRTC connection.
// 1. Client sends WebRTCSessionDescription to the server.
// 2. Server sends WebRTCSessionDescription to the client, exchanging encryption keys.
// 3. Client<->Server exchange ICE Candidates to establish a peered connection.
//
// See: https://davekilian.com/webrtc-the-hard-way.html
rpc NegotiateConnection(stream Exchange) returns (stream Exchange);
}

View File

@ -1,146 +0,0 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.26
// source: peerbroker/proto/peerbroker.proto
package proto
import (
context "context"
errors "errors"
protojson "google.golang.org/protobuf/encoding/protojson"
proto "google.golang.org/protobuf/proto"
drpc "storj.io/drpc"
drpcerr "storj.io/drpc/drpcerr"
)
type drpcEncoding_File_peerbroker_proto_peerbroker_proto struct{}
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Marshal(msg drpc.Message) ([]byte, error) {
return proto.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
}
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) Unmarshal(buf []byte, msg drpc.Message) error {
return proto.Unmarshal(buf, msg.(proto.Message))
}
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
return protojson.Marshal(msg.(proto.Message))
}
func (drpcEncoding_File_peerbroker_proto_peerbroker_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
return protojson.Unmarshal(buf, msg.(proto.Message))
}
type DRPCPeerBrokerClient interface {
DRPCConn() drpc.Conn
NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error)
}
type drpcPeerBrokerClient struct {
cc drpc.Conn
}
func NewDRPCPeerBrokerClient(cc drpc.Conn) DRPCPeerBrokerClient {
return &drpcPeerBrokerClient{cc}
}
func (c *drpcPeerBrokerClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcPeerBrokerClient) NegotiateConnection(ctx context.Context) (DRPCPeerBroker_NegotiateConnectionClient, error) {
stream, err := c.cc.NewStream(ctx, "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
if err != nil {
return nil, err
}
x := &drpcPeerBroker_NegotiateConnectionClient{stream}
return x, nil
}
type DRPCPeerBroker_NegotiateConnectionClient interface {
drpc.Stream
Send(*Exchange) error
Recv() (*Exchange, error)
}
type drpcPeerBroker_NegotiateConnectionClient struct {
drpc.Stream
}
func (x *drpcPeerBroker_NegotiateConnectionClient) Send(m *Exchange) error {
return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
}
func (x *drpcPeerBroker_NegotiateConnectionClient) Recv() (*Exchange, error) {
m := new(Exchange)
if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil {
return nil, err
}
return m, nil
}
func (x *drpcPeerBroker_NegotiateConnectionClient) RecvMsg(m *Exchange) error {
return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
}
type DRPCPeerBrokerServer interface {
NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error
}
type DRPCPeerBrokerUnimplementedServer struct{}
func (s *DRPCPeerBrokerUnimplementedServer) NegotiateConnection(DRPCPeerBroker_NegotiateConnectionStream) error {
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCPeerBrokerDescription struct{}
func (DRPCPeerBrokerDescription) NumMethods() int { return 1 }
func (DRPCPeerBrokerDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
case 0:
return "/peerbroker.PeerBroker/NegotiateConnection", drpcEncoding_File_peerbroker_proto_peerbroker_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return nil, srv.(DRPCPeerBrokerServer).
NegotiateConnection(
&drpcPeerBroker_NegotiateConnectionStream{in1.(drpc.Stream)},
)
}, DRPCPeerBrokerServer.NegotiateConnection, true
default:
return "", nil, nil, nil, false
}
}
func DRPCRegisterPeerBroker(mux drpc.Mux, impl DRPCPeerBrokerServer) error {
return mux.Register(impl, DRPCPeerBrokerDescription{})
}
type DRPCPeerBroker_NegotiateConnectionStream interface {
drpc.Stream
Send(*Exchange) error
Recv() (*Exchange, error)
}
type drpcPeerBroker_NegotiateConnectionStream struct {
drpc.Stream
}
func (x *drpcPeerBroker_NegotiateConnectionStream) Send(m *Exchange) error {
return x.MsgSend(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
}
func (x *drpcPeerBroker_NegotiateConnectionStream) Recv() (*Exchange, error) {
m := new(Exchange)
if err := x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{}); err != nil {
return nil, err
}
return m, nil
}
func (x *drpcPeerBroker_NegotiateConnectionStream) RecvMsg(m *Exchange) error {
return x.MsgRecv(m, drpcEncoding_File_peerbroker_proto_peerbroker_proto{})
}

View File

@ -1,283 +0,0 @@
package peerbroker
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
protobuf "google.golang.org/protobuf/proto"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/peerbroker/proto"
)
var (
// Each NegotiateConnection() function call spawns a new stream.
streamIDLength = len(uuid.NewString())
// We shouldn't PubSub anything larger than this!
maxPayloadSizeBytes = 8192
)
// ProxyOptions provides values to configure a proxy.
type ProxyOptions struct {
ChannelID string
Logger slog.Logger
Pubsub database.Pubsub
}
// ProxyDial writes client negotiation streams over PubSub.
//
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
// messages are small in size (<=8KB), and we don't require delivery
// guarantees because connections can always be renegotiated.
//
// ┌────────────────────┐ ┌─────────────────────────────┐
// │ coderd │ │ coderd │
//
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
// │ client │ │ │ │ │ ┌─────┐
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
//
// │ │ │NegotiateConnection() streams│
// │<stream-id><payload>│ │or write to existing ones. │
// └────────────────────┘ └─────────────────────────────┘
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
proxyDial := &proxyDial{
channelID: options.ChannelID,
logger: options.Logger,
pubsub: options.Pubsub,
connection: client,
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
}
return proxyDial, proxyDial.listen()
}
// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
// as new NegotiateConnection() streams.
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
mux := drpcmux.New()
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
channelID: options.ChannelID,
pubsub: options.Pubsub,
logger: options.Logger,
})
if err != nil {
return xerrors.Errorf("register peer broker: %w", err)
}
server := drpcserver.New(mux)
err = server.Serve(ctx, connListener)
if err != nil {
if errors.Is(err, yamux.ErrSessionShutdown) {
return nil
}
return xerrors.Errorf("serve: %w", err)
}
return nil
}
type proxyListen struct {
channelID string
pubsub database.Pubsub
logger slog.Logger
}
func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
streamID := uuid.NewString()
var err error
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onServerToClientMessage(streamID, stream, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
}
})
if err != nil {
return xerrors.Errorf("subscribe: %w", err)
}
defer closeSubscribe()
for {
clientToServerMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return xerrors.Errorf("recv: %w", err)
}
data, err := protobuf.Marshal(clientToServerMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
}
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyOutID(p.channelID), marshal(data))
if err != nil {
return xerrors.Errorf("publish: %w", err)
}
}
return nil
}
func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
var err error
message, err = unmarshal(message)
if err != nil {
return xerrors.Errorf("decode: %w", err)
}
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
}
serverStreamID := string(message[0:streamIDLength])
if serverStreamID != streamID {
// It's not trying to communicate with this stream!
return nil
}
var msg proto.Exchange
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
}
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("send message: %w", err)
}
return nil
}
type proxyDial struct {
channelID string
pubsub database.Pubsub
logger slog.Logger
connection proto.DRPCPeerBrokerClient
closeSubscribe func()
streamMutex sync.Mutex
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
}
func (p *proxyDial) listen() error {
var err error
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onClientToServerMessage(ctx, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
}
})
if err != nil {
return err
}
return nil
}
func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
var err error
message, err = unmarshal(message)
if err != nil {
return xerrors.Errorf("decode: %w", err)
}
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
}
streamID := string(message[0:streamIDLength])
p.streamMutex.Lock()
stream, ok := p.streams[streamID]
if !ok {
stream, err = p.connection.NegotiateConnection(ctx)
if err != nil {
p.streamMutex.Unlock()
return xerrors.Errorf("negotiate connection: %w", err)
}
p.streams[streamID] = stream
go func() {
defer stream.Close()
err := p.onServerToClientMessage(streamID, stream)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
}
}()
go func() {
<-stream.Context().Done()
p.streamMutex.Lock()
delete(p.streams, streamID)
p.streamMutex.Unlock()
}()
}
p.streamMutex.Unlock()
var msg proto.Exchange
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
}
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("write message: %w", err)
}
return nil
}
func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
for {
serverToClientMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
if errors.Is(err, context.Canceled) {
break
}
return xerrors.Errorf("recv: %w", err)
}
data, err := protobuf.Marshal(serverToClientMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
}
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyInID(p.channelID), marshal(data))
if err != nil {
return xerrors.Errorf("publish: %w", err)
}
}
return nil
}
func (p *proxyDial) Close() error {
p.streamMutex.Lock()
defer p.streamMutex.Unlock()
p.closeSubscribe()
return nil
}
// base64 needs to be used here to keep the pubsub messages in UTF-8 range.
// PostgreSQL cannot handle non UTF-8 messages over pubsub.
func marshal(data []byte) []byte {
return []byte(base64.StdEncoding.EncodeToString(data))
}
func unmarshal(data []byte) ([]byte, error) {
return base64.StdEncoding.DecodeString(string(data))
}
func proxyOutID(channelID string) string {
return fmt.Sprintf("%s-out", channelID)
}
func proxyInID(channelID string) string {
return fmt.Sprintf("%s-in", channelID)
}

View File

@ -1,84 +0,0 @@
package peerbroker_test
import (
"context"
"sync"
"testing"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func TestProxy(t *testing.T) {
t.Parallel()
ctx := context.Background()
channelID := "hello"
pubsub := database.NewPubsubInMemory()
dialerClient, dialerServer := provisionersdk.TransportPipe()
defer dialerClient.Close()
defer dialerServer.Close()
listenerClient, listenerServer := provisionersdk.TransportPipe()
defer listenerClient.Close()
defer listenerServer.Close()
listener, err := peerbroker.Listen(listenerServer, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
return nil, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
}, nil
})
require.NoError(t, err)
proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
ChannelID: channelID,
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
Pubsub: pubsub,
})
require.NoError(t, err)
defer func() {
_ = proxyCloser.Close()
}()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
ChannelID: channelID,
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
Pubsub: pubsub,
})
assert.NoError(t, err)
}()
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
stream, err := api.NegotiateConnection(ctx)
require.NoError(t, err)
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
defer clientConn.Close()
serverConn, err := listener.Accept()
require.NoError(t, err)
defer serverConn.Close()
_, err = serverConn.Ping()
require.NoError(t, err)
_, err = clientConn.Ping()
require.NoError(t, err)
_ = dialerServer.Close()
wg.Wait()
}