feat: add port-forward subcommand (#1350)

This commit is contained in:
Dean Sheather 2022-05-19 00:10:40 +10:00 committed by GitHub
parent 76fc59aa79
commit 9141be3656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1403 additions and 119 deletions

View File

@ -9,6 +9,7 @@ import (
"fmt"
"io"
"net"
"net/url"
"os"
"os/exec"
"os/user"
@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go a.sshServer.HandleConn(channel.NetConn())
case "reconnecting-pty":
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
case "dial":
go a.handleDial(ctx, channel.Label(), channel.NetConn())
default:
a.logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()),
@ -617,6 +620,70 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
}
}
// dialResponse is written to datachannels with protocol "dial" by the agent as
// the first packet to signify whether the dial succeeded or failed.
type dialResponse struct {
Error string `json:"error,omitempty"`
}
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
defer conn.Close()
writeError := func(responseError error) error {
msg := ""
if responseError != nil {
msg = responseError.Error()
if !xerrors.Is(responseError, io.EOF) {
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
}
}
b, err := json.Marshal(dialResponse{
Error: msg,
})
if err != nil {
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
}
_, err = conn.Write(b)
return err
}
u, err := url.Parse(label)
if err != nil {
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
return
}
network := u.Scheme
addr := u.Host + u.Path
if strings.HasPrefix(network, "unix") {
if runtime.GOOS == "windows" {
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
return
}
addr, err = ExpandRelativeHomePath(addr)
if err != nil {
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
return
}
}
d := net.Dialer{Timeout: 3 * time.Second}
nconn, err := d.DialContext(ctx, network, addr)
if err != nil {
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
return
}
err = writeError(nil)
if err != nil {
return
}
Bicopy(ctx, conn, nconn)
}
// isClosed returns whether the API is closed or not.
func (a *agent) isClosed() bool {
select {
@ -662,3 +729,50 @@ func (r *reconnectingPTY) Close() {
r.circularBuffer.Reset()
r.timeout.Stop()
}
// Bicopy copies all of the data between the two connections and will close them
// after one or both of them are done writing. If the context is canceled, both
// of the connections will be closed.
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()
var wg sync.WaitGroup
copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer wg.Done()
_, _ = io.Copy(dst, src)
}
wg.Add(2)
go copyFunc(c1, c2)
go copyFunc(c2, c1)
// Convert waitgroup to a channel so we can also wait on the context.
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
}()
select {
case <-ctx.Done():
case <-done:
}
}
// ExpandRelativeHomePath expands the tilde at the beginning of a path to the
// current user's home directory and returns a full absolute path.
func ExpandRelativeHomePath(in string) (string, error) {
usr, err := user.Current()
if err != nil {
return "", xerrors.Errorf("get current user details: %w", err)
}
if in == "~" {
in = usr.HomeDir
} else if strings.HasPrefix(in, "~/") {
in = filepath.Join(usr.HomeDir, in[2:])
}
return filepath.Abs(in)
}

View File

@ -16,6 +16,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/udp"
"github.com/pion/webrtc/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"
@ -234,6 +235,112 @@ func TestAgent(t *testing.T) {
findEcho()
findEcho()
})
t.Run("Dial", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
setup func(t *testing.T) net.Listener
}{
{
name: "TCP",
setup: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
},
},
{
name: "UDP",
setup: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
},
{
name: "Unix",
setup: func(t *testing.T) net.Listener {
if runtime.GOOS == "windows" {
t.Skip("Unix socket forwarding isn't supported on Windows")
}
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir for unix listener")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
require.NoError(t, err, "create UDP listener")
return l
},
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
// Setup listener
l := c.setup(t)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
go testAccept(t, c)
}
}()
// Dial the listener over WebRTC twice and test out of order
conn := setupAgent(t, agent.Metadata{}, 0)
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
testDial(t, conn1)
})
}
})
t.Run("DialError", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// This test uses Unix listeners so we can very easily ensure that
// no other tests decide to listen on the same random port we
// picked.
t.Skip("this test is unsupported on Windows")
return
}
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})
// Try to dial the non-existent Unix socket over WebRTC
conn := setupAgent(t, agent.Metadata{}, 0)
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
require.Error(t, err)
require.ErrorContains(t, err, "remote dial error")
require.ErrorContains(t, err, "no such file")
require.Nil(t, netConn)
})
}
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
@ -303,3 +410,34 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
Conn: conn,
}
}
var dialTestPayload = []byte("dean-was-here123")
func testDial(t *testing.T, c net.Conn) {
t.Helper()
assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
}
func testAccept(t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
}
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
require.NoError(t, err, "read payload")
require.Equal(t, len(payload), n, "read payload length does not match")
require.Equal(t, payload, b[:n])
}
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
n, err := w.Write(payload)
require.NoError(t, err, "write payload")
require.Equal(t, len(payload), n, "payload length does not match")
}

View File

@ -2,8 +2,11 @@ package agent
import (
"context"
"encoding/json"
"fmt"
"net"
"net/url"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
@ -32,7 +35,7 @@ type Conn struct {
// ReconnectingPTY returns a connection serving a TTY that can
// be reconnected to via ID.
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
Protocol: "reconnecting-pty",
})
if err != nil {
@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error
// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
@ -71,6 +74,42 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil
}
// DialContext dials an arbitrary protocol+address from inside the workspace and
// proxies it through the provided net.Conn.
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
u := &url.URL{
Scheme: network,
}
if strings.HasPrefix(network, "unix") {
u.Path = addr
} else {
u.Host = addr
}
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
Protocol: "dial",
Unordered: strings.HasPrefix(network, "udp"),
})
if err != nil {
return nil, xerrors.Errorf("create datachannel: %w", err)
}
// The first message written from the other side is a JSON payload
// containing the dial error.
dec := json.NewDecoder(channel)
var res dialResponse
err = dec.Decode(&res)
if err != nil {
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
}
if res.Error != "" {
_ = channel.Close()
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
}
return channel.NetConn(), nil
}
func (c *Conn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()

379
cli/portforward.go Normal file
View File

@ -0,0 +1,379 @@
package cli
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"github.com/pion/udp"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
coderagent "github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
)
func portForward() *cobra.Command {
var (
tcpForwards []string // <port>:<port>
udpForwards []string // <port>:<port>
unixForwards []string // <path>:<path> OR <port>:<path>
)
cmd := &cobra.Command{
Use: "port-forward <workspace>",
Aliases: []string{"tunnel"},
Args: cobra.ExactArgs(1),
Example: `
- Port forward a single TCP port from 1234 in the workspace to port 5678 on
your local machine
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 5678:1234") + `
- Port forward a single UDP port from port 9000 to port 9000 on your local
machine
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --udp 9000") + `
- Forward a Unix socket in the workspace to a local Unix socket
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --unix ./local.sock:~/remote.sock") + `
- Forward a Unix socket in the workspace to a local TCP port
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --unix 8080:~/remote.sock") + `
- Port forward multiple TCP ports and a UDP port
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53"),
RunE: func(cmd *cobra.Command, args []string) error {
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
}
if len(specs) == 0 {
err = cmd.Help()
if err != nil {
return xerrors.Errorf("generate help output: %w", err)
}
return xerrors.New("no port-forwards requested")
}
client, err := createClient(cmd)
if err != nil {
return err
}
organization, err := currentOrganization(cmd, client)
if err != nil {
return err
}
workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], false)
if err != nil {
return err
}
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to port-forward")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, agent.ID)
},
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)
}
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
}
defer conn.Close()
// Start all listeners.
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
for _, l := range listeners {
if l == nil {
continue
}
_ = l.Close()
}
}
)
defer cancel()
for i, spec := range specs {
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
if err != nil {
closeAllListeners()
return err
}
listeners[i] = l
}
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
select {
case <-ctx.Done():
closeErr = ctx.Err()
case <-sigs:
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
closeErr = xerrors.New("signal received")
}
cancel()
closeAllListeners()
}()
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
wg.Wait()
return closeErr
},
}
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port")
return cmd
}
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
var (
l net.Listener
err error
)
switch spec.listenNetwork {
case "tcp":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
case "udp":
var host, port string
host, port, err = net.SplitHostPort(spec.listenAddress)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err)
}
var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
}
l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
IP: net.ParseIP(host),
Port: portInt,
})
case "unix":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
default:
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
}
if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
}
wg.Add(1)
go func(spec portForwardSpec) {
defer wg.Done()
for {
netConn, err := l.Accept()
if err != nil {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
return
}
go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
if err != nil {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
return
}
defer remoteConn.Close()
coderagent.Bicopy(ctx, netConn, remoteConn)
}(netConn)
}
}(spec)
return l, nil
}
type portForwardSpec struct {
listenNetwork string // tcp, udp, unix
listenAddress string // <ip>:<port> or path
dialNetwork string // tcp, udp, unix
dialAddress string // <ip>:<port> or path
}
func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) {
specs := []portForwardSpec{}
for _, spec := range tcpSpecs {
local, remote, err := parsePortPort(spec)
if err != nil {
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
}
specs = append(specs, portForwardSpec{
listenNetwork: "tcp",
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
dialNetwork: "tcp",
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
})
}
for _, spec := range udpSpecs {
local, remote, err := parsePortPort(spec)
if err != nil {
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
}
specs = append(specs, portForwardSpec{
listenNetwork: "udp",
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
dialNetwork: "udp",
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
})
}
for _, specStr := range unixSpecs {
localPath, localTCP, remotePath, err := parseUnixUnix(specStr)
if err != nil {
return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err)
}
spec := portForwardSpec{
dialNetwork: "unix",
dialAddress: remotePath,
}
if localPath == "" {
spec.listenNetwork = "tcp"
spec.listenAddress = fmt.Sprintf("127.0.0.1:%v", localTCP)
} else {
if runtime.GOOS == "windows" {
return nil, xerrors.Errorf("Unix port-forwarding is not supported on Windows")
}
spec.listenNetwork = "unix"
spec.listenAddress = localPath
}
specs = append(specs, spec)
}
// Check for duplicate entries.
locals := map[string]struct{}{}
for _, spec := range specs {
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
if _, ok := locals[localStr]; ok {
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
}
locals[localStr] = struct{}{}
}
return specs, nil
}
func parsePort(in string) (uint16, error) {
port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16)
if err != nil {
return 0, xerrors.Errorf("parse port %q: %w", in, err)
}
if port == 0 {
return 0, xerrors.New("port cannot be 0")
}
return uint16(port), nil
}
func parseUnixPath(in string) (string, error) {
path, err := coderagent.ExpandRelativeHomePath(strings.TrimSpace(in))
if err != nil {
return "", xerrors.Errorf("tidy path %q: %w", in, err)
}
return path, nil
}
func parsePortPort(in string) (local uint16, remote uint16, err error) {
parts := strings.Split(in, ":")
if len(parts) > 2 {
return 0, 0, xerrors.Errorf("invalid port specification %q", in)
}
if len(parts) == 1 {
// Duplicate the single part
parts = append(parts, parts[0])
}
local, err = parsePort(parts[0])
if err != nil {
return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err)
}
remote, err = parsePort(parts[1])
if err != nil {
return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err)
}
return local, remote, nil
}
func parsePortOrUnixPath(in string) (string, uint16, error) {
port, err := parsePort(in)
if err == nil {
return "", port, nil
}
path, err := parseUnixPath(in)
if err != nil {
return "", 0, xerrors.Errorf("could not parse port or unix path %q: %w", in, err)
}
return path, 0, nil
}
func parseUnixUnix(in string) (string, uint16, string, error) {
parts := strings.Split(in, ":")
if len(parts) > 2 {
return "", 0, "", xerrors.Errorf("invalid port-forward specification %q", in)
}
if len(parts) == 1 {
// Duplicate the single part
parts = append(parts, parts[0])
}
localPath, localPort, err := parsePortOrUnixPath(parts[0])
if err != nil {
return "", 0, "", xerrors.Errorf("parse local part of spec %q: %w", in, err)
}
// We don't really touch the remote path at all since it gets cleaned
// up/expanded on the remote.
return localPath, localPort, parts[1], nil
}

532
cli/portforward_test.go Normal file
View File

@ -0,0 +1,532 @@
package cli_test
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/pion/udp"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)
func TestPortForward(t *testing.T) {
t.Parallel()
t.Run("None", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "port-forward", "blah")
clitest.SetupConfig(t, client, root)
buf := newThreadSafeBuffer()
cmd.SetOut(buf)
err := cmd.Execute()
require.Error(t, err)
require.ErrorContains(t, err, "no port-forwards")
// Check that the help was printed.
require.Contains(t, buf.String(), "port-forward <workspace>")
})
cases := []struct {
name string
network string
// The flag to pass to `coder port-forward X` to port-forward this type
// of connection. Has two format args (both strings), the first is the
// local address and the second is the remote address.
flag string
// setupRemote creates a "remote" listener to emulate a service in the
// workspace.
setupRemote func(t *testing.T) net.Listener
// setupLocal returns an available port or Unix socket path that the
// port-forward command will listen on "locally". Returns the address
// you pass to net.Dial, and the port/path you pass to `coder
// port-forward`.
setupLocal func(t *testing.T) (string, string)
}{
{
name: "TCP",
network: "tcp",
flag: "--tcp=%v:%v",
setupRemote: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener")
return l
},
setupLocal: func(t *testing.T) (string, string) {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener to generate random port")
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoErrorf(t, err, "split TCP address %q", l.Addr().String())
return l.Addr().String(), port
},
},
{
name: "UDP",
network: "udp",
flag: "--udp=%v:%v",
setupRemote: func(t *testing.T) net.Listener {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener")
return l
},
setupLocal: func(t *testing.T) (string, string) {
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener to generate random port")
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoErrorf(t, err, "split UDP address %q", l.Addr().String())
return l.Addr().String(), port
},
},
{
name: "Unix",
network: "unix",
flag: "--unix=%v:%v",
setupRemote: func(t *testing.T) net.Listener {
if runtime.GOOS == "windows" {
t.Skip("Unix socket forwarding isn't supported on Windows")
}
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir for unix listener")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
require.NoError(t, err, "create UDP listener")
return l
},
setupLocal: func(t *testing.T) (string, string) {
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
require.NoError(t, err, "create temp dir for unix listener")
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})
path := filepath.Join(tmpDir, "test.sock")
return path, path
},
},
}
for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
t.Run("OnePort", func(t *testing.T) {
t.Parallel()
var (
client = coderdtest.New(t, nil)
user = coderdtest.CreateFirstUser(t, client)
_, workspace = runAgent(t, client, user.UserID)
l1, p1 = setupTestListener(t, c.setupRemote(t))
)
t.Cleanup(func() {
_ = l1.Close()
})
// Create a flag that forwards from local to listener 1.
localAddress, localFlag := c.setupLocal(t)
flag := fmt.Sprintf(c.flag, localFlag, p1)
// Launch port-forward in a goroutine so we can start dialing
// the "local" listener.
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
clitest.SetupConfig(t, client, root)
buf := newThreadSafeBuffer()
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := cmd.ExecuteContext(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}()
waitForPortForwardReady(t, buf)
// Open two connections simultaneously and test them out of
// sync.
d := net.Dialer{Timeout: 3 * time.Second}
c1, err := d.DialContext(ctx, c.network, localAddress)
require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close()
c2, err := d.DialContext(ctx, c.network, localAddress)
require.NoError(t, err, "open connection 2 to 'local' listener")
defer c2.Close()
testDial(t, c2)
testDial(t, c1)
})
t.Run("TwoPorts", func(t *testing.T) {
t.Parallel()
var (
client = coderdtest.New(t, nil)
user = coderdtest.CreateFirstUser(t, client)
_, workspace = runAgent(t, client, user.UserID)
l1, p1 = setupTestListener(t, c.setupRemote(t))
l2, p2 = setupTestListener(t, c.setupRemote(t))
)
t.Cleanup(func() {
_ = l1.Close()
_ = l2.Close()
})
// Create a flags for listener 1 and listener 2.
localAddress1, localFlag1 := c.setupLocal(t)
localAddress2, localFlag2 := c.setupLocal(t)
flag1 := fmt.Sprintf(c.flag, localFlag1, p1)
flag2 := fmt.Sprintf(c.flag, localFlag2, p2)
// Launch port-forward in a goroutine so we can start dialing
// the "local" listeners.
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2)
clitest.SetupConfig(t, client, root)
buf := newThreadSafeBuffer()
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := cmd.ExecuteContext(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}()
waitForPortForwardReady(t, buf)
// Open a connection to both listener 1 and 2 simultaneously and
// then test them out of order.
d := net.Dialer{Timeout: 3 * time.Second}
c1, err := d.DialContext(ctx, c.network, localAddress1)
require.NoError(t, err, "open connection 1 to 'local' listener 1")
defer c1.Close()
c2, err := d.DialContext(ctx, c.network, localAddress2)
require.NoError(t, err, "open connection 2 to 'local' listener 2")
defer c2.Close()
testDial(t, c2)
testDial(t, c1)
})
})
}
// Test doing a TCP -> Unix forward.
t.Run("TCP2Unix", func(t *testing.T) {
t.Parallel()
var (
client = coderdtest.New(t, nil)
user = coderdtest.CreateFirstUser(t, client)
_, workspace = runAgent(t, client, user.UserID)
// Find the TCP and Unix cases so we can use their setupLocal and
// setupRemote methods respectively.
tcpCase = cases[0]
unixCase = cases[2]
// Setup remote Unix listener.
l1, p1 = setupTestListener(t, unixCase.setupRemote(t))
)
t.Cleanup(func() {
_ = l1.Close()
})
// Create a flag that forwards from local TCP to Unix listener 1.
// Notably this is a --unix flag.
localAddress, localFlag := tcpCase.setupLocal(t)
flag := fmt.Sprintf(unixCase.flag, localFlag, p1)
// Launch port-forward in a goroutine so we can start dialing
// the "local" listener.
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
clitest.SetupConfig(t, client, root)
buf := newThreadSafeBuffer()
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := cmd.ExecuteContext(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}()
waitForPortForwardReady(t, buf)
// Open two connections simultaneously and test them out of
// sync.
d := net.Dialer{Timeout: 3 * time.Second}
c1, err := d.DialContext(ctx, tcpCase.network, localAddress)
require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close()
c2, err := d.DialContext(ctx, tcpCase.network, localAddress)
require.NoError(t, err, "open connection 2 to 'local' listener")
defer c2.Close()
testDial(t, c2)
testDial(t, c1)
})
// Test doing TCP, UDP and Unix at the same time.
t.Run("All", func(t *testing.T) {
t.Parallel()
var (
client = coderdtest.New(t, nil)
user = coderdtest.CreateFirstUser(t, client)
_, workspace = runAgent(t, client, user.UserID)
// These aren't fixed size because we exclude Unix on Windows.
dials = []addr{}
flags = []string{}
)
// Start listeners and populate arrays with the cases.
for _, c := range cases {
if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" {
// Unix isn't supported on Windows, but we can still
// test other protocols together.
continue
}
l, p := setupTestListener(t, c.setupRemote(t))
t.Cleanup(func() {
_ = l.Close()
})
localAddress, localFlag := c.setupLocal(t)
dials = append(dials, addr{
network: c.network,
addr: localAddress,
})
flags = append(flags, fmt.Sprintf(c.flag, localFlag, p))
}
// Launch port-forward in a goroutine so we can start dialing
// the "local" listeners.
cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...)
clitest.SetupConfig(t, client, root)
buf := newThreadSafeBuffer()
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err := cmd.ExecuteContext(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
}()
waitForPortForwardReady(t, buf)
// Open connections to all items in the "dial" array.
var (
d = net.Dialer{Timeout: 3 * time.Second}
conns = make([]net.Conn, len(dials))
)
for i, a := range dials {
c, err := d.DialContext(ctx, a.network, a.addr)
require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1)
t.Cleanup(func() {
_ = c.Close()
})
conns[i] = c
}
// Test each connection in reverse order.
for i := len(conns) - 1; i >= 0; i-- {
testDial(t, conns[i])
}
})
}
// runAgent creates a fake workspace and starts an agent locally for that
// workspace. The agent will be cleaned up on test completion.
func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]codersdk.WorkspaceResource, codersdk.Workspace) {
ctx := context.Background()
user, err := client.User(ctx, userID.String())
require.NoError(t, err, "specified user does not exist")
require.Greater(t, len(user.OrganizationIDs), 0, "user has no organizations")
orgID := user.OrganizationIDs[0]
// Setup echo provisioner
agentToken := uuid.NewString()
coderdtest.NewProvisionerDaemon(t, client)
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "somename",
Type: "someinstance",
Agents: []*proto.Agent{{
Auth: &proto.Agent_Token{
Token: agentToken,
},
}},
}},
},
},
}},
})
// Create template and workspace
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
// Start workspace agent in a goroutine
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
clitest.SetupConfig(t, client, root)
agentCtx, agentCancel := context.WithCancel(ctx)
t.Cleanup(agentCancel)
go func() {
err := cmd.ExecuteContext(agentCtx)
require.NoError(t, err)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
require.NoError(t, err)
return resources, workspace
}
// setupTestListener starts accepting connections and echoing a single packet.
// Returns the listener and the listen port or Unix path.
func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) {
t.Cleanup(func() {
_ = l.Close()
})
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
go testAccept(t, c)
}
}()
addr := l.Addr().String()
if !strings.HasPrefix(l.Addr().Network(), "unix") {
_, port, err := net.SplitHostPort(addr)
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
addr = port
}
return l, addr
}
var dialTestPayload = []byte("dean-was-here123")
func testDial(t *testing.T, c net.Conn) {
t.Helper()
assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
}
func testAccept(t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
}
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
require.NoError(t, err, "read payload")
require.Equal(t, len(payload), n, "read payload length does not match")
require.Equal(t, payload, b[:n])
}
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
n, err := w.Write(payload)
require.NoError(t, err, "write payload")
require.Equal(t, len(payload), n, "payload length does not match")
}
func waitForPortForwardReady(t *testing.T, output *threadSafeBuffer) {
for i := 0; i < 100; i++ {
time.Sleep(250 * time.Millisecond)
data := output.String()
if strings.Contains(data, "Ready!") {
return
}
}
t.Fatal("port-forward command did not become ready in time")
}
type addr struct {
network string
addr string
}
type threadSafeBuffer struct {
b *bytes.Buffer
mut *sync.RWMutex
}
func newThreadSafeBuffer() *threadSafeBuffer {
return &threadSafeBuffer{
b: bytes.NewBuffer(nil),
mut: new(sync.RWMutex),
}
}
var _ io.Reader = &threadSafeBuffer{}
var _ io.Writer = &threadSafeBuffer{}
// Read implements io.Reader.
func (b *threadSafeBuffer) Read(p []byte) (int, error) {
b.mut.RLock()
defer b.mut.RUnlock()
return b.b.Read(p)
}
// Write implements io.Writer.
func (b *threadSafeBuffer) Write(p []byte) (int, error) {
b.mut.Lock()
defer b.mut.Unlock()
return b.b.Write(p)
}
func (b *threadSafeBuffer) String() string {
b.mut.RLock()
defer b.mut.RUnlock()
return b.b.String()
}

View File

@ -36,6 +36,15 @@ const (
varForceTty = "force-tty"
)
func init() {
// Customizes the color of headings to make subcommands more visually
// appealing.
header := cliui.Styles.Placeholder
cobra.AddTemplateFunc("usageHeader", func(s string) string {
return header.Render(s)
})
}
func Root() *cobra.Command {
cmd := &cobra.Command{
Use: "coder",
@ -71,7 +80,7 @@ func Root() *cobra.Command {
templates(),
update(),
users(),
tunnel(),
portForward(),
workspaceAgent(),
)
@ -179,13 +188,7 @@ func isTTY(cmd *cobra.Command) bool {
}
func usageTemplate() string {
// Customizes the color of headings to make subcommands
// more visually appealing.
header := cliui.Styles.Placeholder
cobra.AddTemplateFunc("usageHeader", func(s string) string {
return header.Render(s)
})
// usageHeader is defined in init().
return `{{usageHeader "Usage:"}}
{{- if .Runnable}}
{{.UseLine}}

View File

@ -50,94 +50,23 @@ func ssh() *cobra.Command {
return err
}
var workspace codersdk.Workspace
var workspaceParts []string
if shuffle {
err := cobra.ExactArgs(0)(cmd, args)
if err != nil {
return err
}
workspaces, err := client.WorkspacesByOwner(cmd.Context(), organization.ID, codersdk.Me)
if err != nil {
return err
}
if len(workspaces) == 0 {
return xerrors.New("no workspaces to shuffle")
}
idx, err := cryptorand.Intn(len(workspaces))
if err != nil {
return err
}
workspace = workspaces[idx]
} else {
err := cobra.MinimumNArgs(1)(cmd, args)
if err != nil {
return err
}
workspaceParts = strings.Split(args[0], ".")
workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), organization.ID, codersdk.Me, workspaceParts[0])
if err != nil {
return err
}
}
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to ssh")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
return xerrors.New("workspace is deleting...")
}
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], shuffle)
if err != nil {
return err
}
agents := make([]codersdk.WorkspaceAgent, 0)
for _, resource := range resources {
agents = append(agents, resource.Agents...)
}
if len(agents) == 0 {
return xerrors.New("workspace has no agents")
}
var agent codersdk.WorkspaceAgent
if len(workspaceParts) >= 2 {
for _, otherAgent := range agents {
if otherAgent.Name != workspaceParts[1] {
continue
}
agent = otherAgent
break
}
if agent.ID == uuid.Nil {
return xerrors.Errorf("agent not found by name %q", workspaceParts[1])
}
}
if agent.ID == uuid.Nil {
if len(agents) > 1 {
if !shuffle {
return xerrors.New("you must specify the name of an agent")
}
idx, err := cryptorand.Intn(len(agents))
if err != nil {
return err
}
agent = agents[idx]
} else {
agent = agents[0]
}
}
// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
@ -233,6 +162,92 @@ func ssh() *cobra.Command {
return cmd
}
// getWorkspaceAgent returns the workspace and agent selected using either the
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
// if `shuffle` is true.
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, orgID uuid.UUID, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
ctx := cmd.Context()
var (
workspace codersdk.Workspace
workspaceParts = strings.Split(in, ".")
err error
)
if shuffle {
workspaces, err := client.WorkspacesByOwner(cmd.Context(), orgID, userID)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
if len(workspaces) == 0 {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("no workspaces to shuffle")
}
workspace, err = cryptorand.Element(workspaces)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
} else {
workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), orgID, userID, workspaceParts[0])
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
}
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err := cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
}
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name)
}
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("fetch workspace resources: %w", err)
}
agents := make([]codersdk.WorkspaceAgent, 0)
for _, resource := range resources {
agents = append(agents, resource.Agents...)
}
if len(agents) == 0 {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name)
}
var agent codersdk.WorkspaceAgent
if len(workspaceParts) >= 2 {
for _, otherAgent := range agents {
if otherAgent.Name != workspaceParts[1] {
continue
}
agent = otherAgent
break
}
if agent.ID == uuid.Nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1])
}
}
if agent.ID == uuid.Nil {
if len(agents) > 1 {
if !shuffle {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent")
}
agent, err = cryptorand.Element(agents)
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
}
} else {
agent = agents[0]
}
}
return workspace, agent, nil
}
// Attempt to poll workspace autostop. We write a per-workspace lockfile to
// avoid spamming the user with notifications in case of multiple instances
// of the CLI running simultaneously.

View File

@ -1,8 +1,9 @@
package cli
import (
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/coder/coder/cli/cliui"
)
func templates() *cobra.Command {
@ -13,15 +14,15 @@ func templates() *cobra.Command {
Example: `
- Create a template for developers to create workspaces
` + color.New(color.FgHiMagenta).Sprint("$ coder templates create") + `
` + cliui.Styles.Code.Render("$ coder templates create") + `
- Make changes to your template, and plan the changes
` + color.New(color.FgHiMagenta).Sprint("$ coder templates plan <name>") + `
` + cliui.Styles.Code.Render("$ coder templates plan <name>") + `
- Update the template. Your developers can update their workspaces
` + color.New(color.FgHiMagenta).Sprint("$ coder templates update <name>"),
` + cliui.Styles.Code.Render("$ coder templates update <name>"),
}
cmd.AddCommand(
templateCreate(),

View File

@ -1,14 +0,0 @@
package cli
import "github.com/spf13/cobra"
func tunnel() *cobra.Command {
return &cobra.Command{
Annotations: workspaceCommand,
Use: "tunnel",
Short: "Forward ports to your local machine",
RunE: func(cmd *cobra.Command, args []string) error {
return nil
},
}
}

20
cryptorand/slices.go Normal file
View File

@ -0,0 +1,20 @@
package cryptorand
import (
"golang.org/x/xerrors"
)
// Element returns a random element of the slice. An error will be returned if
// the slice has no elements in it.
func Element[T any](s []T) (out T, err error) {
if len(s) == 0 {
return out, xerrors.New("slice must have at least one element")
}
i, err := Intn(len(s))
if err != nil {
return out, xerrors.Errorf("generate random integer from 0-%v: %w", len(s), err)
}
return s[i], nil
}

56
cryptorand/slices_test.go Normal file
View File

@ -0,0 +1,56 @@
package cryptorand_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cryptorand"
)
func TestRandomElement(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
s := []string{}
v, err := cryptorand.Element(s)
require.Error(t, err)
require.ErrorContains(t, err, "slice must have at least one element")
require.Empty(t, v)
})
t.Run("OK", func(t *testing.T) {
t.Parallel()
// Generate random slices of ints and strings
var (
ints = make([]int, 20)
strings = make([]string, 20)
)
for i := range ints {
v, err := cryptorand.Intn(1024)
require.NoError(t, err, "generate random int for test slice")
ints[i] = v
}
for i := range strings {
v, err := cryptorand.String(10)
require.NoError(t, err, "generate random string for test slice")
strings[i] = v
}
// Get a random value from each 20 times.
for i := 0; i < 20; i++ {
iv, err := cryptorand.Element(ints)
require.NoError(t, err, "unexpected error from Element(ints)")
t.Logf("random int slice element: %v", iv)
require.Contains(t, ints, iv)
sv, err := cryptorand.Element(strings)
require.NoError(t, err, "unexpected error from Element(strings)")
t.Logf("random string slice element: %v", sv)
require.Contains(t, strings, sv)
}
})
}

2
go.mod
View File

@ -90,6 +90,7 @@ require (
github.com/pion/logging v0.2.2
github.com/pion/transport v0.13.0
github.com/pion/turn/v2 v2.0.8
github.com/pion/udp v0.1.1
github.com/pion/webrtc/v3 v3.1.39
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
github.com/pkg/sftp v1.13.4
@ -217,7 +218,6 @@ require (
github.com/pion/sdp/v3 v3.0.5 // indirect
github.com/pion/srtp/v2 v2.0.7 // indirect
github.com/pion/stun v0.3.5 // indirect
github.com/pion/udp v0.1.1 // indirect
github.com/pires/go-proxyproto v0.6.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect

View File

@ -53,8 +53,8 @@ type ChannelOptions struct {
// Arbitrary string that can be parsed on `Accept`.
Protocol string
// Ordered determines whether the channel acts like
// a TCP connection. Defaults to false.
// Unordered determines whether the channel acts like
// a UDP connection. Defaults to false.
Unordered bool
// Whether the channel will be left open on disconnect or not.

View File

@ -68,7 +68,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
closed: make(chan struct{}),
closedRTC: make(chan struct{}),
closedICE: make(chan struct{}),
dcOpenChannel: make(chan *webrtc.DataChannel),
dcOpenChannel: make(chan *webrtc.DataChannel, 8),
dcDisconnectChannel: make(chan struct{}),
dcFailedChannel: make(chan struct{}),
localCandidateChannel: make(chan webrtc.ICECandidateInit),
@ -264,12 +264,13 @@ func (c *Conn) init() error {
}()
})
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
select {
case <-c.closed:
return
case c.dcOpenChannel <- dc:
default:
}
go func() {
select {
case <-c.closed:
return
case c.dcOpenChannel <- dc:
}
}()
})
_, err := c.pingChannel()
if err != nil {
@ -469,8 +470,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
return newChannel(c, dataChannel, &ChannelOptions{}), nil
}
// Dial creates a new DataChannel.
func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
// CreateChannel creates a new DataChannel.
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
if opts == nil {
opts = &ChannelOptions{}
}

View File

@ -90,7 +90,7 @@ func TestConn(t *testing.T) {
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
channel, err := server.Dial(context.Background(), "wow", nil)
channel, err := server.CreateChannel(context.Background(), "wow", nil)
assert.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
@ -108,7 +108,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
@ -124,7 +124,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
@ -141,7 +141,7 @@ func TestConn(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
@ -196,7 +196,7 @@ func TestConn(t *testing.T) {
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.Dial(ctx, "hello", &peer.ChannelOptions{})
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
if err != nil {
return nil, err
}
@ -234,7 +234,7 @@ func TestConn(t *testing.T) {
require.NoError(t, err)
expectedErr := xerrors.New("wow")
_ = conn.CloseWithError(expectedErr)
_, err = conn.Dial(context.Background(), "", nil)
_, err = conn.CreateChannel(context.Background(), "", nil)
require.ErrorIs(t, err, expectedErr)
})
@ -274,7 +274,7 @@ func TestConn(t *testing.T) {
client, server, _ := createPair(t)
exchange(t, client, server)
go func() {
channel, err := client.Dial(context.Background(), "test", nil)
channel, err := client.CreateChannel(context.Background(), "test", nil)
require.NoError(t, err)
_, err = channel.Write([]byte{1, 2})
require.NoError(t, err)