mirror of https://github.com/coder/coder.git
feat: add port-forward subcommand (#1350)
This commit is contained in:
parent
76fc59aa79
commit
9141be3656
114
agent/agent.go
114
agent/agent.go
|
@ -9,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
|
|||
go a.sshServer.HandleConn(channel.NetConn())
|
||||
case "reconnecting-pty":
|
||||
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
|
||||
case "dial":
|
||||
go a.handleDial(ctx, channel.Label(), channel.NetConn())
|
||||
default:
|
||||
a.logger.Warn(ctx, "unhandled protocol from channel",
|
||||
slog.F("protocol", channel.Protocol()),
|
||||
|
@ -617,6 +620,70 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
|
|||
}
|
||||
}
|
||||
|
||||
// dialResponse is written to datachannels with protocol "dial" by the agent as
|
||||
// the first packet to signify whether the dial succeeded or failed.
|
||||
type dialResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
writeError := func(responseError error) error {
|
||||
msg := ""
|
||||
if responseError != nil {
|
||||
msg = responseError.Error()
|
||||
if !xerrors.Is(responseError, io.EOF) {
|
||||
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(dialResponse{
|
||||
Error: msg,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
|
||||
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := url.Parse(label)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
|
||||
return
|
||||
}
|
||||
|
||||
network := u.Scheme
|
||||
addr := u.Host + u.Path
|
||||
if strings.HasPrefix(network, "unix") {
|
||||
if runtime.GOOS == "windows" {
|
||||
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
|
||||
return
|
||||
}
|
||||
addr, err = ExpandRelativeHomePath(addr)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
nconn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
|
||||
return
|
||||
}
|
||||
|
||||
err = writeError(nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
Bicopy(ctx, conn, nconn)
|
||||
}
|
||||
|
||||
// isClosed returns whether the API is closed or not.
|
||||
func (a *agent) isClosed() bool {
|
||||
select {
|
||||
|
@ -662,3 +729,50 @@ func (r *reconnectingPTY) Close() {
|
|||
r.circularBuffer.Reset()
|
||||
r.timeout.Stop()
|
||||
}
|
||||
|
||||
// Bicopy copies all of the data between the two connections and will close them
|
||||
// after one or both of them are done writing. If the context is canceled, both
|
||||
// of the connections will be closed.
|
||||
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
copyFunc := func(dst io.WriteCloser, src io.Reader) {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(dst, src)
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go copyFunc(c1, c2)
|
||||
go copyFunc(c2, c1)
|
||||
|
||||
// Convert waitgroup to a channel so we can also wait on the context.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
// ExpandRelativeHomePath expands the tilde at the beginning of a path to the
|
||||
// current user's home directory and returns a full absolute path.
|
||||
func ExpandRelativeHomePath(in string) (string, error) {
|
||||
usr, err := user.Current()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("get current user details: %w", err)
|
||||
}
|
||||
|
||||
if in == "~" {
|
||||
in = usr.HomeDir
|
||||
} else if strings.HasPrefix(in, "~/") {
|
||||
in = filepath.Join(usr.HomeDir, in[2:])
|
||||
}
|
||||
|
||||
return filepath.Abs(in)
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -234,6 +235,112 @@ func TestAgent(t *testing.T) {
|
|||
findEcho()
|
||||
findEcho()
|
||||
})
|
||||
|
||||
t.Run("Dial", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
setup func(t *testing.T) net.Listener
|
||||
}{
|
||||
{
|
||||
name: "TCP",
|
||||
setup: func(t *testing.T) net.Listener {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "create TCP listener")
|
||||
return l
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
setup: func(t *testing.T) net.Listener {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
}
|
||||
l, err := udp.Listen("udp", &addr)
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix",
|
||||
setup: func(t *testing.T) net.Listener {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir for unix listener")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup listener
|
||||
l := c.setup(t)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go testAccept(t, c)
|
||||
}
|
||||
}()
|
||||
|
||||
// Dial the listener over WebRTC twice and test out of order
|
||||
conn := setupAgent(t, agent.Metadata{}, 0)
|
||||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
testDial(t, conn2)
|
||||
testDial(t, conn1)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// This test uses Unix listeners so we can very easily ensure that
|
||||
// no other tests decide to listen on the same random port we
|
||||
// picked.
|
||||
t.Skip("this test is unsupported on Windows")
|
||||
return
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
// Try to dial the non-existent Unix socket over WebRTC
|
||||
conn := setupAgent(t, agent.Metadata{}, 0)
|
||||
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "remote dial error")
|
||||
require.ErrorContains(t, err, "no such file")
|
||||
require.Nil(t, netConn)
|
||||
})
|
||||
}
|
||||
|
||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
|
||||
|
@ -303,3 +410,34 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
|||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
var dialTestPayload = []byte("dean-was-here123")
|
||||
|
||||
func testDial(t *testing.T, c net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
assertWritePayload(t, c, dialTestPayload)
|
||||
assertReadPayload(t, c, dialTestPayload)
|
||||
}
|
||||
|
||||
func testAccept(t *testing.T, c net.Conn) {
|
||||
t.Helper()
|
||||
defer c.Close()
|
||||
|
||||
assertReadPayload(t, c, dialTestPayload)
|
||||
assertWritePayload(t, c, dialTestPayload)
|
||||
}
|
||||
|
||||
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
||||
b := make([]byte, len(payload)+16)
|
||||
n, err := r.Read(b)
|
||||
require.NoError(t, err, "read payload")
|
||||
require.Equal(t, len(payload), n, "read payload length does not match")
|
||||
require.Equal(t, payload, b[:n])
|
||||
}
|
||||
|
||||
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
||||
n, err := w.Write(payload)
|
||||
require.NoError(t, err, "write payload")
|
||||
require.Equal(t, len(payload), n, "payload length does not match")
|
||||
}
|
||||
|
|
|
@ -2,8 +2,11 @@ package agent
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -32,7 +35,7 @@ type Conn struct {
|
|||
// ReconnectingPTY returns a connection serving a TTY that can
|
||||
// be reconnected to via ID.
|
||||
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
|
||||
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
|
||||
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
|
||||
Protocol: "reconnecting-pty",
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error
|
|||
|
||||
// SSH dials the built-in SSH server.
|
||||
func (c *Conn) SSH() (net.Conn, error) {
|
||||
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
|
||||
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
|
||||
Protocol: "ssh",
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -71,6 +74,42 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
|
|||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
// DialContext dials an arbitrary protocol+address from inside the workspace and
|
||||
// proxies it through the provided net.Conn.
|
||||
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
u := &url.URL{
|
||||
Scheme: network,
|
||||
}
|
||||
if strings.HasPrefix(network, "unix") {
|
||||
u.Path = addr
|
||||
} else {
|
||||
u.Host = addr
|
||||
}
|
||||
|
||||
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
|
||||
Protocol: "dial",
|
||||
Unordered: strings.HasPrefix(network, "udp"),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create datachannel: %w", err)
|
||||
}
|
||||
|
||||
// The first message written from the other side is a JSON payload
|
||||
// containing the dial error.
|
||||
dec := json.NewDecoder(channel)
|
||||
var res dialResponse
|
||||
err = dec.Decode(&res)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
|
||||
}
|
||||
if res.Error != "" {
|
||||
_ = channel.Close()
|
||||
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
|
||||
}
|
||||
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
_ = c.Negotiator.DRPCConn().Close()
|
||||
return c.Conn.Close()
|
||||
|
|
|
@ -0,0 +1,379 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/pion/udp"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
coderagent "github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func portForward() *cobra.Command {
|
||||
var (
|
||||
tcpForwards []string // <port>:<port>
|
||||
udpForwards []string // <port>:<port>
|
||||
unixForwards []string // <path>:<path> OR <port>:<path>
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "port-forward <workspace>",
|
||||
Aliases: []string{"tunnel"},
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: `
|
||||
- Port forward a single TCP port from 1234 in the workspace to port 5678 on
|
||||
your local machine
|
||||
|
||||
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 5678:1234") + `
|
||||
|
||||
- Port forward a single UDP port from port 9000 to port 9000 on your local
|
||||
machine
|
||||
|
||||
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --udp 9000") + `
|
||||
|
||||
- Forward a Unix socket in the workspace to a local Unix socket
|
||||
|
||||
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --unix ./local.sock:~/remote.sock") + `
|
||||
|
||||
- Forward a Unix socket in the workspace to a local TCP port
|
||||
|
||||
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --unix 8080:~/remote.sock") + `
|
||||
|
||||
- Port forward multiple TCP ports and a UDP port
|
||||
|
||||
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53"),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse port-forward specs: %w", err)
|
||||
}
|
||||
if len(specs) == 0 {
|
||||
err = cmd.Help()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate help output: %w", err)
|
||||
}
|
||||
return xerrors.New("no port-forwards requested")
|
||||
}
|
||||
|
||||
client, err := createClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
organization, err := currentOrganization(cmd, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
|
||||
return xerrors.New("workspace must be in start transition to port-forward")
|
||||
}
|
||||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||||
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
|
||||
WorkspaceName: workspace.Name,
|
||||
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
return client.WorkspaceAgent(ctx, agent.ID)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial workspace agent: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Start all listeners.
|
||||
var (
|
||||
ctx, cancel = context.WithCancel(cmd.Context())
|
||||
wg = new(sync.WaitGroup)
|
||||
listeners = make([]net.Listener, len(specs))
|
||||
closeAllListeners = func() {
|
||||
for _, l := range listeners {
|
||||
if l == nil {
|
||||
continue
|
||||
}
|
||||
_ = l.Close()
|
||||
}
|
||||
}
|
||||
)
|
||||
defer cancel()
|
||||
for i, spec := range specs {
|
||||
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
|
||||
if err != nil {
|
||||
closeAllListeners()
|
||||
return err
|
||||
}
|
||||
listeners[i] = l
|
||||
}
|
||||
|
||||
// Wait for the context to be canceled or for a signal and close
|
||||
// all listeners.
|
||||
var closeErr error
|
||||
go func() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeErr = ctx.Err()
|
||||
case <-sigs:
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
|
||||
closeErr = xerrors.New("signal received")
|
||||
}
|
||||
|
||||
cancel()
|
||||
closeAllListeners()
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
|
||||
wg.Wait()
|
||||
return closeErr
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
|
||||
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
|
||||
cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
|
||||
|
||||
var (
|
||||
l net.Listener
|
||||
err error
|
||||
)
|
||||
switch spec.listenNetwork {
|
||||
case "tcp":
|
||||
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
|
||||
case "udp":
|
||||
var host, port string
|
||||
host, port, err = net.SplitHostPort(spec.listenAddress)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err)
|
||||
}
|
||||
|
||||
var portInt int
|
||||
portInt, err = strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
|
||||
}
|
||||
|
||||
l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
|
||||
IP: net.ParseIP(host),
|
||||
Port: portInt,
|
||||
})
|
||||
case "unix":
|
||||
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(spec portForwardSpec) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
netConn, err := l.Accept()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
|
||||
return
|
||||
}
|
||||
|
||||
go func(netConn net.Conn) {
|
||||
defer netConn.Close()
|
||||
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
|
||||
coderagent.Bicopy(ctx, netConn, remoteConn)
|
||||
}(netConn)
|
||||
}
|
||||
}(spec)
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type portForwardSpec struct {
|
||||
listenNetwork string // tcp, udp, unix
|
||||
listenAddress string // <ip>:<port> or path
|
||||
|
||||
dialNetwork string // tcp, udp, unix
|
||||
dialAddress string // <ip>:<port> or path
|
||||
}
|
||||
|
||||
func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) {
|
||||
specs := []portForwardSpec{}
|
||||
|
||||
for _, spec := range tcpSpecs {
|
||||
local, remote, err := parsePortPort(spec)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "tcp",
|
||||
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
|
||||
dialNetwork: "tcp",
|
||||
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
|
||||
})
|
||||
}
|
||||
|
||||
for _, spec := range udpSpecs {
|
||||
local, remote, err := parsePortPort(spec)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "udp",
|
||||
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
|
||||
dialNetwork: "udp",
|
||||
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
|
||||
})
|
||||
}
|
||||
|
||||
for _, specStr := range unixSpecs {
|
||||
localPath, localTCP, remotePath, err := parseUnixUnix(specStr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err)
|
||||
}
|
||||
|
||||
spec := portForwardSpec{
|
||||
dialNetwork: "unix",
|
||||
dialAddress: remotePath,
|
||||
}
|
||||
if localPath == "" {
|
||||
spec.listenNetwork = "tcp"
|
||||
spec.listenAddress = fmt.Sprintf("127.0.0.1:%v", localTCP)
|
||||
} else {
|
||||
if runtime.GOOS == "windows" {
|
||||
return nil, xerrors.Errorf("Unix port-forwarding is not supported on Windows")
|
||||
}
|
||||
spec.listenNetwork = "unix"
|
||||
spec.listenAddress = localPath
|
||||
}
|
||||
specs = append(specs, spec)
|
||||
}
|
||||
|
||||
// Check for duplicate entries.
|
||||
locals := map[string]struct{}{}
|
||||
for _, spec := range specs {
|
||||
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
|
||||
if _, ok := locals[localStr]; ok {
|
||||
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
|
||||
}
|
||||
locals[localStr] = struct{}{}
|
||||
}
|
||||
|
||||
return specs, nil
|
||||
}
|
||||
|
||||
func parsePort(in string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("parse port %q: %w", in, err)
|
||||
}
|
||||
if port == 0 {
|
||||
return 0, xerrors.New("port cannot be 0")
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func parseUnixPath(in string) (string, error) {
|
||||
path, err := coderagent.ExpandRelativeHomePath(strings.TrimSpace(in))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("tidy path %q: %w", in, err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func parsePortPort(in string) (local uint16, remote uint16, err error) {
|
||||
parts := strings.Split(in, ":")
|
||||
if len(parts) > 2 {
|
||||
return 0, 0, xerrors.Errorf("invalid port specification %q", in)
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
// Duplicate the single part
|
||||
parts = append(parts, parts[0])
|
||||
}
|
||||
|
||||
local, err = parsePort(parts[0])
|
||||
if err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err)
|
||||
}
|
||||
remote, err = parsePort(parts[1])
|
||||
if err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err)
|
||||
}
|
||||
|
||||
return local, remote, nil
|
||||
}
|
||||
|
||||
func parsePortOrUnixPath(in string) (string, uint16, error) {
|
||||
port, err := parsePort(in)
|
||||
if err == nil {
|
||||
return "", port, nil
|
||||
}
|
||||
|
||||
path, err := parseUnixPath(in)
|
||||
if err != nil {
|
||||
return "", 0, xerrors.Errorf("could not parse port or unix path %q: %w", in, err)
|
||||
}
|
||||
|
||||
return path, 0, nil
|
||||
}
|
||||
|
||||
func parseUnixUnix(in string) (string, uint16, string, error) {
|
||||
parts := strings.Split(in, ":")
|
||||
if len(parts) > 2 {
|
||||
return "", 0, "", xerrors.Errorf("invalid port-forward specification %q", in)
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
// Duplicate the single part
|
||||
parts = append(parts, parts[0])
|
||||
}
|
||||
|
||||
localPath, localPort, err := parsePortOrUnixPath(parts[0])
|
||||
if err != nil {
|
||||
return "", 0, "", xerrors.Errorf("parse local part of spec %q: %w", in, err)
|
||||
}
|
||||
|
||||
// We don't really touch the remote path at all since it gets cleaned
|
||||
// up/expanded on the remote.
|
||||
return localPath, localPort, parts[1], nil
|
||||
}
|
|
@ -0,0 +1,532 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
|
||||
func TestPortForward(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
cmd, root := clitest.New(t, "port-forward", "blah")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "no port-forwards")
|
||||
|
||||
// Check that the help was printed.
|
||||
require.Contains(t, buf.String(), "port-forward <workspace>")
|
||||
})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
network string
|
||||
// The flag to pass to `coder port-forward X` to port-forward this type
|
||||
// of connection. Has two format args (both strings), the first is the
|
||||
// local address and the second is the remote address.
|
||||
flag string
|
||||
// setupRemote creates a "remote" listener to emulate a service in the
|
||||
// workspace.
|
||||
setupRemote func(t *testing.T) net.Listener
|
||||
// setupLocal returns an available port or Unix socket path that the
|
||||
// port-forward command will listen on "locally". Returns the address
|
||||
// you pass to net.Dial, and the port/path you pass to `coder
|
||||
// port-forward`.
|
||||
setupLocal func(t *testing.T) (string, string)
|
||||
}{
|
||||
{
|
||||
name: "TCP",
|
||||
network: "tcp",
|
||||
flag: "--tcp=%v:%v",
|
||||
setupRemote: func(t *testing.T) net.Listener {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "create TCP listener")
|
||||
return l
|
||||
},
|
||||
setupLocal: func(t *testing.T) (string, string) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "create TCP listener to generate random port")
|
||||
defer l.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||
require.NoErrorf(t, err, "split TCP address %q", l.Addr().String())
|
||||
return l.Addr().String(), port
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
network: "udp",
|
||||
flag: "--udp=%v:%v",
|
||||
setupRemote: func(t *testing.T) net.Listener {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
}
|
||||
l, err := udp.Listen("udp", &addr)
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
setupLocal: func(t *testing.T) (string, string) {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
}
|
||||
l, err := udp.Listen("udp", &addr)
|
||||
require.NoError(t, err, "create UDP listener to generate random port")
|
||||
defer l.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||
require.NoErrorf(t, err, "split UDP address %q", l.Addr().String())
|
||||
return l.Addr().String(), port
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix",
|
||||
network: "unix",
|
||||
flag: "--unix=%v:%v",
|
||||
setupRemote: func(t *testing.T) net.Listener {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir for unix listener")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
setupLocal: func(t *testing.T) (string, string) {
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir for unix listener")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
path := filepath.Join(tmpDir, "test.sock")
|
||||
return path, path
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter
|
||||
c := c
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OnePort", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
client = coderdtest.New(t, nil)
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
_, workspace = runAgent(t, client, user.UserID)
|
||||
l1, p1 = setupTestListener(t, c.setupRemote(t))
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
_ = l1.Close()
|
||||
})
|
||||
|
||||
// Create a flag that forwards from local to listener 1.
|
||||
localAddress, localFlag := c.setupLocal(t)
|
||||
flag := fmt.Sprintf(c.flag, localFlag, p1)
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listener.
|
||||
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() {
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
|
||||
// Open two connections simultaneously and test them out of
|
||||
// sync.
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
c1, err := d.DialContext(ctx, c.network, localAddress)
|
||||
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||
defer c1.Close()
|
||||
c2, err := d.DialContext(ctx, c.network, localAddress)
|
||||
require.NoError(t, err, "open connection 2 to 'local' listener")
|
||||
defer c2.Close()
|
||||
testDial(t, c2)
|
||||
testDial(t, c1)
|
||||
})
|
||||
|
||||
t.Run("TwoPorts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
client = coderdtest.New(t, nil)
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
_, workspace = runAgent(t, client, user.UserID)
|
||||
l1, p1 = setupTestListener(t, c.setupRemote(t))
|
||||
l2, p2 = setupTestListener(t, c.setupRemote(t))
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
_ = l1.Close()
|
||||
_ = l2.Close()
|
||||
})
|
||||
|
||||
// Create a flags for listener 1 and listener 2.
|
||||
localAddress1, localFlag1 := c.setupLocal(t)
|
||||
localAddress2, localFlag2 := c.setupLocal(t)
|
||||
flag1 := fmt.Sprintf(c.flag, localFlag1, p1)
|
||||
flag2 := fmt.Sprintf(c.flag, localFlag2, p2)
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listeners.
|
||||
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() {
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
|
||||
// Open a connection to both listener 1 and 2 simultaneously and
|
||||
// then test them out of order.
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
c1, err := d.DialContext(ctx, c.network, localAddress1)
|
||||
require.NoError(t, err, "open connection 1 to 'local' listener 1")
|
||||
defer c1.Close()
|
||||
c2, err := d.DialContext(ctx, c.network, localAddress2)
|
||||
require.NoError(t, err, "open connection 2 to 'local' listener 2")
|
||||
defer c2.Close()
|
||||
testDial(t, c2)
|
||||
testDial(t, c1)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Test doing a TCP -> Unix forward.
|
||||
t.Run("TCP2Unix", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
client = coderdtest.New(t, nil)
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
_, workspace = runAgent(t, client, user.UserID)
|
||||
|
||||
// Find the TCP and Unix cases so we can use their setupLocal and
|
||||
// setupRemote methods respectively.
|
||||
tcpCase = cases[0]
|
||||
unixCase = cases[2]
|
||||
|
||||
// Setup remote Unix listener.
|
||||
l1, p1 = setupTestListener(t, unixCase.setupRemote(t))
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
_ = l1.Close()
|
||||
})
|
||||
|
||||
// Create a flag that forwards from local TCP to Unix listener 1.
|
||||
// Notably this is a --unix flag.
|
||||
localAddress, localFlag := tcpCase.setupLocal(t)
|
||||
flag := fmt.Sprintf(unixCase.flag, localFlag, p1)
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listener.
|
||||
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() {
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
|
||||
// Open two connections simultaneously and test them out of
|
||||
// sync.
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
c1, err := d.DialContext(ctx, tcpCase.network, localAddress)
|
||||
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||
defer c1.Close()
|
||||
c2, err := d.DialContext(ctx, tcpCase.network, localAddress)
|
||||
require.NoError(t, err, "open connection 2 to 'local' listener")
|
||||
defer c2.Close()
|
||||
testDial(t, c2)
|
||||
testDial(t, c1)
|
||||
})
|
||||
|
||||
// Test doing TCP, UDP and Unix at the same time.
|
||||
t.Run("All", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
client = coderdtest.New(t, nil)
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
_, workspace = runAgent(t, client, user.UserID)
|
||||
// These aren't fixed size because we exclude Unix on Windows.
|
||||
dials = []addr{}
|
||||
flags = []string{}
|
||||
)
|
||||
|
||||
// Start listeners and populate arrays with the cases.
|
||||
for _, c := range cases {
|
||||
if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" {
|
||||
// Unix isn't supported on Windows, but we can still
|
||||
// test other protocols together.
|
||||
continue
|
||||
}
|
||||
|
||||
l, p := setupTestListener(t, c.setupRemote(t))
|
||||
t.Cleanup(func() {
|
||||
_ = l.Close()
|
||||
})
|
||||
|
||||
localAddress, localFlag := c.setupLocal(t)
|
||||
dials = append(dials, addr{
|
||||
network: c.network,
|
||||
addr: localAddress,
|
||||
})
|
||||
flags = append(flags, fmt.Sprintf(c.flag, localFlag, p))
|
||||
}
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listeners.
|
||||
cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() {
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
|
||||
// Open connections to all items in the "dial" array.
|
||||
var (
|
||||
d = net.Dialer{Timeout: 3 * time.Second}
|
||||
conns = make([]net.Conn, len(dials))
|
||||
)
|
||||
for i, a := range dials {
|
||||
c, err := d.DialContext(ctx, a.network, a.addr)
|
||||
require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1)
|
||||
t.Cleanup(func() {
|
||||
_ = c.Close()
|
||||
})
|
||||
conns[i] = c
|
||||
}
|
||||
|
||||
// Test each connection in reverse order.
|
||||
for i := len(conns) - 1; i >= 0; i-- {
|
||||
testDial(t, conns[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// runAgent creates a fake workspace and starts an agent locally for that
|
||||
// workspace. The agent will be cleaned up on test completion.
|
||||
func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]codersdk.WorkspaceResource, codersdk.Workspace) {
|
||||
ctx := context.Background()
|
||||
user, err := client.User(ctx, userID.String())
|
||||
require.NoError(t, err, "specified user does not exist")
|
||||
require.Greater(t, len(user.OrganizationIDs), 0, "user has no organizations")
|
||||
orgID := user.OrganizationIDs[0]
|
||||
|
||||
// Setup echo provisioner
|
||||
agentToken := uuid.NewString()
|
||||
coderdtest.NewProvisionerDaemon(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: agentToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
|
||||
// Create template and workspace
|
||||
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// Start workspace agent in a goroutine
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
|
||||
clitest.SetupConfig(t, client, root)
|
||||
agentCtx, agentCancel := context.WithCancel(ctx)
|
||||
t.Cleanup(agentCancel)
|
||||
go func() {
|
||||
err := cmd.ExecuteContext(agentCtx)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
return resources, workspace
|
||||
}
|
||||
|
||||
// setupTestListener starts accepting connections and echoing a single packet.
|
||||
// Returns the listener and the listen port or Unix path.
|
||||
func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) {
|
||||
t.Cleanup(func() {
|
||||
_ = l.Close()
|
||||
})
|
||||
go func() {
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go testAccept(t, c)
|
||||
}
|
||||
}()
|
||||
|
||||
addr := l.Addr().String()
|
||||
if !strings.HasPrefix(l.Addr().Network(), "unix") {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
||||
addr = port
|
||||
}
|
||||
|
||||
return l, addr
|
||||
}
|
||||
|
||||
var dialTestPayload = []byte("dean-was-here123")
|
||||
|
||||
func testDial(t *testing.T, c net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
assertWritePayload(t, c, dialTestPayload)
|
||||
assertReadPayload(t, c, dialTestPayload)
|
||||
}
|
||||
|
||||
func testAccept(t *testing.T, c net.Conn) {
|
||||
t.Helper()
|
||||
defer c.Close()
|
||||
|
||||
assertReadPayload(t, c, dialTestPayload)
|
||||
assertWritePayload(t, c, dialTestPayload)
|
||||
}
|
||||
|
||||
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
||||
b := make([]byte, len(payload)+16)
|
||||
n, err := r.Read(b)
|
||||
require.NoError(t, err, "read payload")
|
||||
require.Equal(t, len(payload), n, "read payload length does not match")
|
||||
require.Equal(t, payload, b[:n])
|
||||
}
|
||||
|
||||
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
||||
n, err := w.Write(payload)
|
||||
require.NoError(t, err, "write payload")
|
||||
require.Equal(t, len(payload), n, "payload length does not match")
|
||||
}
|
||||
|
||||
func waitForPortForwardReady(t *testing.T, output *threadSafeBuffer) {
|
||||
for i := 0; i < 100; i++ {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
data := output.String()
|
||||
if strings.Contains(data, "Ready!") {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatal("port-forward command did not become ready in time")
|
||||
}
|
||||
|
||||
type addr struct {
|
||||
network string
|
||||
addr string
|
||||
}
|
||||
|
||||
type threadSafeBuffer struct {
|
||||
b *bytes.Buffer
|
||||
mut *sync.RWMutex
|
||||
}
|
||||
|
||||
func newThreadSafeBuffer() *threadSafeBuffer {
|
||||
return &threadSafeBuffer{
|
||||
b: bytes.NewBuffer(nil),
|
||||
mut: new(sync.RWMutex),
|
||||
}
|
||||
}
|
||||
|
||||
var _ io.Reader = &threadSafeBuffer{}
|
||||
var _ io.Writer = &threadSafeBuffer{}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (b *threadSafeBuffer) Read(p []byte) (int, error) {
|
||||
b.mut.RLock()
|
||||
defer b.mut.RUnlock()
|
||||
|
||||
return b.b.Read(p)
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (b *threadSafeBuffer) Write(p []byte) (int, error) {
|
||||
b.mut.Lock()
|
||||
defer b.mut.Unlock()
|
||||
|
||||
return b.b.Write(p)
|
||||
}
|
||||
|
||||
func (b *threadSafeBuffer) String() string {
|
||||
b.mut.RLock()
|
||||
defer b.mut.RUnlock()
|
||||
|
||||
return b.b.String()
|
||||
}
|
19
cli/root.go
19
cli/root.go
|
@ -36,6 +36,15 @@ const (
|
|||
varForceTty = "force-tty"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Customizes the color of headings to make subcommands more visually
|
||||
// appealing.
|
||||
header := cliui.Styles.Placeholder
|
||||
cobra.AddTemplateFunc("usageHeader", func(s string) string {
|
||||
return header.Render(s)
|
||||
})
|
||||
}
|
||||
|
||||
func Root() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "coder",
|
||||
|
@ -71,7 +80,7 @@ func Root() *cobra.Command {
|
|||
templates(),
|
||||
update(),
|
||||
users(),
|
||||
tunnel(),
|
||||
portForward(),
|
||||
workspaceAgent(),
|
||||
)
|
||||
|
||||
|
@ -179,13 +188,7 @@ func isTTY(cmd *cobra.Command) bool {
|
|||
}
|
||||
|
||||
func usageTemplate() string {
|
||||
// Customizes the color of headings to make subcommands
|
||||
// more visually appealing.
|
||||
header := cliui.Styles.Placeholder
|
||||
cobra.AddTemplateFunc("usageHeader", func(s string) string {
|
||||
return header.Render(s)
|
||||
})
|
||||
|
||||
// usageHeader is defined in init().
|
||||
return `{{usageHeader "Usage:"}}
|
||||
{{- if .Runnable}}
|
||||
{{.UseLine}}
|
||||
|
|
159
cli/ssh.go
159
cli/ssh.go
|
@ -50,94 +50,23 @@ func ssh() *cobra.Command {
|
|||
return err
|
||||
}
|
||||
|
||||
var workspace codersdk.Workspace
|
||||
var workspaceParts []string
|
||||
if shuffle {
|
||||
err := cobra.ExactArgs(0)(cmd, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workspaces, err := client.WorkspacesByOwner(cmd.Context(), organization.ID, codersdk.Me)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(workspaces) == 0 {
|
||||
return xerrors.New("no workspaces to shuffle")
|
||||
}
|
||||
|
||||
idx, err := cryptorand.Intn(len(workspaces))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
workspace = workspaces[idx]
|
||||
} else {
|
||||
err := cobra.MinimumNArgs(1)(cmd, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workspaceParts = strings.Split(args[0], ".")
|
||||
workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), organization.ID, codersdk.Me, workspaceParts[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
|
||||
return xerrors.New("workspace must be in start transition to ssh")
|
||||
}
|
||||
|
||||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||||
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
|
||||
return xerrors.New("workspace is deleting...")
|
||||
}
|
||||
|
||||
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
|
||||
workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], shuffle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agents := make([]codersdk.WorkspaceAgent, 0)
|
||||
for _, resource := range resources {
|
||||
agents = append(agents, resource.Agents...)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return xerrors.New("workspace has no agents")
|
||||
}
|
||||
var agent codersdk.WorkspaceAgent
|
||||
if len(workspaceParts) >= 2 {
|
||||
for _, otherAgent := range agents {
|
||||
if otherAgent.Name != workspaceParts[1] {
|
||||
continue
|
||||
}
|
||||
agent = otherAgent
|
||||
break
|
||||
}
|
||||
if agent.ID == uuid.Nil {
|
||||
return xerrors.Errorf("agent not found by name %q", workspaceParts[1])
|
||||
}
|
||||
}
|
||||
if agent.ID == uuid.Nil {
|
||||
if len(agents) > 1 {
|
||||
if !shuffle {
|
||||
return xerrors.New("you must specify the name of an agent")
|
||||
}
|
||||
idx, err := cryptorand.Intn(len(agents))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
agent = agents[idx]
|
||||
} else {
|
||||
agent = agents[0]
|
||||
}
|
||||
}
|
||||
// OpenSSH passes stderr directly to the calling TTY.
|
||||
// This is required in "stdio" mode so a connecting indicator can be displayed.
|
||||
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
|
||||
|
@ -233,6 +162,92 @@ func ssh() *cobra.Command {
|
|||
return cmd
|
||||
}
|
||||
|
||||
// getWorkspaceAgent returns the workspace and agent selected using either the
|
||||
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
|
||||
// if `shuffle` is true.
|
||||
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, orgID uuid.UUID, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
|
||||
ctx := cmd.Context()
|
||||
|
||||
var (
|
||||
workspace codersdk.Workspace
|
||||
workspaceParts = strings.Split(in, ".")
|
||||
err error
|
||||
)
|
||||
if shuffle {
|
||||
workspaces, err := client.WorkspacesByOwner(cmd.Context(), orgID, userID)
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||
}
|
||||
if len(workspaces) == 0 {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("no workspaces to shuffle")
|
||||
}
|
||||
|
||||
workspace, err = cryptorand.Element(workspaces)
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||
}
|
||||
} else {
|
||||
workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), orgID, userID, workspaceParts[0])
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh")
|
||||
}
|
||||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||||
err := cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||
}
|
||||
}
|
||||
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name)
|
||||
}
|
||||
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("fetch workspace resources: %w", err)
|
||||
}
|
||||
|
||||
agents := make([]codersdk.WorkspaceAgent, 0)
|
||||
for _, resource := range resources {
|
||||
agents = append(agents, resource.Agents...)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name)
|
||||
}
|
||||
var agent codersdk.WorkspaceAgent
|
||||
if len(workspaceParts) >= 2 {
|
||||
for _, otherAgent := range agents {
|
||||
if otherAgent.Name != workspaceParts[1] {
|
||||
continue
|
||||
}
|
||||
agent = otherAgent
|
||||
break
|
||||
}
|
||||
if agent.ID == uuid.Nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1])
|
||||
}
|
||||
}
|
||||
if agent.ID == uuid.Nil {
|
||||
if len(agents) > 1 {
|
||||
if !shuffle {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent")
|
||||
}
|
||||
agent, err = cryptorand.Element(agents)
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||
}
|
||||
} else {
|
||||
agent = agents[0]
|
||||
}
|
||||
}
|
||||
|
||||
return workspace, agent, nil
|
||||
}
|
||||
|
||||
// Attempt to poll workspace autostop. We write a per-workspace lockfile to
|
||||
// avoid spamming the user with notifications in case of multiple instances
|
||||
// of the CLI running simultaneously.
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
)
|
||||
|
||||
func templates() *cobra.Command {
|
||||
|
@ -13,15 +14,15 @@ func templates() *cobra.Command {
|
|||
Example: `
|
||||
- Create a template for developers to create workspaces
|
||||
|
||||
` + color.New(color.FgHiMagenta).Sprint("$ coder templates create") + `
|
||||
` + cliui.Styles.Code.Render("$ coder templates create") + `
|
||||
|
||||
- Make changes to your template, and plan the changes
|
||||
|
||||
` + color.New(color.FgHiMagenta).Sprint("$ coder templates plan <name>") + `
|
||||
` + cliui.Styles.Code.Render("$ coder templates plan <name>") + `
|
||||
|
||||
- Update the template. Your developers can update their workspaces
|
||||
|
||||
` + color.New(color.FgHiMagenta).Sprint("$ coder templates update <name>"),
|
||||
` + cliui.Styles.Code.Render("$ coder templates update <name>"),
|
||||
}
|
||||
cmd.AddCommand(
|
||||
templateCreate(),
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
package cli
|
||||
|
||||
import "github.com/spf13/cobra"
|
||||
|
||||
func tunnel() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "tunnel",
|
||||
Short: "Forward ports to your local machine",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package cryptorand
|
||||
|
||||
import (
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Element returns a random element of the slice. An error will be returned if
|
||||
// the slice has no elements in it.
|
||||
func Element[T any](s []T) (out T, err error) {
|
||||
if len(s) == 0 {
|
||||
return out, xerrors.New("slice must have at least one element")
|
||||
}
|
||||
|
||||
i, err := Intn(len(s))
|
||||
if err != nil {
|
||||
return out, xerrors.Errorf("generate random integer from 0-%v: %w", len(s), err)
|
||||
}
|
||||
|
||||
return s[i], nil
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package cryptorand_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestRandomElement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := []string{}
|
||||
v, err := cryptorand.Element(s)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "slice must have at least one element")
|
||||
require.Empty(t, v)
|
||||
})
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate random slices of ints and strings
|
||||
var (
|
||||
ints = make([]int, 20)
|
||||
strings = make([]string, 20)
|
||||
)
|
||||
for i := range ints {
|
||||
v, err := cryptorand.Intn(1024)
|
||||
require.NoError(t, err, "generate random int for test slice")
|
||||
ints[i] = v
|
||||
}
|
||||
for i := range strings {
|
||||
v, err := cryptorand.String(10)
|
||||
require.NoError(t, err, "generate random string for test slice")
|
||||
strings[i] = v
|
||||
}
|
||||
|
||||
// Get a random value from each 20 times.
|
||||
for i := 0; i < 20; i++ {
|
||||
iv, err := cryptorand.Element(ints)
|
||||
require.NoError(t, err, "unexpected error from Element(ints)")
|
||||
t.Logf("random int slice element: %v", iv)
|
||||
require.Contains(t, ints, iv)
|
||||
|
||||
sv, err := cryptorand.Element(strings)
|
||||
require.NoError(t, err, "unexpected error from Element(strings)")
|
||||
t.Logf("random string slice element: %v", sv)
|
||||
require.Contains(t, strings, sv)
|
||||
}
|
||||
})
|
||||
}
|
2
go.mod
2
go.mod
|
@ -90,6 +90,7 @@ require (
|
|||
github.com/pion/logging v0.2.2
|
||||
github.com/pion/transport v0.13.0
|
||||
github.com/pion/turn/v2 v2.0.8
|
||||
github.com/pion/udp v0.1.1
|
||||
github.com/pion/webrtc/v3 v3.1.39
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
|
||||
github.com/pkg/sftp v1.13.4
|
||||
|
@ -217,7 +218,6 @@ require (
|
|||
github.com/pion/sdp/v3 v3.0.5 // indirect
|
||||
github.com/pion/srtp/v2 v2.0.7 // indirect
|
||||
github.com/pion/stun v0.3.5 // indirect
|
||||
github.com/pion/udp v0.1.1 // indirect
|
||||
github.com/pires/go-proxyproto v0.6.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
|
|
|
@ -53,8 +53,8 @@ type ChannelOptions struct {
|
|||
// Arbitrary string that can be parsed on `Accept`.
|
||||
Protocol string
|
||||
|
||||
// Ordered determines whether the channel acts like
|
||||
// a TCP connection. Defaults to false.
|
||||
// Unordered determines whether the channel acts like
|
||||
// a UDP connection. Defaults to false.
|
||||
Unordered bool
|
||||
|
||||
// Whether the channel will be left open on disconnect or not.
|
||||
|
|
19
peer/conn.go
19
peer/conn.go
|
@ -68,7 +68,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
|
|||
closed: make(chan struct{}),
|
||||
closedRTC: make(chan struct{}),
|
||||
closedICE: make(chan struct{}),
|
||||
dcOpenChannel: make(chan *webrtc.DataChannel),
|
||||
dcOpenChannel: make(chan *webrtc.DataChannel, 8),
|
||||
dcDisconnectChannel: make(chan struct{}),
|
||||
dcFailedChannel: make(chan struct{}),
|
||||
localCandidateChannel: make(chan webrtc.ICECandidateInit),
|
||||
|
@ -264,12 +264,13 @@ func (c *Conn) init() error {
|
|||
}()
|
||||
})
|
||||
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.dcOpenChannel <- dc:
|
||||
default:
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case c.dcOpenChannel <- dc:
|
||||
}
|
||||
}()
|
||||
})
|
||||
_, err := c.pingChannel()
|
||||
if err != nil {
|
||||
|
@ -469,8 +470,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
|
|||
return newChannel(c, dataChannel, &ChannelOptions{}), nil
|
||||
}
|
||||
|
||||
// Dial creates a new DataChannel.
|
||||
func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
|
||||
// CreateChannel creates a new DataChannel.
|
||||
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
|
||||
if opts == nil {
|
||||
opts = &ChannelOptions{}
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ func TestConn(t *testing.T) {
|
|||
_, err := server.Ping()
|
||||
require.NoError(t, err)
|
||||
// Create a channel that closes on disconnect.
|
||||
channel, err := server.Dial(context.Background(), "wow", nil)
|
||||
channel, err := server.CreateChannel(context.Background(), "wow", nil)
|
||||
assert.NoError(t, err)
|
||||
err = wan.Stop()
|
||||
require.NoError(t, err)
|
||||
|
@ -108,7 +108,7 @@ func TestConn(t *testing.T) {
|
|||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
sch, err := server.Accept(context.Background())
|
||||
|
@ -124,7 +124,7 @@ func TestConn(t *testing.T) {
|
|||
t.Parallel()
|
||||
client, server, wan := createPair(t)
|
||||
exchange(t, client, server)
|
||||
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
sch, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -141,7 +141,7 @@ func TestConn(t *testing.T) {
|
|||
t.Parallel()
|
||||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||
require.NoError(t, err)
|
||||
sch, err := server.Accept(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
@ -196,7 +196,7 @@ func TestConn(t *testing.T) {
|
|||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
var cch *peer.Channel
|
||||
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
cch, err = client.Dial(ctx, "hello", &peer.ChannelOptions{})
|
||||
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -234,7 +234,7 @@ func TestConn(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
expectedErr := xerrors.New("wow")
|
||||
_ = conn.CloseWithError(expectedErr)
|
||||
_, err = conn.Dial(context.Background(), "", nil)
|
||||
_, err = conn.CreateChannel(context.Background(), "", nil)
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
|
@ -274,7 +274,7 @@ func TestConn(t *testing.T) {
|
|||
client, server, _ := createPair(t)
|
||||
exchange(t, client, server)
|
||||
go func() {
|
||||
channel, err := client.Dial(context.Background(), "test", nil)
|
||||
channel, err := client.CreateChannel(context.Background(), "test", nil)
|
||||
require.NoError(t, err)
|
||||
_, err = channel.Write([]byte{1, 2})
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue