mirror of https://github.com/coder/coder.git
330 lines
8.4 KiB
Go
330 lines
8.4 KiB
Go
package peer
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/pion/datachannel"
|
|
"github.com/pion/webrtc/v3"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog"
|
|
)
|
|
|
|
const (
|
|
bufferedAmountLowThreshold uint64 = 512 * 1024 // 512 KB
|
|
maxBufferedAmount uint64 = 1024 * 1024 // 1 MB
|
|
// For some reason messages larger just don't work...
|
|
// This shouldn't be a huge deal for real-world usage.
|
|
// See: https://github.com/pion/datachannel/issues/59
|
|
maxMessageLength = 64 * 1024 // 64 KB
|
|
)
|
|
|
|
// newChannel creates a new channel and initializes it.
|
|
// The initialization overrides listener handles, and detaches
|
|
// the channel on open. The datachannel should not be manually
|
|
// mutated after being passed to this function.
|
|
func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOptions) *Channel {
|
|
channel := &Channel{
|
|
opts: opts,
|
|
conn: conn,
|
|
dc: dc,
|
|
|
|
opened: make(chan struct{}),
|
|
closed: make(chan struct{}),
|
|
sendMore: make(chan struct{}, 1),
|
|
}
|
|
channel.init()
|
|
return channel
|
|
}
|
|
|
|
type ChannelOptions struct {
|
|
// ID is a channel ID that should be used when `Negotiated`
|
|
// is true.
|
|
ID uint16
|
|
|
|
// Negotiated returns whether the data channel will already
|
|
// be active on the other end. Defaults to false.
|
|
Negotiated bool
|
|
|
|
// Arbitrary string that can be parsed on `Accept`.
|
|
Protocol string
|
|
|
|
// Unordered determines whether the channel acts like
|
|
// a UDP connection. Defaults to false.
|
|
Unordered bool
|
|
|
|
// Whether the channel will be left open on disconnect or not.
|
|
// If true, data will be buffered on either end to be sent
|
|
// once reconnected. Defaults to false.
|
|
OpenOnDisconnect bool
|
|
}
|
|
|
|
// Channel represents a WebRTC DataChannel.
|
|
//
|
|
// This struct wraps webrtc.DataChannel to add concurrent-safe usage,
|
|
// data bufferring, and standardized errors for connection state.
|
|
//
|
|
// It modifies the default behavior of a DataChannel by closing on
|
|
// WebRTC PeerConnection failure. This is done to emulate TCP connections.
|
|
// This option can be changed in the options when creating a Channel.
|
|
type Channel struct {
|
|
opts *ChannelOptions
|
|
|
|
conn *Conn
|
|
dc *webrtc.DataChannel
|
|
// This field can be nil. It becomes set after the DataChannel
|
|
// has been opened and is detached.
|
|
rwc datachannel.ReadWriteCloser
|
|
reader io.Reader
|
|
|
|
closed chan struct{}
|
|
closeMutex sync.Mutex
|
|
closeError error
|
|
|
|
opened chan struct{}
|
|
|
|
// sendMore is used to block Write operations on a full buffer.
|
|
// It's signaled when the buffer can accept more data.
|
|
sendMore chan struct{}
|
|
writeMutex sync.Mutex
|
|
}
|
|
|
|
// init attaches listeners to the DataChannel to detect opening,
|
|
// closing, and when the channel is ready to transmit data.
|
|
//
|
|
// This should only be called once on creation.
|
|
func (c *Channel) init() {
|
|
// WebRTC connections maintain an internal buffer that can fill when:
|
|
// 1. Data is being sent faster than it can flush.
|
|
// 2. The connection is disconnected, but data is still being sent.
|
|
//
|
|
// This applies a maximum in-memory buffer for data, and will cause
|
|
// write operations to block once the threshold is set.
|
|
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
|
|
c.dc.OnBufferedAmountLow(func() {
|
|
if c.isClosed() {
|
|
return
|
|
}
|
|
select {
|
|
case <-c.closed:
|
|
return
|
|
case c.sendMore <- struct{}{}:
|
|
default:
|
|
}
|
|
})
|
|
c.dc.OnClose(func() {
|
|
c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
|
|
_ = c.closeWithError(ErrClosed)
|
|
})
|
|
c.dc.OnOpen(func() {
|
|
c.closeMutex.Lock()
|
|
defer c.closeMutex.Unlock()
|
|
|
|
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
|
|
var err error
|
|
c.rwc, err = c.dc.Detach()
|
|
if err != nil {
|
|
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
|
|
return
|
|
}
|
|
// pion/webrtc will return an io.ErrShortBuffer when a read
|
|
// is triggerred with a buffer size less than the chunks written.
|
|
//
|
|
// This makes sense when considering UDP connections, because
|
|
// bufferring of data that has no transmit guarantees is likely
|
|
// to cause unexpected behavior.
|
|
//
|
|
// When ordered, this adds a bufio.Reader. This ensures additional
|
|
// data on TCP-like connections can be read in parts, while still
|
|
// being bufferred.
|
|
if c.opts.Unordered {
|
|
c.reader = c.rwc
|
|
} else {
|
|
// This must be the max message length otherwise a short
|
|
// buffer error can occur.
|
|
c.reader = bufio.NewReaderSize(c.rwc, maxMessageLength)
|
|
}
|
|
close(c.opened)
|
|
})
|
|
|
|
c.conn.dcDisconnectListeners.Add(1)
|
|
c.conn.dcFailedListeners.Add(1)
|
|
c.conn.dcClosedWaitGroup.Add(1)
|
|
go func() {
|
|
var err error
|
|
// A DataChannel can disconnect multiple times, so this needs to loop.
|
|
for {
|
|
select {
|
|
case <-c.conn.closedRTC:
|
|
// If this channel was closed, there's no need to close again.
|
|
err = c.conn.closeError
|
|
case <-c.conn.Closed():
|
|
// If the RTC connection closed with an error, this channel
|
|
// should end with the same one.
|
|
err = c.conn.closeError
|
|
case <-c.conn.dcDisconnectChannel:
|
|
// If the RTC connection is disconnected, we need to check if
|
|
// the DataChannel is supposed to end on disconnect.
|
|
if c.opts.OpenOnDisconnect {
|
|
continue
|
|
}
|
|
err = xerrors.Errorf("rtc disconnected. closing: %w", ErrClosed)
|
|
case <-c.conn.dcFailedChannel:
|
|
// If the RTC connection failed, close the Channel.
|
|
err = ErrFailed
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
_ = c.closeWithError(err)
|
|
}()
|
|
}
|
|
|
|
// Read blocks until data is received.
|
|
//
|
|
// This will block until the underlying DataChannel has been opened.
|
|
func (c *Channel) Read(bytes []byte) (int, error) {
|
|
if c.isClosed() {
|
|
return 0, c.closeError
|
|
}
|
|
if !c.isOpened() {
|
|
err := c.waitOpened()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
bytesRead, err := c.reader.Read(bytes)
|
|
if err != nil {
|
|
if c.isClosed() {
|
|
return 0, c.closeError
|
|
}
|
|
// An EOF always occurs when the connection is closed.
|
|
// Alternative close errors will occur first if an unexpected
|
|
// close has occurred.
|
|
if xerrors.Is(err, io.EOF) {
|
|
err = c.closeWithError(ErrClosed)
|
|
}
|
|
return bytesRead, err
|
|
}
|
|
return bytesRead, err
|
|
}
|
|
|
|
// Write sends data to the underlying DataChannel.
|
|
//
|
|
// This function will block if too much data is being sent.
|
|
// Data will buffer if the connection is temporarily disconnected,
|
|
// and will be flushed upon reconnection.
|
|
//
|
|
// If the Channel is setup to close on disconnect, any buffered
|
|
// data will be lost.
|
|
func (c *Channel) Write(bytes []byte) (n int, err error) {
|
|
if len(bytes) > maxMessageLength {
|
|
return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength)
|
|
}
|
|
|
|
c.writeMutex.Lock()
|
|
defer c.writeMutex.Unlock()
|
|
|
|
if c.isClosed() {
|
|
return 0, c.closeWithError(nil)
|
|
}
|
|
if !c.isOpened() {
|
|
err := c.waitOpened()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount {
|
|
<-c.sendMore
|
|
}
|
|
return c.rwc.Write(bytes)
|
|
}
|
|
|
|
// Close gracefully closes the DataChannel.
|
|
func (c *Channel) Close() error {
|
|
return c.closeWithError(nil)
|
|
}
|
|
|
|
// Label returns the label of the underlying DataChannel.
|
|
func (c *Channel) Label() string {
|
|
return c.dc.Label()
|
|
}
|
|
|
|
// Protocol returns the protocol of the underlying DataChannel.
|
|
func (c *Channel) Protocol() string {
|
|
return c.dc.Protocol()
|
|
}
|
|
|
|
// NetConn wraps the DataChannel in a struct fulfilling net.Conn.
|
|
// Read, Write, and Close operations can still be used on the *Channel struct.
|
|
func (c *Channel) NetConn() net.Conn {
|
|
return &fakeNetConn{
|
|
c: c,
|
|
addr: &peerAddr{},
|
|
}
|
|
}
|
|
|
|
// closeWithError closes the Channel with the error provided.
|
|
// If a graceful close occurs, the error will be nil.
|
|
func (c *Channel) closeWithError(err error) error {
|
|
c.closeMutex.Lock()
|
|
defer c.closeMutex.Unlock()
|
|
|
|
if c.isClosed() {
|
|
return c.closeError
|
|
}
|
|
|
|
c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
|
|
if err == nil {
|
|
c.closeError = ErrClosed
|
|
} else {
|
|
c.closeError = err
|
|
}
|
|
if c.rwc != nil {
|
|
_ = c.rwc.Close()
|
|
}
|
|
_ = c.dc.Close()
|
|
|
|
close(c.closed)
|
|
close(c.sendMore)
|
|
c.conn.dcDisconnectListeners.Sub(1)
|
|
c.conn.dcFailedListeners.Sub(1)
|
|
c.conn.dcClosedWaitGroup.Done()
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *Channel) isClosed() bool {
|
|
select {
|
|
case <-c.closed:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (c *Channel) isOpened() bool {
|
|
select {
|
|
case <-c.opened:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (c *Channel) waitOpened() error {
|
|
select {
|
|
case <-c.opened:
|
|
return nil
|
|
case <-c.closed:
|
|
return c.closeError
|
|
}
|
|
}
|