mirror of https://github.com/coder/coder.git
feat: add port-forward subcommand (#1350)
This commit is contained in:
parent
76fc59aa79
commit
9141be3656
114
agent/agent.go
114
agent/agent.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
@ -211,6 +212,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
|
||||||
go a.sshServer.HandleConn(channel.NetConn())
|
go a.sshServer.HandleConn(channel.NetConn())
|
||||||
case "reconnecting-pty":
|
case "reconnecting-pty":
|
||||||
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
|
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
|
||||||
|
case "dial":
|
||||||
|
go a.handleDial(ctx, channel.Label(), channel.NetConn())
|
||||||
default:
|
default:
|
||||||
a.logger.Warn(ctx, "unhandled protocol from channel",
|
a.logger.Warn(ctx, "unhandled protocol from channel",
|
||||||
slog.F("protocol", channel.Protocol()),
|
slog.F("protocol", channel.Protocol()),
|
||||||
|
@ -617,6 +620,70 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dialResponse is written to datachannels with protocol "dial" by the agent as
|
||||||
|
// the first packet to signify whether the dial succeeded or failed.
|
||||||
|
type dialResponse struct {
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
writeError := func(responseError error) error {
|
||||||
|
msg := ""
|
||||||
|
if responseError != nil {
|
||||||
|
msg = responseError.Error()
|
||||||
|
if !xerrors.Is(responseError, io.EOF) {
|
||||||
|
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(dialResponse{
|
||||||
|
Error: msg,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
|
||||||
|
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(b)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(label)
|
||||||
|
if err != nil {
|
||||||
|
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
network := u.Scheme
|
||||||
|
addr := u.Host + u.Path
|
||||||
|
if strings.HasPrefix(network, "unix") {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr, err = ExpandRelativeHomePath(addr)
|
||||||
|
if err != nil {
|
||||||
|
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d := net.Dialer{Timeout: 3 * time.Second}
|
||||||
|
nconn, err := d.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writeError(nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Bicopy(ctx, conn, nconn)
|
||||||
|
}
|
||||||
|
|
||||||
// isClosed returns whether the API is closed or not.
|
// isClosed returns whether the API is closed or not.
|
||||||
func (a *agent) isClosed() bool {
|
func (a *agent) isClosed() bool {
|
||||||
select {
|
select {
|
||||||
|
@ -662,3 +729,50 @@ func (r *reconnectingPTY) Close() {
|
||||||
r.circularBuffer.Reset()
|
r.circularBuffer.Reset()
|
||||||
r.timeout.Stop()
|
r.timeout.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bicopy copies all of the data between the two connections and will close them
|
||||||
|
// after one or both of them are done writing. If the context is canceled, both
|
||||||
|
// of the connections will be closed.
|
||||||
|
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
|
||||||
|
defer c1.Close()
|
||||||
|
defer c2.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
copyFunc := func(dst io.WriteCloser, src io.Reader) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(2)
|
||||||
|
go copyFunc(c1, c2)
|
||||||
|
go copyFunc(c2, c1)
|
||||||
|
|
||||||
|
// Convert waitgroup to a channel so we can also wait on the context.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandRelativeHomePath expands the tilde at the beginning of a path to the
|
||||||
|
// current user's home directory and returns a full absolute path.
|
||||||
|
func ExpandRelativeHomePath(in string) (string, error) {
|
||||||
|
usr, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return "", xerrors.Errorf("get current user details: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if in == "~" {
|
||||||
|
in = usr.HomeDir
|
||||||
|
} else if strings.HasPrefix(in, "~/") {
|
||||||
|
in = filepath.Join(usr.HomeDir, in[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Abs(in)
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/pion/udp"
|
||||||
"github.com/pion/webrtc/v3"
|
"github.com/pion/webrtc/v3"
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -234,6 +235,112 @@ func TestAgent(t *testing.T) {
|
||||||
findEcho()
|
findEcho()
|
||||||
findEcho()
|
findEcho()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Dial", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(t *testing.T) net.Listener
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TCP",
|
||||||
|
setup: func(t *testing.T) net.Listener {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err, "create TCP listener")
|
||||||
|
return l
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP",
|
||||||
|
setup: func(t *testing.T) net.Listener {
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
l, err := udp.Listen("udp", &addr)
|
||||||
|
require.NoError(t, err, "create UDP listener")
|
||||||
|
return l
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unix",
|
||||||
|
setup: func(t *testing.T) net.Listener {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||||
|
require.NoError(t, err, "create temp dir for unix listener")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.RemoveAll(tmpDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||||
|
require.NoError(t, err, "create UDP listener")
|
||||||
|
return l
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
c := c
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Setup listener
|
||||||
|
l := c.setup(t)
|
||||||
|
defer l.Close()
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go testAccept(t, c)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Dial the listener over WebRTC twice and test out of order
|
||||||
|
conn := setupAgent(t, agent.Metadata{}, 0)
|
||||||
|
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn1.Close()
|
||||||
|
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn2.Close()
|
||||||
|
testDial(t, conn2)
|
||||||
|
testDial(t, conn1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialError", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
// This test uses Unix listeners so we can very easily ensure that
|
||||||
|
// no other tests decide to listen on the same random port we
|
||||||
|
// picked.
|
||||||
|
t.Skip("this test is unsupported on Windows")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||||
|
require.NoError(t, err, "create temp dir")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.RemoveAll(tmpDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Try to dial the non-existent Unix socket over WebRTC
|
||||||
|
conn := setupAgent(t, agent.Metadata{}, 0)
|
||||||
|
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "remote dial error")
|
||||||
|
require.ErrorContains(t, err, "no such file")
|
||||||
|
require.Nil(t, netConn)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
|
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
|
||||||
|
@ -303,3 +410,34 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var dialTestPayload = []byte("dean-was-here123")
|
||||||
|
|
||||||
|
func testDial(t *testing.T, c net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
assertWritePayload(t, c, dialTestPayload)
|
||||||
|
assertReadPayload(t, c, dialTestPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAccept(t *testing.T, c net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
assertReadPayload(t, c, dialTestPayload)
|
||||||
|
assertWritePayload(t, c, dialTestPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
||||||
|
b := make([]byte, len(payload)+16)
|
||||||
|
n, err := r.Read(b)
|
||||||
|
require.NoError(t, err, "read payload")
|
||||||
|
require.Equal(t, len(payload), n, "read payload length does not match")
|
||||||
|
require.Equal(t, payload, b[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
||||||
|
n, err := w.Write(payload)
|
||||||
|
require.NoError(t, err, "write payload")
|
||||||
|
require.Equal(t, len(payload), n, "payload length does not match")
|
||||||
|
}
|
||||||
|
|
|
@ -2,8 +2,11 @@ package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
@ -32,7 +35,7 @@ type Conn struct {
|
||||||
// ReconnectingPTY returns a connection serving a TTY that can
|
// ReconnectingPTY returns a connection serving a TTY that can
|
||||||
// be reconnected to via ID.
|
// be reconnected to via ID.
|
||||||
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
|
func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error) {
|
||||||
channel, err := c.Dial(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
|
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d", id, height, width), &peer.ChannelOptions{
|
||||||
Protocol: "reconnecting-pty",
|
Protocol: "reconnecting-pty",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -43,7 +46,7 @@ func (c *Conn) ReconnectingPTY(id string, height, width uint16) (net.Conn, error
|
||||||
|
|
||||||
// SSH dials the built-in SSH server.
|
// SSH dials the built-in SSH server.
|
||||||
func (c *Conn) SSH() (net.Conn, error) {
|
func (c *Conn) SSH() (net.Conn, error) {
|
||||||
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
|
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
|
||||||
Protocol: "ssh",
|
Protocol: "ssh",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -71,6 +74,42 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
|
||||||
return ssh.NewClient(sshConn, channels, requests), nil
|
return ssh.NewClient(sshConn, channels, requests), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialContext dials an arbitrary protocol+address from inside the workspace and
|
||||||
|
// proxies it through the provided net.Conn.
|
||||||
|
func (c *Conn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
|
u := &url.URL{
|
||||||
|
Scheme: network,
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(network, "unix") {
|
||||||
|
u.Path = addr
|
||||||
|
} else {
|
||||||
|
u.Host = addr
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
|
||||||
|
Protocol: "dial",
|
||||||
|
Unordered: strings.HasPrefix(network, "udp"),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("create datachannel: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first message written from the other side is a JSON payload
|
||||||
|
// containing the dial error.
|
||||||
|
dec := json.NewDecoder(channel)
|
||||||
|
var res dialResponse
|
||||||
|
err = dec.Decode(&res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
|
||||||
|
}
|
||||||
|
if res.Error != "" {
|
||||||
|
_ = channel.Close()
|
||||||
|
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return channel.NetConn(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
_ = c.Negotiator.DRPCConn().Close()
|
_ = c.Negotiator.DRPCConn().Close()
|
||||||
return c.Conn.Close()
|
return c.Conn.Close()
|
||||||
|
|
|
@ -0,0 +1,379 @@
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/pion/udp"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
coderagent "github.com/coder/coder/agent"
|
||||||
|
"github.com/coder/coder/cli/cliui"
|
||||||
|
"github.com/coder/coder/coderd/database"
|
||||||
|
"github.com/coder/coder/codersdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func portForward() *cobra.Command {
|
||||||
|
var (
|
||||||
|
tcpForwards []string // <port>:<port>
|
||||||
|
udpForwards []string // <port>:<port>
|
||||||
|
unixForwards []string // <path>:<path> OR <port>:<path>
|
||||||
|
)
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "port-forward <workspace>",
|
||||||
|
Aliases: []string{"tunnel"},
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
Example: `
|
||||||
|
- Port forward a single TCP port from 1234 in the workspace to port 5678 on
|
||||||
|
your local machine
|
||||||
|
|
||||||
|
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 5678:1234") + `
|
||||||
|
|
||||||
|
- Port forward a single UDP port from port 9000 to port 9000 on your local
|
||||||
|
machine
|
||||||
|
|
||||||
|
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --udp 9000") + `
|
||||||
|
|
||||||
|
- Forward a Unix socket in the workspace to a local Unix socket
|
||||||
|
|
||||||
|
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --unix ./local.sock:~/remote.sock") + `
|
||||||
|
|
||||||
|
- Forward a Unix socket in the workspace to a local TCP port
|
||||||
|
|
||||||
|
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --unix 8080:~/remote.sock") + `
|
||||||
|
|
||||||
|
- Port forward multiple TCP ports and a UDP port
|
||||||
|
|
||||||
|
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53"),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("parse port-forward specs: %w", err)
|
||||||
|
}
|
||||||
|
if len(specs) == 0 {
|
||||||
|
err = cmd.Help()
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("generate help output: %w", err)
|
||||||
|
}
|
||||||
|
return xerrors.New("no port-forwards requested")
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := createClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
organization, err := currentOrganization(cmd, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
|
||||||
|
return xerrors.New("workspace must be in start transition to port-forward")
|
||||||
|
}
|
||||||
|
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||||||
|
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
|
||||||
|
WorkspaceName: workspace.Name,
|
||||||
|
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
|
||||||
|
return client.WorkspaceAgent(ctx, agent.ID)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("await agent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("dial workspace agent: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Start all listeners.
|
||||||
|
var (
|
||||||
|
ctx, cancel = context.WithCancel(cmd.Context())
|
||||||
|
wg = new(sync.WaitGroup)
|
||||||
|
listeners = make([]net.Listener, len(specs))
|
||||||
|
closeAllListeners = func() {
|
||||||
|
for _, l := range listeners {
|
||||||
|
if l == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = l.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
defer cancel()
|
||||||
|
for i, spec := range specs {
|
||||||
|
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
|
||||||
|
if err != nil {
|
||||||
|
closeAllListeners()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
listeners[i] = l
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the context to be canceled or for a signal and close
|
||||||
|
// all listeners.
|
||||||
|
var closeErr error
|
||||||
|
go func() {
|
||||||
|
sigs := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
closeErr = ctx.Err()
|
||||||
|
case <-sigs:
|
||||||
|
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
|
||||||
|
closeErr = xerrors.New("signal received")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
closeAllListeners()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
|
||||||
|
wg.Wait()
|
||||||
|
return closeErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
|
||||||
|
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
|
||||||
|
cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port")
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
|
||||||
|
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
|
||||||
|
|
||||||
|
var (
|
||||||
|
l net.Listener
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
switch spec.listenNetwork {
|
||||||
|
case "tcp":
|
||||||
|
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
|
||||||
|
case "udp":
|
||||||
|
var host, port string
|
||||||
|
host, port, err = net.SplitHostPort(spec.listenAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var portInt int
|
||||||
|
portInt, err = strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
|
||||||
|
IP: net.ParseIP(host),
|
||||||
|
Port: portInt,
|
||||||
|
})
|
||||||
|
case "unix":
|
||||||
|
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
|
||||||
|
default:
|
||||||
|
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(spec portForwardSpec) {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
netConn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
|
||||||
|
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(netConn net.Conn) {
|
||||||
|
defer netConn.Close()
|
||||||
|
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer remoteConn.Close()
|
||||||
|
|
||||||
|
coderagent.Bicopy(ctx, netConn, remoteConn)
|
||||||
|
}(netConn)
|
||||||
|
}
|
||||||
|
}(spec)
|
||||||
|
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type portForwardSpec struct {
|
||||||
|
listenNetwork string // tcp, udp, unix
|
||||||
|
listenAddress string // <ip>:<port> or path
|
||||||
|
|
||||||
|
dialNetwork string // tcp, udp, unix
|
||||||
|
dialAddress string // <ip>:<port> or path
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) {
|
||||||
|
specs := []portForwardSpec{}
|
||||||
|
|
||||||
|
for _, spec := range tcpSpecs {
|
||||||
|
local, remote, err := parsePortPort(spec)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
specs = append(specs, portForwardSpec{
|
||||||
|
listenNetwork: "tcp",
|
||||||
|
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
|
||||||
|
dialNetwork: "tcp",
|
||||||
|
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, spec := range udpSpecs {
|
||||||
|
local, remote, err := parsePortPort(spec)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
specs = append(specs, portForwardSpec{
|
||||||
|
listenNetwork: "udp",
|
||||||
|
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
|
||||||
|
dialNetwork: "udp",
|
||||||
|
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, specStr := range unixSpecs {
|
||||||
|
localPath, localTCP, remotePath, err := parseUnixUnix(specStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
spec := portForwardSpec{
|
||||||
|
dialNetwork: "unix",
|
||||||
|
dialAddress: remotePath,
|
||||||
|
}
|
||||||
|
if localPath == "" {
|
||||||
|
spec.listenNetwork = "tcp"
|
||||||
|
spec.listenAddress = fmt.Sprintf("127.0.0.1:%v", localTCP)
|
||||||
|
} else {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
return nil, xerrors.Errorf("Unix port-forwarding is not supported on Windows")
|
||||||
|
}
|
||||||
|
spec.listenNetwork = "unix"
|
||||||
|
spec.listenAddress = localPath
|
||||||
|
}
|
||||||
|
specs = append(specs, spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for duplicate entries.
|
||||||
|
locals := map[string]struct{}{}
|
||||||
|
for _, spec := range specs {
|
||||||
|
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
|
||||||
|
if _, ok := locals[localStr]; ok {
|
||||||
|
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
|
||||||
|
}
|
||||||
|
locals[localStr] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return specs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(in string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(strings.TrimSpace(in), 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, xerrors.Errorf("parse port %q: %w", in, err)
|
||||||
|
}
|
||||||
|
if port == 0 {
|
||||||
|
return 0, xerrors.New("port cannot be 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUnixPath(in string) (string, error) {
|
||||||
|
path, err := coderagent.ExpandRelativeHomePath(strings.TrimSpace(in))
|
||||||
|
if err != nil {
|
||||||
|
return "", xerrors.Errorf("tidy path %q: %w", in, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePortPort(in string) (local uint16, remote uint16, err error) {
|
||||||
|
parts := strings.Split(in, ":")
|
||||||
|
if len(parts) > 2 {
|
||||||
|
return 0, 0, xerrors.Errorf("invalid port specification %q", in)
|
||||||
|
}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
// Duplicate the single part
|
||||||
|
parts = append(parts, parts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
local, err = parsePort(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err)
|
||||||
|
}
|
||||||
|
remote, err = parsePort(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return local, remote, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePortOrUnixPath(in string) (string, uint16, error) {
|
||||||
|
port, err := parsePort(in)
|
||||||
|
if err == nil {
|
||||||
|
return "", port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := parseUnixPath(in)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, xerrors.Errorf("could not parse port or unix path %q: %w", in, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUnixUnix(in string) (string, uint16, string, error) {
|
||||||
|
parts := strings.Split(in, ":")
|
||||||
|
if len(parts) > 2 {
|
||||||
|
return "", 0, "", xerrors.Errorf("invalid port-forward specification %q", in)
|
||||||
|
}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
// Duplicate the single part
|
||||||
|
parts = append(parts, parts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
localPath, localPort, err := parsePortOrUnixPath(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, "", xerrors.Errorf("parse local part of spec %q: %w", in, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't really touch the remote path at all since it gets cleaned
|
||||||
|
// up/expanded on the remote.
|
||||||
|
return localPath, localPort, parts[1], nil
|
||||||
|
}
|
|
@ -0,0 +1,532 @@
|
||||||
|
package cli_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pion/udp"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/cli/clitest"
|
||||||
|
"github.com/coder/coder/coderd/coderdtest"
|
||||||
|
"github.com/coder/coder/codersdk"
|
||||||
|
"github.com/coder/coder/provisioner/echo"
|
||||||
|
"github.com/coder/coder/provisionersdk/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPortForward(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("None", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := coderdtest.New(t, nil)
|
||||||
|
_ = coderdtest.CreateFirstUser(t, client)
|
||||||
|
|
||||||
|
cmd, root := clitest.New(t, "port-forward", "blah")
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
buf := newThreadSafeBuffer()
|
||||||
|
cmd.SetOut(buf)
|
||||||
|
|
||||||
|
err := cmd.Execute()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "no port-forwards")
|
||||||
|
|
||||||
|
// Check that the help was printed.
|
||||||
|
require.Contains(t, buf.String(), "port-forward <workspace>")
|
||||||
|
})
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
network string
|
||||||
|
// The flag to pass to `coder port-forward X` to port-forward this type
|
||||||
|
// of connection. Has two format args (both strings), the first is the
|
||||||
|
// local address and the second is the remote address.
|
||||||
|
flag string
|
||||||
|
// setupRemote creates a "remote" listener to emulate a service in the
|
||||||
|
// workspace.
|
||||||
|
setupRemote func(t *testing.T) net.Listener
|
||||||
|
// setupLocal returns an available port or Unix socket path that the
|
||||||
|
// port-forward command will listen on "locally". Returns the address
|
||||||
|
// you pass to net.Dial, and the port/path you pass to `coder
|
||||||
|
// port-forward`.
|
||||||
|
setupLocal func(t *testing.T) (string, string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TCP",
|
||||||
|
network: "tcp",
|
||||||
|
flag: "--tcp=%v:%v",
|
||||||
|
setupRemote: func(t *testing.T) net.Listener {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err, "create TCP listener")
|
||||||
|
return l
|
||||||
|
},
|
||||||
|
setupLocal: func(t *testing.T) (string, string) {
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err, "create TCP listener to generate random port")
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||||
|
require.NoErrorf(t, err, "split TCP address %q", l.Addr().String())
|
||||||
|
return l.Addr().String(), port
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP",
|
||||||
|
network: "udp",
|
||||||
|
flag: "--udp=%v:%v",
|
||||||
|
setupRemote: func(t *testing.T) net.Listener {
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
l, err := udp.Listen("udp", &addr)
|
||||||
|
require.NoError(t, err, "create UDP listener")
|
||||||
|
return l
|
||||||
|
},
|
||||||
|
setupLocal: func(t *testing.T) (string, string) {
|
||||||
|
addr := net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
l, err := udp.Listen("udp", &addr)
|
||||||
|
require.NoError(t, err, "create UDP listener to generate random port")
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||||
|
require.NoErrorf(t, err, "split UDP address %q", l.Addr().String())
|
||||||
|
return l.Addr().String(), port
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unix",
|
||||||
|
network: "unix",
|
||||||
|
flag: "--unix=%v:%v",
|
||||||
|
setupRemote: func(t *testing.T) net.Listener {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||||
|
require.NoError(t, err, "create temp dir for unix listener")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.RemoveAll(tmpDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||||
|
require.NoError(t, err, "create UDP listener")
|
||||||
|
return l
|
||||||
|
},
|
||||||
|
setupLocal: func(t *testing.T) (string, string) {
|
||||||
|
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||||
|
require.NoError(t, err, "create temp dir for unix listener")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.RemoveAll(tmpDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
path := filepath.Join(tmpDir, "test.sock")
|
||||||
|
return path, path
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter
|
||||||
|
c := c
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("OnePort", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
client = coderdtest.New(t, nil)
|
||||||
|
user = coderdtest.CreateFirstUser(t, client)
|
||||||
|
_, workspace = runAgent(t, client, user.UserID)
|
||||||
|
l1, p1 = setupTestListener(t, c.setupRemote(t))
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = l1.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a flag that forwards from local to listener 1.
|
||||||
|
localAddress, localFlag := c.setupLocal(t)
|
||||||
|
flag := fmt.Sprintf(c.flag, localFlag, p1)
|
||||||
|
|
||||||
|
// Launch port-forward in a goroutine so we can start dialing
|
||||||
|
// the "local" listener.
|
||||||
|
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
buf := newThreadSafeBuffer()
|
||||||
|
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
err := cmd.ExecuteContext(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
}()
|
||||||
|
waitForPortForwardReady(t, buf)
|
||||||
|
|
||||||
|
// Open two connections simultaneously and test them out of
|
||||||
|
// sync.
|
||||||
|
d := net.Dialer{Timeout: 3 * time.Second}
|
||||||
|
c1, err := d.DialContext(ctx, c.network, localAddress)
|
||||||
|
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||||
|
defer c1.Close()
|
||||||
|
c2, err := d.DialContext(ctx, c.network, localAddress)
|
||||||
|
require.NoError(t, err, "open connection 2 to 'local' listener")
|
||||||
|
defer c2.Close()
|
||||||
|
testDial(t, c2)
|
||||||
|
testDial(t, c1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TwoPorts", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
client = coderdtest.New(t, nil)
|
||||||
|
user = coderdtest.CreateFirstUser(t, client)
|
||||||
|
_, workspace = runAgent(t, client, user.UserID)
|
||||||
|
l1, p1 = setupTestListener(t, c.setupRemote(t))
|
||||||
|
l2, p2 = setupTestListener(t, c.setupRemote(t))
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = l1.Close()
|
||||||
|
_ = l2.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a flags for listener 1 and listener 2.
|
||||||
|
localAddress1, localFlag1 := c.setupLocal(t)
|
||||||
|
localAddress2, localFlag2 := c.setupLocal(t)
|
||||||
|
flag1 := fmt.Sprintf(c.flag, localFlag1, p1)
|
||||||
|
flag2 := fmt.Sprintf(c.flag, localFlag2, p2)
|
||||||
|
|
||||||
|
// Launch port-forward in a goroutine so we can start dialing
|
||||||
|
// the "local" listeners.
|
||||||
|
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
buf := newThreadSafeBuffer()
|
||||||
|
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
err := cmd.ExecuteContext(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
}()
|
||||||
|
waitForPortForwardReady(t, buf)
|
||||||
|
|
||||||
|
// Open a connection to both listener 1 and 2 simultaneously and
|
||||||
|
// then test them out of order.
|
||||||
|
d := net.Dialer{Timeout: 3 * time.Second}
|
||||||
|
c1, err := d.DialContext(ctx, c.network, localAddress1)
|
||||||
|
require.NoError(t, err, "open connection 1 to 'local' listener 1")
|
||||||
|
defer c1.Close()
|
||||||
|
c2, err := d.DialContext(ctx, c.network, localAddress2)
|
||||||
|
require.NoError(t, err, "open connection 2 to 'local' listener 2")
|
||||||
|
defer c2.Close()
|
||||||
|
testDial(t, c2)
|
||||||
|
testDial(t, c1)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test doing a TCP -> Unix forward.
|
||||||
|
t.Run("TCP2Unix", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
client = coderdtest.New(t, nil)
|
||||||
|
user = coderdtest.CreateFirstUser(t, client)
|
||||||
|
_, workspace = runAgent(t, client, user.UserID)
|
||||||
|
|
||||||
|
// Find the TCP and Unix cases so we can use their setupLocal and
|
||||||
|
// setupRemote methods respectively.
|
||||||
|
tcpCase = cases[0]
|
||||||
|
unixCase = cases[2]
|
||||||
|
|
||||||
|
// Setup remote Unix listener.
|
||||||
|
l1, p1 = setupTestListener(t, unixCase.setupRemote(t))
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = l1.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a flag that forwards from local TCP to Unix listener 1.
|
||||||
|
// Notably this is a --unix flag.
|
||||||
|
localAddress, localFlag := tcpCase.setupLocal(t)
|
||||||
|
flag := fmt.Sprintf(unixCase.flag, localFlag, p1)
|
||||||
|
|
||||||
|
// Launch port-forward in a goroutine so we can start dialing
|
||||||
|
// the "local" listener.
|
||||||
|
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
buf := newThreadSafeBuffer()
|
||||||
|
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
err := cmd.ExecuteContext(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
}()
|
||||||
|
waitForPortForwardReady(t, buf)
|
||||||
|
|
||||||
|
// Open two connections simultaneously and test them out of
|
||||||
|
// sync.
|
||||||
|
d := net.Dialer{Timeout: 3 * time.Second}
|
||||||
|
c1, err := d.DialContext(ctx, tcpCase.network, localAddress)
|
||||||
|
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||||
|
defer c1.Close()
|
||||||
|
c2, err := d.DialContext(ctx, tcpCase.network, localAddress)
|
||||||
|
require.NoError(t, err, "open connection 2 to 'local' listener")
|
||||||
|
defer c2.Close()
|
||||||
|
testDial(t, c2)
|
||||||
|
testDial(t, c1)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test doing TCP, UDP and Unix at the same time.
|
||||||
|
t.Run("All", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
var (
|
||||||
|
client = coderdtest.New(t, nil)
|
||||||
|
user = coderdtest.CreateFirstUser(t, client)
|
||||||
|
_, workspace = runAgent(t, client, user.UserID)
|
||||||
|
// These aren't fixed size because we exclude Unix on Windows.
|
||||||
|
dials = []addr{}
|
||||||
|
flags = []string{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start listeners and populate arrays with the cases.
|
||||||
|
for _, c := range cases {
|
||||||
|
if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" {
|
||||||
|
// Unix isn't supported on Windows, but we can still
|
||||||
|
// test other protocols together.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
l, p := setupTestListener(t, c.setupRemote(t))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = l.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
localAddress, localFlag := c.setupLocal(t)
|
||||||
|
dials = append(dials, addr{
|
||||||
|
network: c.network,
|
||||||
|
addr: localAddress,
|
||||||
|
})
|
||||||
|
flags = append(flags, fmt.Sprintf(c.flag, localFlag, p))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch port-forward in a goroutine so we can start dialing
|
||||||
|
// the "local" listeners.
|
||||||
|
cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
buf := newThreadSafeBuffer()
|
||||||
|
cmd.SetOut(io.MultiWriter(buf, os.Stderr))
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
err := cmd.ExecuteContext(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
}()
|
||||||
|
waitForPortForwardReady(t, buf)
|
||||||
|
|
||||||
|
// Open connections to all items in the "dial" array.
|
||||||
|
var (
|
||||||
|
d = net.Dialer{Timeout: 3 * time.Second}
|
||||||
|
conns = make([]net.Conn, len(dials))
|
||||||
|
)
|
||||||
|
for i, a := range dials {
|
||||||
|
c, err := d.DialContext(ctx, a.network, a.addr)
|
||||||
|
require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = c.Close()
|
||||||
|
})
|
||||||
|
conns[i] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test each connection in reverse order.
|
||||||
|
for i := len(conns) - 1; i >= 0; i-- {
|
||||||
|
testDial(t, conns[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// runAgent creates a fake workspace and starts an agent locally for that
|
||||||
|
// workspace. The agent will be cleaned up on test completion.
|
||||||
|
func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]codersdk.WorkspaceResource, codersdk.Workspace) {
|
||||||
|
ctx := context.Background()
|
||||||
|
user, err := client.User(ctx, userID.String())
|
||||||
|
require.NoError(t, err, "specified user does not exist")
|
||||||
|
require.Greater(t, len(user.OrganizationIDs), 0, "user has no organizations")
|
||||||
|
orgID := user.OrganizationIDs[0]
|
||||||
|
|
||||||
|
// Setup echo provisioner
|
||||||
|
agentToken := uuid.NewString()
|
||||||
|
coderdtest.NewProvisionerDaemon(t, client)
|
||||||
|
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
|
||||||
|
Parse: echo.ParseComplete,
|
||||||
|
ProvisionDryRun: echo.ProvisionComplete,
|
||||||
|
Provision: []*proto.Provision_Response{{
|
||||||
|
Type: &proto.Provision_Response_Complete{
|
||||||
|
Complete: &proto.Provision_Complete{
|
||||||
|
Resources: []*proto.Resource{{
|
||||||
|
Name: "somename",
|
||||||
|
Type: "someinstance",
|
||||||
|
Agents: []*proto.Agent{{
|
||||||
|
Auth: &proto.Agent_Token{
|
||||||
|
Token: agentToken,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create template and workspace
|
||||||
|
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||||
|
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||||
|
workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID)
|
||||||
|
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||||
|
|
||||||
|
// Start workspace agent in a goroutine
|
||||||
|
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
agentCtx, agentCancel := context.WithCancel(ctx)
|
||||||
|
t.Cleanup(agentCancel)
|
||||||
|
go func() {
|
||||||
|
err := cmd.ExecuteContext(agentCtx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||||
|
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return resources, workspace
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTestListener starts accepting connections and echoing a single packet.
|
||||||
|
// Returns the listener and the listen port or Unix path.
|
||||||
|
func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) {
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = l.Close()
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go testAccept(t, c)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr := l.Addr().String()
|
||||||
|
if !strings.HasPrefix(l.Addr().Network(), "unix") {
|
||||||
|
_, port, err := net.SplitHostPort(addr)
|
||||||
|
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
||||||
|
addr = port
|
||||||
|
}
|
||||||
|
|
||||||
|
return l, addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var dialTestPayload = []byte("dean-was-here123")
|
||||||
|
|
||||||
|
func testDial(t *testing.T, c net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
assertWritePayload(t, c, dialTestPayload)
|
||||||
|
assertReadPayload(t, c, dialTestPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAccept(t *testing.T, c net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
assertReadPayload(t, c, dialTestPayload)
|
||||||
|
assertWritePayload(t, c, dialTestPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
|
||||||
|
b := make([]byte, len(payload)+16)
|
||||||
|
n, err := r.Read(b)
|
||||||
|
require.NoError(t, err, "read payload")
|
||||||
|
require.Equal(t, len(payload), n, "read payload length does not match")
|
||||||
|
require.Equal(t, payload, b[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
||||||
|
n, err := w.Write(payload)
|
||||||
|
require.NoError(t, err, "write payload")
|
||||||
|
require.Equal(t, len(payload), n, "payload length does not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForPortForwardReady(t *testing.T, output *threadSafeBuffer) {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
|
||||||
|
data := output.String()
|
||||||
|
if strings.Contains(data, "Ready!") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatal("port-forward command did not become ready in time")
|
||||||
|
}
|
||||||
|
|
||||||
|
type addr struct {
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type threadSafeBuffer struct {
|
||||||
|
b *bytes.Buffer
|
||||||
|
mut *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newThreadSafeBuffer() *threadSafeBuffer {
|
||||||
|
return &threadSafeBuffer{
|
||||||
|
b: bytes.NewBuffer(nil),
|
||||||
|
mut: new(sync.RWMutex),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ io.Reader = &threadSafeBuffer{}
|
||||||
|
var _ io.Writer = &threadSafeBuffer{}
|
||||||
|
|
||||||
|
// Read implements io.Reader.
|
||||||
|
func (b *threadSafeBuffer) Read(p []byte) (int, error) {
|
||||||
|
b.mut.RLock()
|
||||||
|
defer b.mut.RUnlock()
|
||||||
|
|
||||||
|
return b.b.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer.
|
||||||
|
func (b *threadSafeBuffer) Write(p []byte) (int, error) {
|
||||||
|
b.mut.Lock()
|
||||||
|
defer b.mut.Unlock()
|
||||||
|
|
||||||
|
return b.b.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *threadSafeBuffer) String() string {
|
||||||
|
b.mut.RLock()
|
||||||
|
defer b.mut.RUnlock()
|
||||||
|
|
||||||
|
return b.b.String()
|
||||||
|
}
|
19
cli/root.go
19
cli/root.go
|
@ -36,6 +36,15 @@ const (
|
||||||
varForceTty = "force-tty"
|
varForceTty = "force-tty"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Customizes the color of headings to make subcommands more visually
|
||||||
|
// appealing.
|
||||||
|
header := cliui.Styles.Placeholder
|
||||||
|
cobra.AddTemplateFunc("usageHeader", func(s string) string {
|
||||||
|
return header.Render(s)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func Root() *cobra.Command {
|
func Root() *cobra.Command {
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "coder",
|
Use: "coder",
|
||||||
|
@ -71,7 +80,7 @@ func Root() *cobra.Command {
|
||||||
templates(),
|
templates(),
|
||||||
update(),
|
update(),
|
||||||
users(),
|
users(),
|
||||||
tunnel(),
|
portForward(),
|
||||||
workspaceAgent(),
|
workspaceAgent(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -179,13 +188,7 @@ func isTTY(cmd *cobra.Command) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func usageTemplate() string {
|
func usageTemplate() string {
|
||||||
// Customizes the color of headings to make subcommands
|
// usageHeader is defined in init().
|
||||||
// more visually appealing.
|
|
||||||
header := cliui.Styles.Placeholder
|
|
||||||
cobra.AddTemplateFunc("usageHeader", func(s string) string {
|
|
||||||
return header.Render(s)
|
|
||||||
})
|
|
||||||
|
|
||||||
return `{{usageHeader "Usage:"}}
|
return `{{usageHeader "Usage:"}}
|
||||||
{{- if .Runnable}}
|
{{- if .Runnable}}
|
||||||
{{.UseLine}}
|
{{.UseLine}}
|
||||||
|
|
159
cli/ssh.go
159
cli/ssh.go
|
@ -50,94 +50,23 @@ func ssh() *cobra.Command {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var workspace codersdk.Workspace
|
|
||||||
var workspaceParts []string
|
|
||||||
if shuffle {
|
if shuffle {
|
||||||
err := cobra.ExactArgs(0)(cmd, args)
|
err := cobra.ExactArgs(0)(cmd, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
workspaces, err := client.WorkspacesByOwner(cmd.Context(), organization.ID, codersdk.Me)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(workspaces) == 0 {
|
|
||||||
return xerrors.New("no workspaces to shuffle")
|
|
||||||
}
|
|
||||||
|
|
||||||
idx, err := cryptorand.Intn(len(workspaces))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
workspace = workspaces[idx]
|
|
||||||
} else {
|
} else {
|
||||||
err := cobra.MinimumNArgs(1)(cmd, args)
|
err := cobra.MinimumNArgs(1)(cmd, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
workspaceParts = strings.Split(args[0], ".")
|
|
||||||
workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), organization.ID, codersdk.Me, workspaceParts[0])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
|
workspace, agent, err := getWorkspaceAndAgent(cmd, client, organization.ID, codersdk.Me, args[0], shuffle)
|
||||||
return xerrors.New("workspace must be in start transition to ssh")
|
|
||||||
}
|
|
||||||
|
|
||||||
if workspace.LatestBuild.Job.CompletedAt == nil {
|
|
||||||
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
|
|
||||||
return xerrors.New("workspace is deleting...")
|
|
||||||
}
|
|
||||||
|
|
||||||
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
agents := make([]codersdk.WorkspaceAgent, 0)
|
|
||||||
for _, resource := range resources {
|
|
||||||
agents = append(agents, resource.Agents...)
|
|
||||||
}
|
|
||||||
if len(agents) == 0 {
|
|
||||||
return xerrors.New("workspace has no agents")
|
|
||||||
}
|
|
||||||
var agent codersdk.WorkspaceAgent
|
|
||||||
if len(workspaceParts) >= 2 {
|
|
||||||
for _, otherAgent := range agents {
|
|
||||||
if otherAgent.Name != workspaceParts[1] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
agent = otherAgent
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if agent.ID == uuid.Nil {
|
|
||||||
return xerrors.Errorf("agent not found by name %q", workspaceParts[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if agent.ID == uuid.Nil {
|
|
||||||
if len(agents) > 1 {
|
|
||||||
if !shuffle {
|
|
||||||
return xerrors.New("you must specify the name of an agent")
|
|
||||||
}
|
|
||||||
idx, err := cryptorand.Intn(len(agents))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
agent = agents[idx]
|
|
||||||
} else {
|
|
||||||
agent = agents[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// OpenSSH passes stderr directly to the calling TTY.
|
// OpenSSH passes stderr directly to the calling TTY.
|
||||||
// This is required in "stdio" mode so a connecting indicator can be displayed.
|
// This is required in "stdio" mode so a connecting indicator can be displayed.
|
||||||
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
|
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
|
||||||
|
@ -233,6 +162,92 @@ func ssh() *cobra.Command {
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getWorkspaceAgent returns the workspace and agent selected using either the
|
||||||
|
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
|
||||||
|
// if `shuffle` is true.
|
||||||
|
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, orgID uuid.UUID, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
|
||||||
|
ctx := cmd.Context()
|
||||||
|
|
||||||
|
var (
|
||||||
|
workspace codersdk.Workspace
|
||||||
|
workspaceParts = strings.Split(in, ".")
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if shuffle {
|
||||||
|
workspaces, err := client.WorkspacesByOwner(cmd.Context(), orgID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||||
|
}
|
||||||
|
if len(workspaces) == 0 {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("no workspaces to shuffle")
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace, err = cryptorand.Element(workspaces)
|
||||||
|
if err != nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
workspace, err = client.WorkspaceByOwnerAndName(cmd.Context(), orgID, userID, workspaceParts[0])
|
||||||
|
if err != nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh")
|
||||||
|
}
|
||||||
|
if workspace.LatestBuild.Job.CompletedAt == nil {
|
||||||
|
err := cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||||
|
if err != nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("fetch workspace resources: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
agents := make([]codersdk.WorkspaceAgent, 0)
|
||||||
|
for _, resource := range resources {
|
||||||
|
agents = append(agents, resource.Agents...)
|
||||||
|
}
|
||||||
|
if len(agents) == 0 {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name)
|
||||||
|
}
|
||||||
|
var agent codersdk.WorkspaceAgent
|
||||||
|
if len(workspaceParts) >= 2 {
|
||||||
|
for _, otherAgent := range agents {
|
||||||
|
if otherAgent.Name != workspaceParts[1] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
agent = otherAgent
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if agent.ID == uuid.Nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if agent.ID == uuid.Nil {
|
||||||
|
if len(agents) > 1 {
|
||||||
|
if !shuffle {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent")
|
||||||
|
}
|
||||||
|
agent, err = cryptorand.Element(agents)
|
||||||
|
if err != nil {
|
||||||
|
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
agent = agents[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return workspace, agent, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Attempt to poll workspace autostop. We write a per-workspace lockfile to
|
// Attempt to poll workspace autostop. We write a per-workspace lockfile to
|
||||||
// avoid spamming the user with notifications in case of multiple instances
|
// avoid spamming the user with notifications in case of multiple instances
|
||||||
// of the CLI running simultaneously.
|
// of the CLI running simultaneously.
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/fatih/color"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/coder/coder/cli/cliui"
|
||||||
)
|
)
|
||||||
|
|
||||||
func templates() *cobra.Command {
|
func templates() *cobra.Command {
|
||||||
|
@ -13,15 +14,15 @@ func templates() *cobra.Command {
|
||||||
Example: `
|
Example: `
|
||||||
- Create a template for developers to create workspaces
|
- Create a template for developers to create workspaces
|
||||||
|
|
||||||
` + color.New(color.FgHiMagenta).Sprint("$ coder templates create") + `
|
` + cliui.Styles.Code.Render("$ coder templates create") + `
|
||||||
|
|
||||||
- Make changes to your template, and plan the changes
|
- Make changes to your template, and plan the changes
|
||||||
|
|
||||||
` + color.New(color.FgHiMagenta).Sprint("$ coder templates plan <name>") + `
|
` + cliui.Styles.Code.Render("$ coder templates plan <name>") + `
|
||||||
|
|
||||||
- Update the template. Your developers can update their workspaces
|
- Update the template. Your developers can update their workspaces
|
||||||
|
|
||||||
` + color.New(color.FgHiMagenta).Sprint("$ coder templates update <name>"),
|
` + cliui.Styles.Code.Render("$ coder templates update <name>"),
|
||||||
}
|
}
|
||||||
cmd.AddCommand(
|
cmd.AddCommand(
|
||||||
templateCreate(),
|
templateCreate(),
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
package cli
|
|
||||||
|
|
||||||
import "github.com/spf13/cobra"
|
|
||||||
|
|
||||||
func tunnel() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Annotations: workspaceCommand,
|
|
||||||
Use: "tunnel",
|
|
||||||
Short: "Forward ports to your local machine",
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
package cryptorand
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Element returns a random element of the slice. An error will be returned if
|
||||||
|
// the slice has no elements in it.
|
||||||
|
func Element[T any](s []T) (out T, err error) {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return out, xerrors.New("slice must have at least one element")
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := Intn(len(s))
|
||||||
|
if err != nil {
|
||||||
|
return out, xerrors.Errorf("generate random integer from 0-%v: %w", len(s), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s[i], nil
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
package cryptorand_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/cryptorand"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRandomElement(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("Empty", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s := []string{}
|
||||||
|
v, err := cryptorand.Element(s)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "slice must have at least one element")
|
||||||
|
require.Empty(t, v)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OK", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Generate random slices of ints and strings
|
||||||
|
var (
|
||||||
|
ints = make([]int, 20)
|
||||||
|
strings = make([]string, 20)
|
||||||
|
)
|
||||||
|
for i := range ints {
|
||||||
|
v, err := cryptorand.Intn(1024)
|
||||||
|
require.NoError(t, err, "generate random int for test slice")
|
||||||
|
ints[i] = v
|
||||||
|
}
|
||||||
|
for i := range strings {
|
||||||
|
v, err := cryptorand.String(10)
|
||||||
|
require.NoError(t, err, "generate random string for test slice")
|
||||||
|
strings[i] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a random value from each 20 times.
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
iv, err := cryptorand.Element(ints)
|
||||||
|
require.NoError(t, err, "unexpected error from Element(ints)")
|
||||||
|
t.Logf("random int slice element: %v", iv)
|
||||||
|
require.Contains(t, ints, iv)
|
||||||
|
|
||||||
|
sv, err := cryptorand.Element(strings)
|
||||||
|
require.NoError(t, err, "unexpected error from Element(strings)")
|
||||||
|
t.Logf("random string slice element: %v", sv)
|
||||||
|
require.Contains(t, strings, sv)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
2
go.mod
2
go.mod
|
@ -90,6 +90,7 @@ require (
|
||||||
github.com/pion/logging v0.2.2
|
github.com/pion/logging v0.2.2
|
||||||
github.com/pion/transport v0.13.0
|
github.com/pion/transport v0.13.0
|
||||||
github.com/pion/turn/v2 v2.0.8
|
github.com/pion/turn/v2 v2.0.8
|
||||||
|
github.com/pion/udp v0.1.1
|
||||||
github.com/pion/webrtc/v3 v3.1.39
|
github.com/pion/webrtc/v3 v3.1.39
|
||||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
|
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
|
||||||
github.com/pkg/sftp v1.13.4
|
github.com/pkg/sftp v1.13.4
|
||||||
|
@ -217,7 +218,6 @@ require (
|
||||||
github.com/pion/sdp/v3 v3.0.5 // indirect
|
github.com/pion/sdp/v3 v3.0.5 // indirect
|
||||||
github.com/pion/srtp/v2 v2.0.7 // indirect
|
github.com/pion/srtp/v2 v2.0.7 // indirect
|
||||||
github.com/pion/stun v0.3.5 // indirect
|
github.com/pion/stun v0.3.5 // indirect
|
||||||
github.com/pion/udp v0.1.1 // indirect
|
|
||||||
github.com/pires/go-proxyproto v0.6.2 // indirect
|
github.com/pires/go-proxyproto v0.6.2 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
|
|
@ -53,8 +53,8 @@ type ChannelOptions struct {
|
||||||
// Arbitrary string that can be parsed on `Accept`.
|
// Arbitrary string that can be parsed on `Accept`.
|
||||||
Protocol string
|
Protocol string
|
||||||
|
|
||||||
// Ordered determines whether the channel acts like
|
// Unordered determines whether the channel acts like
|
||||||
// a TCP connection. Defaults to false.
|
// a UDP connection. Defaults to false.
|
||||||
Unordered bool
|
Unordered bool
|
||||||
|
|
||||||
// Whether the channel will be left open on disconnect or not.
|
// Whether the channel will be left open on disconnect or not.
|
||||||
|
|
19
peer/conn.go
19
peer/conn.go
|
@ -68,7 +68,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
closedRTC: make(chan struct{}),
|
closedRTC: make(chan struct{}),
|
||||||
closedICE: make(chan struct{}),
|
closedICE: make(chan struct{}),
|
||||||
dcOpenChannel: make(chan *webrtc.DataChannel),
|
dcOpenChannel: make(chan *webrtc.DataChannel, 8),
|
||||||
dcDisconnectChannel: make(chan struct{}),
|
dcDisconnectChannel: make(chan struct{}),
|
||||||
dcFailedChannel: make(chan struct{}),
|
dcFailedChannel: make(chan struct{}),
|
||||||
localCandidateChannel: make(chan webrtc.ICECandidateInit),
|
localCandidateChannel: make(chan webrtc.ICECandidateInit),
|
||||||
|
@ -264,12 +264,13 @@ func (c *Conn) init() error {
|
||||||
}()
|
}()
|
||||||
})
|
})
|
||||||
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||||
select {
|
go func() {
|
||||||
case <-c.closed:
|
select {
|
||||||
return
|
case <-c.closed:
|
||||||
case c.dcOpenChannel <- dc:
|
return
|
||||||
default:
|
case c.dcOpenChannel <- dc:
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
})
|
})
|
||||||
_, err := c.pingChannel()
|
_, err := c.pingChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -469,8 +470,8 @@ func (c *Conn) Accept(ctx context.Context) (*Channel, error) {
|
||||||
return newChannel(c, dataChannel, &ChannelOptions{}), nil
|
return newChannel(c, dataChannel, &ChannelOptions{}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial creates a new DataChannel.
|
// CreateChannel creates a new DataChannel.
|
||||||
func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
|
func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
opts = &ChannelOptions{}
|
opts = &ChannelOptions{}
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,7 +90,7 @@ func TestConn(t *testing.T) {
|
||||||
_, err := server.Ping()
|
_, err := server.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Create a channel that closes on disconnect.
|
// Create a channel that closes on disconnect.
|
||||||
channel, err := server.Dial(context.Background(), "wow", nil)
|
channel, err := server.CreateChannel(context.Background(), "wow", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = wan.Stop()
|
err = wan.Stop()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -108,7 +108,7 @@ func TestConn(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client, server, _ := createPair(t)
|
client, server, _ := createPair(t)
|
||||||
exchange(t, client, server)
|
exchange(t, client, server)
|
||||||
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
|
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sch, err := server.Accept(context.Background())
|
sch, err := server.Accept(context.Background())
|
||||||
|
@ -124,7 +124,7 @@ func TestConn(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client, server, wan := createPair(t)
|
client, server, wan := createPair(t)
|
||||||
exchange(t, client, server)
|
exchange(t, client, server)
|
||||||
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
|
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sch, err := server.Accept(context.Background())
|
sch, err := server.Accept(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -141,7 +141,7 @@ func TestConn(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
client, server, _ := createPair(t)
|
client, server, _ := createPair(t)
|
||||||
exchange(t, client, server)
|
exchange(t, client, server)
|
||||||
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
|
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
sch, err := server.Accept(context.Background())
|
sch, err := server.Accept(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -196,7 +196,7 @@ func TestConn(t *testing.T) {
|
||||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
var cch *peer.Channel
|
var cch *peer.Channel
|
||||||
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
cch, err = client.Dial(ctx, "hello", &peer.ChannelOptions{})
|
cch, err = client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -234,7 +234,7 @@ func TestConn(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expectedErr := xerrors.New("wow")
|
expectedErr := xerrors.New("wow")
|
||||||
_ = conn.CloseWithError(expectedErr)
|
_ = conn.CloseWithError(expectedErr)
|
||||||
_, err = conn.Dial(context.Background(), "", nil)
|
_, err = conn.CreateChannel(context.Background(), "", nil)
|
||||||
require.ErrorIs(t, err, expectedErr)
|
require.ErrorIs(t, err, expectedErr)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -274,7 +274,7 @@ func TestConn(t *testing.T) {
|
||||||
client, server, _ := createPair(t)
|
client, server, _ := createPair(t)
|
||||||
exchange(t, client, server)
|
exchange(t, client, server)
|
||||||
go func() {
|
go func() {
|
||||||
channel, err := client.Dial(context.Background(), "test", nil)
|
channel, err := client.CreateChannel(context.Background(), "test", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = channel.Write([]byte{1, 2})
|
_, err = channel.Write([]byte{1, 2})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
Loading…
Reference in New Issue