diff --git a/agent/agent.go b/agent/agent.go index b946166056..75787b4cfc 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "net/url" "os" "os/exec" "os/user" @@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { go a.sshServer.HandleConn(channel.NetConn()) case "reconnecting-pty": go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn()) + case "dial": + go a.handleDial(ctx, channel.Label(), channel.NetConn()) default: a.logger.Warn(ctx, "unhandled protocol from channel", slog.F("protocol", channel.Protocol()), @@ -617,6 +620,70 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne } } +// dialResponse is written to datachannels with protocol "dial" by the agent as +// the first packet to signify whether the dial succeeded or failed. +type dialResponse struct { + Error string `json:"error,omitempty"` +} + +func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) { + defer conn.Close() + + writeError := func(responseError error) error { + msg := "" + if responseError != nil { + msg = responseError.Error() + if !xerrors.Is(responseError, io.EOF) { + a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError)) + } + } + b, err := json.Marshal(dialResponse{ + Error: msg, + }) + if err != nil { + a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err)) + return xerrors.Errorf("marshal agent webrtc dial response: %w", err) + } + + _, err = conn.Write(b) + return err + } + + u, err := url.Parse(label) + if err != nil { + _ = writeError(xerrors.Errorf("parse URL %q: %w", label, err)) + return + } + + network := u.Scheme + addr := u.Host + u.Path + if strings.HasPrefix(network, "unix") { + if runtime.GOOS == "windows" { + _ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces")) + return + } + addr, err = ExpandRelativeHomePath(addr) + if err != nil { + _ = writeError(xerrors.Errorf("expand path %q: %w", addr, err)) + return + } + } + + d := net.Dialer{Timeout: 3 * time.Second} + nconn, err := d.DialContext(ctx, network, addr) + if err != nil { + _ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err)) + return + } + + err = writeError(nil) + if err != nil { + return + } + + Bicopy(ctx, conn, nconn) +} + // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { select { @@ -662,3 +729,50 @@ func (r *reconnectingPTY) Close() { r.circularBuffer.Reset() r.timeout.Stop() } + +// Bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + defer c1.Close() + defer c2.Close() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer wg.Done() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} + +// ExpandRelativeHomePath expands the tilde at the beginning of a path to the +// current user's home directory and returns a full absolute path. +func ExpandRelativeHomePath(in string) (string, error) { + usr, err := user.Current() + if err != nil { + return "", xerrors.Errorf("get current user details: %w", err) + } + + if in == "~" { + in = usr.HomeDir + } else if strings.HasPrefix(in, "~/") { + in = filepath.Join(usr.HomeDir, in[2:]) + } + + return filepath.Abs(in) +} diff --git a/agent/agent_test.go b/agent/agent_test.go index bd26fae7f0..2a8a2224ab 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/google/uuid" + "github.com/pion/udp" "github.com/pion/webrtc/v3" "github.com/pkg/sftp" "github.com/stretchr/testify/require" @@ -234,6 +235,112 @@ func TestAgent(t *testing.T) { findEcho() findEcho() }) + + t.Run("Dial", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + setup func(t *testing.T) net.Listener + }{ + { + name: "TCP", + setup: func(t *testing.T) net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + return l + }, + }, + { + name: "UDP", + setup: func(t *testing.T) net.Listener { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener") + return l + }, + }, + { + name: "Unix", + setup: func(t *testing.T) net.Listener { + if runtime.GOOS == "windows" { + t.Skip("Unix socket forwarding isn't supported on Windows") + } + + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir for unix listener") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) + require.NoError(t, err, "create UDP listener") + return l + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + // Setup listener + l := c.setup(t) + defer l.Close() + go func() { + for { + c, err := l.Accept() + if err != nil { + return + } + + go testAccept(t, c) + } + }() + + // Dial the listener over WebRTC twice and test out of order + conn := setupAgent(t, agent.Metadata{}, 0) + conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) + require.NoError(t, err) + defer conn1.Close() + conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) + require.NoError(t, err) + defer conn2.Close() + testDial(t, conn2) + testDial(t, conn1) + }) + } + }) + + t.Run("DialError", func(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + // This test uses Unix listeners so we can very easily ensure that + // no other tests decide to listen on the same random port we + // picked. + t.Skip("this test is unsupported on Windows") + return + } + + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + // Try to dial the non-existent Unix socket over WebRTC + conn := setupAgent(t, agent.Metadata{}, 0) + netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock")) + require.Error(t, err) + require.ErrorContains(t, err, "remote dial error") + require.ErrorContains(t, err, "no such file") + require.Nil(t, netConn) + }) } func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd { @@ -303,3 +410,34 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) Conn: conn, } } + +var dialTestPayload = []byte("dean-was-here123") + +func testDial(t *testing.T, c net.Conn) { + t.Helper() + + assertWritePayload(t, c, dialTestPayload) + assertReadPayload(t, c, dialTestPayload) +} + +func testAccept(t *testing.T, c net.Conn) { + t.Helper() + defer c.Close() + + assertReadPayload(t, c, dialTestPayload) + assertWritePayload(t, c, dialTestPayload) +} + +func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { + b := make([]byte, len(payload)+16) + n, err := r.Read(b) + require.NoError(t, err, "read payload") + require.Equal(t, len(payload), n, "read payload length does not match") + require.Equal(t, payload, b[:n]) +} + +func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { + n, err := w.Write(payload) + require.NoError(t, err, "write payload") + require.Equal(t, len(payload), n, "payload length does not match") +} diff --git a/agent/conn.go b/agent/conn.go index 81a6315af2..56d3d42ea1 100644 --- a/agent/conn.go +++ b/agent/conn.go @@ -2,8 +2,11 @@ package agent import ( "context" + "encoding/json" "fmt" "net" + "net/url" + "strings" "golang.org/x/crypto/ssh" "golang.org/x/xerrors" @@ -32,7 +35,7 @@ type Conn struct { // ReconnectingPTY returns a connection serving a TTY that can // be reconnected to via ID. func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) { - channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ + channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{ Protocol: "reconnecting-pty", }) if err != nil { @@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error // SSH dials the built-in SSH server. func (c *Conn) SSH() (net.Conn, error) { - channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{ + channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{ Protocol: "ssh", }) if err != nil { @@ -71,6 +74,42 @@ func (c *Conn) SSHClient() (*ssh.Client, error) { return ssh.NewClient(sshConn, channels, requests), nil } +// DialContext dials an arbitrary protocol+address from inside the workspace and +// proxies it through the provided net.Conn. +func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + u := &url.URL{ + Scheme: network, + } + if strings.HasPrefix(network, "unix") { + u.Path = addr + } else { + u.Host = addr + } + + channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{ + Protocol: "dial", + Unordered: strings.HasPrefix(network, "udp"), + }) + if err != nil { + return nil, xerrors.Errorf("create datachannel: %w", err) + } + + // The first message written from the other side is a JSON payload + // containing the dial error. + dec := json.NewDecoder(channel) + var res dialResponse + err = dec.Decode(&res) + if err != nil { + return nil, xerrors.Errorf("failed to decode initial packet: %w", err) + } + if res.Error != "" { + _ = channel.Close() + return nil, xerrors.Errorf("remote dial error: %v", res.Error) + } + + return channel.NetConn(), nil +} + func (c *Conn) Close() error { _ = c.Negotiator.DRPCConn().Close() return c.Conn.Close() diff --git a/cli/portforward.go b/cli/portforward.go new file mode 100644 index 0000000000..51206687f9 --- /dev/null +++ b/cli/portforward.go @@ -0,0 +1,379 @@ +package cli + +import ( + "context" + "fmt" + "net" + "os" + "os/signal" + "runtime" + "strconv" + "strings" + "sync" + "syscall" + + "github.com/pion/udp" + "github.com/spf13/cobra" + "golang.org/x/xerrors" + + coderagent "github.com/coder/coder/agent" + "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/codersdk" +) + +func portForward() *cobra.Command { + var ( + tcpForwards []string // : + udpForwards []string // : + unixForwards []string // : OR : + ) + cmd := &cobra.Command{ + Use: "port-forward ", + Aliases: []string{"tunnel"}, + Args: cobra.ExactArgs(1), + Example: ` + - Port forward a single TCP port from 1234 in the workspace to port 5678 on + your local machine + + ` + cliui.Styles.Code.Render("$ coder port-forward --tcp 5678:1234") + ` + + - Port forward a single UDP port from port 9000 to port 9000 on your local + machine + + ` + cliui.Styles.Code.Render("$ coder port-forward --udp 9000") + ` + + - Forward a Unix socket in the workspace to a local Unix socket + + ` + cliui.Styles.Code.Render("$ coder port-forward --unix ./local.sock:~/remote.sock") + ` + + - Forward a Unix socket in the workspace to a local TCP port + + ` + cliui.Styles.Code.Render("$ coder port-forward --unix 8080:~/remote.sock") + ` + + - Port forward multiple TCP ports and a UDP port + + ` + cliui.Styles.Code.Render("$ coder port-forward --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53"), + RunE: func(cmd *cobra.Command, args []string) error { + specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards) + if err != nil { + return xerrors.Errorf("parse port-forward specs: %w", err) + } + if len(specs) == 0 { + err = cmd.Help() + if err != nil { + return xerrors.Errorf("generate help output: %w", err) + } + return xerrors.New("no port-forwards requested") + } + + client, err := createClient(cmd) + if err != nil { + return err + } + organization, err := currentOrganization(cmd, client) + if err != nil { + return err + } + + workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], false) + if err != nil { + return err + } + if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { + return xerrors.New("workspace must be in start transition to port-forward") + } + if workspace.LatestBuild.Job.CompletedAt == nil { + err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + if err != nil { + return err + } + } + + err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ + WorkspaceName: workspace.Name, + Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { + return client.WorkspaceAgent(ctx, agent.ID) + }, + }) + if err != nil { + return xerrors.Errorf("await agent: %w", err) + } + + conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) + if err != nil { + return xerrors.Errorf("dial workspace agent: %w", err) + } + defer conn.Close() + + // Start all listeners. + var ( + ctx, cancel = context.WithCancel(cmd.Context()) + wg = new(sync.WaitGroup) + listeners = make([]net.Listener, len(specs)) + closeAllListeners = func() { + for _, l := range listeners { + if l == nil { + continue + } + _ = l.Close() + } + } + ) + defer cancel() + for i, spec := range specs { + l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) + if err != nil { + closeAllListeners() + return err + } + listeners[i] = l + } + + // Wait for the context to be canceled or for a signal and close + // all listeners. + var closeErr error + go func() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-ctx.Done(): + closeErr = ctx.Err() + case <-sigs: + _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections") + closeErr = xerrors.New("signal received") + } + + cancel() + closeAllListeners() + }() + + _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!") + wg.Wait() + return closeErr + }, + } + + cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine") + cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") + cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port") + + return cmd +} + +func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) + + var ( + l net.Listener + err error + ) + switch spec.listenNetwork { + case "tcp": + l, err = net.Listen(spec.listenNetwork, spec.listenAddress) + case "udp": + var host, port string + host, port, err = net.SplitHostPort(spec.listenAddress) + if err != nil { + return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err) + } + + var portInt int + portInt, err = strconv.Atoi(port) + if err != nil { + return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err) + } + + l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ + IP: net.ParseIP(host), + Port: portInt, + }) + case "unix": + l, err = net.Listen(spec.listenNetwork, spec.listenAddress) + default: + return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork) + } + if err != nil { + return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) + } + + wg.Add(1) + go func(spec portForwardSpec) { + defer wg.Done() + for { + netConn, err := l.Accept() + if err != nil { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err) + _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") + return + } + + go func(netConn net.Conn) { + defer netConn.Close() + remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) + if err != nil { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + return + } + defer remoteConn.Close() + + coderagent.Bicopy(ctx, netConn, remoteConn) + }(netConn) + } + }(spec) + + return l, nil +} + +type portForwardSpec struct { + listenNetwork string // tcp, udp, unix + listenAddress string // : or path + + dialNetwork string // tcp, udp, unix + dialAddress string // : or path +} + +func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) { + specs := []portForwardSpec{} + + for _, spec := range tcpSpecs { + local, remote, err := parsePortPort(spec) + if err != nil { + return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) + } + + specs = append(specs, portForwardSpec{ + listenNetwork: "tcp", + listenAddress: fmt.Sprintf("127.0.0.1:%v", local), + dialNetwork: "tcp", + dialAddress: fmt.Sprintf("127.0.0.1:%v", remote), + }) + } + + for _, spec := range udpSpecs { + local, remote, err := parsePortPort(spec) + if err != nil { + return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) + } + + specs = append(specs, portForwardSpec{ + listenNetwork: "udp", + listenAddress: fmt.Sprintf("127.0.0.1:%v", local), + dialNetwork: "udp", + dialAddress: fmt.Sprintf("127.0.0.1:%v", remote), + }) + } + + for _, specStr := range unixSpecs { + localPath, localTCP, remotePath, err := parseUnixUnix(specStr) + if err != nil { + return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err) + } + + spec := portForwardSpec{ + dialNetwork: "unix", + dialAddress: remotePath, + } + if localPath == "" { + spec.listenNetwork = "tcp" + spec.listenAddress = fmt.Sprintf("127.0.0.1:%v", localTCP) + } else { + if runtime.GOOS == "windows" { + return nil, xerrors.Errorf("Unix port-forwarding is not supported on Windows") + } + spec.listenNetwork = "unix" + spec.listenAddress = localPath + } + specs = append(specs, spec) + } + + // Check for duplicate entries. + locals := map[string]struct{}{} + for _, spec := range specs { + localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress) + if _, ok := locals[localStr]; ok { + return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress) + } + locals[localStr] = struct{}{} + } + + return specs, nil +} + +func parsePort(in string) (uint16, error) { + port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16) + if err != nil { + return 0, xerrors.Errorf("parse port %q: %w", in, err) + } + if port == 0 { + return 0, xerrors.New("port cannot be 0") + } + + return uint16(port), nil +} + +func parseUnixPath(in string) (string, error) { + path, err := coderagent.ExpandRelativeHomePath(strings.TrimSpace(in)) + if err != nil { + return "", xerrors.Errorf("tidy path %q: %w", in, err) + } + + return path, nil +} + +func parsePortPort(in string) (local uint16, remote uint16, err error) { + parts := strings.Split(in, ":") + if len(parts) > 2 { + return 0, 0, xerrors.Errorf("invalid port specification %q", in) + } + if len(parts) == 1 { + // Duplicate the single part + parts = append(parts, parts[0]) + } + + local, err = parsePort(parts[0]) + if err != nil { + return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err) + } + remote, err = parsePort(parts[1]) + if err != nil { + return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err) + } + + return local, remote, nil +} + +func parsePortOrUnixPath(in string) (string, uint16, error) { + port, err := parsePort(in) + if err == nil { + return "", port, nil + } + + path, err := parseUnixPath(in) + if err != nil { + return "", 0, xerrors.Errorf("could not parse port or unix path %q: %w", in, err) + } + + return path, 0, nil +} + +func parseUnixUnix(in string) (string, uint16, string, error) { + parts := strings.Split(in, ":") + if len(parts) > 2 { + return "", 0, "", xerrors.Errorf("invalid port-forward specification %q", in) + } + if len(parts) == 1 { + // Duplicate the single part + parts = append(parts, parts[0]) + } + + localPath, localPort, err := parsePortOrUnixPath(parts[0]) + if err != nil { + return "", 0, "", xerrors.Errorf("parse local part of spec %q: %w", in, err) + } + + // We don't really touch the remote path at all since it gets cleaned + // up/expanded on the remote. + return localPath, localPort, parts[1], nil +} diff --git a/cli/portforward_test.go b/cli/portforward_test.go new file mode 100644 index 0000000000..0c0d3ddc5f --- /dev/null +++ b/cli/portforward_test.go @@ -0,0 +1,532 @@ +package cli_test + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/pion/udp" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisioner/echo" + "github.com/coder/coder/provisionersdk/proto" +) + +func TestPortForward(t *testing.T) { + t.Parallel() + + t.Run("None", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + + cmd, root := clitest.New(t, "port-forward", "blah") + clitest.SetupConfig(t, client, root) + buf := newThreadSafeBuffer() + cmd.SetOut(buf) + + err := cmd.Execute() + require.Error(t, err) + require.ErrorContains(t, err, "no port-forwards") + + // Check that the help was printed. + require.Contains(t, buf.String(), "port-forward ") + }) + + cases := []struct { + name string + network string + // The flag to pass to `coder port-forward X` to port-forward this type + // of connection. Has two format args (both strings), the first is the + // local address and the second is the remote address. + flag string + // setupRemote creates a "remote" listener to emulate a service in the + // workspace. + setupRemote func(t *testing.T) net.Listener + // setupLocal returns an available port or Unix socket path that the + // port-forward command will listen on "locally". Returns the address + // you pass to net.Dial, and the port/path you pass to `coder + // port-forward`. + setupLocal func(t *testing.T) (string, string) + }{ + { + name: "TCP", + network: "tcp", + flag: "--tcp=%v:%v", + setupRemote: func(t *testing.T) net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener") + return l + }, + setupLocal: func(t *testing.T) (string, string) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "create TCP listener to generate random port") + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + require.NoErrorf(t, err, "split TCP address %q", l.Addr().String()) + return l.Addr().String(), port + }, + }, + { + name: "UDP", + network: "udp", + flag: "--udp=%v:%v", + setupRemote: func(t *testing.T) net.Listener { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener") + return l + }, + setupLocal: func(t *testing.T) (string, string) { + addr := net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + l, err := udp.Listen("udp", &addr) + require.NoError(t, err, "create UDP listener to generate random port") + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + require.NoErrorf(t, err, "split UDP address %q", l.Addr().String()) + return l.Addr().String(), port + }, + }, + { + name: "Unix", + network: "unix", + flag: "--unix=%v:%v", + setupRemote: func(t *testing.T) net.Listener { + if runtime.GOOS == "windows" { + t.Skip("Unix socket forwarding isn't supported on Windows") + } + + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir for unix listener") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock")) + require.NoError(t, err, "create UDP listener") + return l + }, + setupLocal: func(t *testing.T) (string, string) { + tmpDir, err := os.MkdirTemp("", "coderd_agent_test_") + require.NoError(t, err, "create temp dir for unix listener") + t.Cleanup(func() { + _ = os.RemoveAll(tmpDir) + }) + + path := filepath.Join(tmpDir, "test.sock") + return path, path + }, + }, + } + + for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + t.Run("OnePort", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + l1, p1 = setupTestListener(t, c.setupRemote(t)) + ) + t.Cleanup(func() { + _ = l1.Close() + }) + + // Create a flag that forwards from local to listener 1. + localAddress, localFlag := c.setupLocal(t) + flag := fmt.Sprintf(c.flag, localFlag, p1) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listener. + cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) + clitest.SetupConfig(t, client, root) + buf := newThreadSafeBuffer() + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open two connections simultaneously and test them out of + // sync. + d := net.Dialer{Timeout: 3 * time.Second} + c1, err := d.DialContext(ctx, c.network, localAddress) + require.NoError(t, err, "open connection 1 to 'local' listener") + defer c1.Close() + c2, err := d.DialContext(ctx, c.network, localAddress) + require.NoError(t, err, "open connection 2 to 'local' listener") + defer c2.Close() + testDial(t, c2) + testDial(t, c1) + }) + + t.Run("TwoPorts", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + l1, p1 = setupTestListener(t, c.setupRemote(t)) + l2, p2 = setupTestListener(t, c.setupRemote(t)) + ) + t.Cleanup(func() { + _ = l1.Close() + _ = l2.Close() + }) + + // Create a flags for listener 1 and listener 2. + localAddress1, localFlag1 := c.setupLocal(t) + localAddress2, localFlag2 := c.setupLocal(t) + flag1 := fmt.Sprintf(c.flag, localFlag1, p1) + flag2 := fmt.Sprintf(c.flag, localFlag2, p2) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listeners. + cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2) + clitest.SetupConfig(t, client, root) + buf := newThreadSafeBuffer() + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open a connection to both listener 1 and 2 simultaneously and + // then test them out of order. + d := net.Dialer{Timeout: 3 * time.Second} + c1, err := d.DialContext(ctx, c.network, localAddress1) + require.NoError(t, err, "open connection 1 to 'local' listener 1") + defer c1.Close() + c2, err := d.DialContext(ctx, c.network, localAddress2) + require.NoError(t, err, "open connection 2 to 'local' listener 2") + defer c2.Close() + testDial(t, c2) + testDial(t, c1) + }) + }) + } + + // Test doing a TCP -> Unix forward. + t.Run("TCP2Unix", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + + // Find the TCP and Unix cases so we can use their setupLocal and + // setupRemote methods respectively. + tcpCase = cases[0] + unixCase = cases[2] + + // Setup remote Unix listener. + l1, p1 = setupTestListener(t, unixCase.setupRemote(t)) + ) + t.Cleanup(func() { + _ = l1.Close() + }) + + // Create a flag that forwards from local TCP to Unix listener 1. + // Notably this is a --unix flag. + localAddress, localFlag := tcpCase.setupLocal(t) + flag := fmt.Sprintf(unixCase.flag, localFlag, p1) + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listener. + cmd, root := clitest.New(t, "port-forward", workspace.Name, flag) + clitest.SetupConfig(t, client, root) + buf := newThreadSafeBuffer() + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open two connections simultaneously and test them out of + // sync. + d := net.Dialer{Timeout: 3 * time.Second} + c1, err := d.DialContext(ctx, tcpCase.network, localAddress) + require.NoError(t, err, "open connection 1 to 'local' listener") + defer c1.Close() + c2, err := d.DialContext(ctx, tcpCase.network, localAddress) + require.NoError(t, err, "open connection 2 to 'local' listener") + defer c2.Close() + testDial(t, c2) + testDial(t, c1) + }) + + // Test doing TCP, UDP and Unix at the same time. + t.Run("All", func(t *testing.T) { + t.Parallel() + var ( + client = coderdtest.New(t, nil) + user = coderdtest.CreateFirstUser(t, client) + _, workspace = runAgent(t, client, user.UserID) + // These aren't fixed size because we exclude Unix on Windows. + dials = []addr{} + flags = []string{} + ) + + // Start listeners and populate arrays with the cases. + for _, c := range cases { + if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" { + // Unix isn't supported on Windows, but we can still + // test other protocols together. + continue + } + + l, p := setupTestListener(t, c.setupRemote(t)) + t.Cleanup(func() { + _ = l.Close() + }) + + localAddress, localFlag := c.setupLocal(t) + dials = append(dials, addr{ + network: c.network, + addr: localAddress, + }) + flags = append(flags, fmt.Sprintf(c.flag, localFlag, p)) + } + + // Launch port-forward in a goroutine so we can start dialing + // the "local" listeners. + cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...) + clitest.SetupConfig(t, client, root) + buf := newThreadSafeBuffer() + cmd.SetOut(io.MultiWriter(buf, os.Stderr)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + err := cmd.ExecuteContext(ctx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }() + waitForPortForwardReady(t, buf) + + // Open connections to all items in the "dial" array. + var ( + d = net.Dialer{Timeout: 3 * time.Second} + conns = make([]net.Conn, len(dials)) + ) + for i, a := range dials { + c, err := d.DialContext(ctx, a.network, a.addr) + require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1) + t.Cleanup(func() { + _ = c.Close() + }) + conns[i] = c + } + + // Test each connection in reverse order. + for i := len(conns) - 1; i >= 0; i-- { + testDial(t, conns[i]) + } + }) +} + +// runAgent creates a fake workspace and starts an agent locally for that +// workspace. The agent will be cleaned up on test completion. +func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]codersdk.WorkspaceResource, codersdk.Workspace) { + ctx := context.Background() + user, err := client.User(ctx, userID.String()) + require.NoError(t, err, "specified user does not exist") + require.Greater(t, len(user.OrganizationIDs), 0, "user has no organizations") + orgID := user.OrganizationIDs[0] + + // Setup echo provisioner + agentToken := uuid.NewString() + coderdtest.NewProvisionerDaemon(t, client) + version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "somename", + Type: "someinstance", + Agents: []*proto.Agent{{ + Auth: &proto.Agent_Token{ + Token: agentToken, + }, + }}, + }}, + }, + }, + }}, + }) + + // Create template and workspace + template := coderdtest.CreateTemplate(t, client, orgID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + // Start workspace agent in a goroutine + cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) + clitest.SetupConfig(t, client, root) + agentCtx, agentCancel := context.WithCancel(ctx) + t.Cleanup(agentCancel) + go func() { + err := cmd.ExecuteContext(agentCtx) + require.NoError(t, err) + }() + + coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID) + require.NoError(t, err) + + return resources, workspace +} + +// setupTestListener starts accepting connections and echoing a single packet. +// Returns the listener and the listen port or Unix path. +func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) { + t.Cleanup(func() { + _ = l.Close() + }) + go func() { + for { + c, err := l.Accept() + if err != nil { + return + } + + go testAccept(t, c) + } + }() + + addr := l.Addr().String() + if !strings.HasPrefix(l.Addr().Network(), "unix") { + _, port, err := net.SplitHostPort(addr) + require.NoErrorf(t, err, "split non-Unix listen path %q", addr) + addr = port + } + + return l, addr +} + +var dialTestPayload = []byte("dean-was-here123") + +func testDial(t *testing.T, c net.Conn) { + t.Helper() + + assertWritePayload(t, c, dialTestPayload) + assertReadPayload(t, c, dialTestPayload) +} + +func testAccept(t *testing.T, c net.Conn) { + t.Helper() + defer c.Close() + + assertReadPayload(t, c, dialTestPayload) + assertWritePayload(t, c, dialTestPayload) +} + +func assertReadPayload(t *testing.T, r io.Reader, payload []byte) { + b := make([]byte, len(payload)+16) + n, err := r.Read(b) + require.NoError(t, err, "read payload") + require.Equal(t, len(payload), n, "read payload length does not match") + require.Equal(t, payload, b[:n]) +} + +func assertWritePayload(t *testing.T, w io.Writer, payload []byte) { + n, err := w.Write(payload) + require.NoError(t, err, "write payload") + require.Equal(t, len(payload), n, "payload length does not match") +} + +func waitForPortForwardReady(t *testing.T, output *threadSafeBuffer) { + for i := 0; i < 100; i++ { + time.Sleep(250 * time.Millisecond) + + data := output.String() + if strings.Contains(data, "Ready!") { + return + } + } + + t.Fatal("port-forward command did not become ready in time") +} + +type addr struct { + network string + addr string +} + +type threadSafeBuffer struct { + b *bytes.Buffer + mut *sync.RWMutex +} + +func newThreadSafeBuffer() *threadSafeBuffer { + return &threadSafeBuffer{ + b: bytes.NewBuffer(nil), + mut: new(sync.RWMutex), + } +} + +var _ io.Reader = &threadSafeBuffer{} +var _ io.Writer = &threadSafeBuffer{} + +// Read implements io.Reader. +func (b *threadSafeBuffer) Read(p []byte) (int, error) { + b.mut.RLock() + defer b.mut.RUnlock() + + return b.b.Read(p) +} + +// Write implements io.Writer. +func (b *threadSafeBuffer) Write(p []byte) (int, error) { + b.mut.Lock() + defer b.mut.Unlock() + + return b.b.Write(p) +} + +func (b *threadSafeBuffer) String() string { + b.mut.RLock() + defer b.mut.RUnlock() + + return b.b.String() +} diff --git a/cli/root.go b/cli/root.go index 9e96be9be3..424ec54155 100644 --- a/cli/root.go +++ b/cli/root.go @@ -36,6 +36,15 @@ const ( varForceTty = "force-tty" ) +func init() { + // Customizes the color of headings to make subcommands more visually + // appealing. + header := cliui.Styles.Placeholder + cobra.AddTemplateFunc("usageHeader", func(s string) string { + return header.Render(s) + }) +} + func Root() *cobra.Command { cmd := &cobra.Command{ Use: "coder", @@ -71,7 +80,7 @@ func Root() *cobra.Command { templates(), update(), users(), - tunnel(), + portForward(), workspaceAgent(), ) @@ -179,13 +188,7 @@ func isTTY(cmd *cobra.Command) bool { } func usageTemplate() string { - // Customizes the color of headings to make subcommands - // more visually appealing. - header := cliui.Styles.Placeholder - cobra.AddTemplateFunc("usageHeader", func(s string) string { - return header.Render(s) - }) - + // usageHeader is defined in init(). return `{{usageHeader "Usage:"}} {{- if .Runnable}} {{.UseLine}} diff --git a/cli/ssh.go b/cli/ssh.go index b7f82f00dd..4dfd68463a 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -50,94 +50,23 @@ func ssh() *cobra.Command { return err } - var workspace codersdk.Workspace - var workspaceParts []string if shuffle { err := cobra.ExactArgs(0)(cmd, args) if err != nil { return err } - - workspaces, err := client.WorkspacesByOwner(cmd.Context(), organization.ID, codersdk.Me) - if err != nil { - return err - } - if len(workspaces) == 0 { - return xerrors.New("no workspaces to shuffle") - } - - idx, err := cryptorand.Intn(len(workspaces)) - if err != nil { - return err - } - workspace = workspaces[idx] } else { err := cobra.MinimumNArgs(1)(cmd, args) if err != nil { return err } - - workspaceParts = strings.Split(args[0], ".") - workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), organization.ID, codersdk.Me, workspaceParts[0]) - if err != nil { - return err - } } - if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { - return xerrors.New("workspace must be in start transition to ssh") - } - - if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) - if err != nil { - return err - } - } - - if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete { - return xerrors.New("workspace is deleting...") - } - - resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID) + workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], shuffle) if err != nil { return err } - agents := make([]codersdk.WorkspaceAgent, 0) - for _, resource := range resources { - agents = append(agents, resource.Agents...) - } - if len(agents) == 0 { - return xerrors.New("workspace has no agents") - } - var agent codersdk.WorkspaceAgent - if len(workspaceParts) >= 2 { - for _, otherAgent := range agents { - if otherAgent.Name != workspaceParts[1] { - continue - } - agent = otherAgent - break - } - if agent.ID == uuid.Nil { - return xerrors.Errorf("agent not found by name %q", workspaceParts[1]) - } - } - if agent.ID == uuid.Nil { - if len(agents) > 1 { - if !shuffle { - return xerrors.New("you must specify the name of an agent") - } - idx, err := cryptorand.Intn(len(agents)) - if err != nil { - return err - } - agent = agents[idx] - } else { - agent = agents[0] - } - } // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ @@ -233,6 +162,92 @@ func ssh() *cobra.Command { return cmd } +// getWorkspaceAgent returns the workspace and agent selected using either the +// `[.]` syntax via `in` or picks a random workspace and agent +// if `shuffle` is true. +func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, orgID uuid.UUID, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive + ctx := cmd.Context() + + var ( + workspace codersdk.Workspace + workspaceParts = strings.Split(in, ".") + err error + ) + if shuffle { + workspaces, err := client.WorkspacesByOwner(cmd.Context(), orgID, userID) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + } + if len(workspaces) == 0 { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("no workspaces to shuffle") + } + + workspace, err = cryptorand.Element(workspaces) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + } + } else { + workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), orgID, userID, workspaceParts[0]) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + } + } + + if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh") + } + if workspace.LatestBuild.Job.CompletedAt == nil { + err := cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + } + } + if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name) + } + + resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("fetch workspace resources: %w", err) + } + + agents := make([]codersdk.WorkspaceAgent, 0) + for _, resource := range resources { + agents = append(agents, resource.Agents...) + } + if len(agents) == 0 { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name) + } + var agent codersdk.WorkspaceAgent + if len(workspaceParts) >= 2 { + for _, otherAgent := range agents { + if otherAgent.Name != workspaceParts[1] { + continue + } + agent = otherAgent + break + } + if agent.ID == uuid.Nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1]) + } + } + if agent.ID == uuid.Nil { + if len(agents) > 1 { + if !shuffle { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent") + } + agent, err = cryptorand.Element(agents) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err + } + } else { + agent = agents[0] + } + } + + return workspace, agent, nil +} + // Attempt to poll workspace autostop. We write a per-workspace lockfile to // avoid spamming the user with notifications in case of multiple instances // of the CLI running simultaneously. diff --git a/cli/templates.go b/cli/templates.go index 45f6224369..2e5b179c90 100644 --- a/cli/templates.go +++ b/cli/templates.go @@ -1,8 +1,9 @@ package cli import ( - "github.com/fatih/color" "github.com/spf13/cobra" + + "github.com/coder/coder/cli/cliui" ) func templates() *cobra.Command { @@ -13,15 +14,15 @@ func templates() *cobra.Command { Example: ` - Create a template for developers to create workspaces - ` + color.New(color.FgHiMagenta).Sprint("$ coder templates create") + ` + ` + cliui.Styles.Code.Render("$ coder templates create") + ` - Make changes to your template, and plan the changes - ` + color.New(color.FgHiMagenta).Sprint("$ coder templates plan ") + ` + ` + cliui.Styles.Code.Render("$ coder templates plan ") + ` - Update the template. Your developers can update their workspaces - ` + color.New(color.FgHiMagenta).Sprint("$ coder templates update "), + ` + cliui.Styles.Code.Render("$ coder templates update "), } cmd.AddCommand( templateCreate(), diff --git a/cli/tunnel.go b/cli/tunnel.go deleted file mode 100644 index 887d766a09..0000000000 --- a/cli/tunnel.go +++ /dev/null @@ -1,14 +0,0 @@ -package cli - -import "github.com/spf13/cobra" - -func tunnel() *cobra.Command { - return &cobra.Command{ - Annotations: workspaceCommand, - Use: "tunnel", - Short: "Forward ports to your local machine", - RunE: func(cmd *cobra.Command, args []string) error { - return nil - }, - } -} diff --git a/cryptorand/slices.go b/cryptorand/slices.go new file mode 100644 index 0000000000..90bc6a9b20 --- /dev/null +++ b/cryptorand/slices.go @@ -0,0 +1,20 @@ +package cryptorand + +import ( + "golang.org/x/xerrors" +) + +// Element returns a random element of the slice. An error will be returned if +// the slice has no elements in it. +func Element[T any](s []T) (out T, err error) { + if len(s) == 0 { + return out, xerrors.New("slice must have at least one element") + } + + i, err := Intn(len(s)) + if err != nil { + return out, xerrors.Errorf("generate random integer from 0-%v: %w", len(s), err) + } + + return s[i], nil +} diff --git a/cryptorand/slices_test.go b/cryptorand/slices_test.go new file mode 100644 index 0000000000..f4c7be248c --- /dev/null +++ b/cryptorand/slices_test.go @@ -0,0 +1,56 @@ +package cryptorand_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" +) + +func TestRandomElement(t *testing.T) { + t.Parallel() + + t.Run("Empty", func(t *testing.T) { + t.Parallel() + + s := []string{} + v, err := cryptorand.Element(s) + require.Error(t, err) + require.ErrorContains(t, err, "slice must have at least one element") + require.Empty(t, v) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + // Generate random slices of ints and strings + var ( + ints = make([]int, 20) + strings = make([]string, 20) + ) + for i := range ints { + v, err := cryptorand.Intn(1024) + require.NoError(t, err, "generate random int for test slice") + ints[i] = v + } + for i := range strings { + v, err := cryptorand.String(10) + require.NoError(t, err, "generate random string for test slice") + strings[i] = v + } + + // Get a random value from each 20 times. + for i := 0; i < 20; i++ { + iv, err := cryptorand.Element(ints) + require.NoError(t, err, "unexpected error from Element(ints)") + t.Logf("random int slice element: %v", iv) + require.Contains(t, ints, iv) + + sv, err := cryptorand.Element(strings) + require.NoError(t, err, "unexpected error from Element(strings)") + t.Logf("random string slice element: %v", sv) + require.Contains(t, strings, sv) + } + }) +} diff --git a/go.mod b/go.mod index 47d47a5e1f..1876731d77 100644 --- a/go.mod +++ b/go.mod @@ -90,6 +90,7 @@ require ( github.com/pion/logging v0.2.2 github.com/pion/transport v0.13.0 github.com/pion/turn/v2 v2.0.8 + github.com/pion/udp v0.1.1 github.com/pion/webrtc/v3 v3.1.39 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/sftp v1.13.4 @@ -217,7 +218,6 @@ require ( github.com/pion/sdp/v3 v3.0.5 // indirect github.com/pion/srtp/v2 v2.0.7 // indirect github.com/pion/stun v0.3.5 // indirect - github.com/pion/udp v0.1.1 // indirect github.com/pires/go-proxyproto v0.6.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/peer/channel.go b/peer/channel.go index 5a4424f8bf..7db76d984f 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -53,8 +53,8 @@ type ChannelOptions struct { // Arbitrary string that can be parsed on `Accept`. Protocol string - // Ordered determines whether the channel acts like - // a TCP connection. Defaults to false. + // Unordered determines whether the channel acts like + // a UDP connection. Defaults to false. Unordered bool // Whether the channel will be left open on disconnect or not. diff --git a/peer/conn.go b/peer/conn.go index e9126443b8..c81b29d0bb 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -68,7 +68,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp closed: make(chan struct{}), closedRTC: make(chan struct{}), closedICE: make(chan struct{}), - dcOpenChannel: make(chan *webrtc.DataChannel), + dcOpenChannel: make(chan *webrtc.DataChannel, 8), dcDisconnectChannel: make(chan struct{}), dcFailedChannel: make(chan struct{}), localCandidateChannel: make(chan webrtc.ICECandidateInit), @@ -264,12 +264,13 @@ func (c *Conn) init() error { }() }) c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) { - select { - case <-c.closed: - return - case c.dcOpenChannel <- dc: - default: - } + go func() { + select { + case <-c.closed: + return + case c.dcOpenChannel <- dc: + } + }() }) _, err := c.pingChannel() if err != nil { @@ -469,8 +470,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) { return newChannel(c, dataChannel, &ChannelOptions{}), nil } -// Dial creates a new DataChannel. -func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { +// CreateChannel creates a new DataChannel. +func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { if opts == nil { opts = &ChannelOptions{} } diff --git a/peer/conn_test.go b/peer/conn_test.go index 960ec34cfa..46bcea980e 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -90,7 +90,7 @@ func TestConn(t *testing.T) { _, err := server.Ping() require.NoError(t, err) // Create a channel that closes on disconnect. - channel, err := server.Dial(context.Background(), "wow", nil) + channel, err := server.CreateChannel(context.Background(), "wow", nil) assert.NoError(t, err) err = wan.Stop() require.NoError(t, err) @@ -108,7 +108,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) @@ -124,7 +124,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, wan := createPair(t) exchange(t, client, server) - cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) + cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) @@ -196,7 +196,7 @@ func TestConn(t *testing.T) { defaultTransport := http.DefaultTransport.(*http.Transport).Clone() var cch *peer.Channel defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - cch, err = client.Dial(ctx, "hello", &peer.ChannelOptions{}) + cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func TestConn(t *testing.T) { require.NoError(t, err) expectedErr := xerrors.New("wow") _ = conn.CloseWithError(expectedErr) - _, err = conn.Dial(context.Background(), "", nil) + _, err = conn.CreateChannel(context.Background(), "", nil) require.ErrorIs(t, err, expectedErr) }) @@ -274,7 +274,7 @@ func TestConn(t *testing.T) { client, server, _ := createPair(t) exchange(t, client, server) go func() { - channel, err := client.Dial(context.Background(), "test", nil) + channel, err := client.CreateChannel(context.Background(), "test", nil) require.NoError(t, err) _, err = channel.Write([]byte{1, 2}) require.NoError(t, err)