feat: replace vscodeipc with vscodessh (#5645)

The VS Code extension has been refactored to use VS Code
Remote SSH instead of using the private API.

This changes the structure to continue using SSH, but
output network information periodically to a file.
This commit is contained in:
Kyle Carberry 2023-01-09 22:23:17 -06:00 committed by GitHub
parent fa7deaaa5c
commit 9f6edab53b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 324 additions and 650 deletions

View File

@ -304,7 +304,18 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
if err != nil {
return
}
go a.sshServer.HandleConn(conn)
closed := make(chan struct{})
_ = a.trackConnGoroutine(func() {
select {
case <-network.Closed():
case <-closed:
}
_ = conn.Close()
})
_ = a.trackConnGoroutine(func() {
defer close(closed)
a.sshServer.HandleConn(conn)
})
}
}); err != nil {
return nil, err

View File

@ -97,7 +97,7 @@ func Core() []*cobra.Command {
update(),
users(),
versionCmd(),
vscodeipcCmd(),
vscodeSSH(),
workspaceAgent(),
}
}

View File

@ -256,7 +256,7 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder
)
if shuffle {
res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: codersdk.Me,
Owner: userID,
})
if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err

View File

@ -1,88 +0,0 @@
package cli
import (
"fmt"
"net"
"net/http"
"net/url"
"github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/vscodeipc"
"github.com/coder/coder/codersdk"
)
// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages.
// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder
func vscodeipcCmd() *cobra.Command {
var (
rawURL string
token string
port uint16
)
cmd := &cobra.Command{
Use: "vscodeipc <workspace-agent>",
Args: cobra.ExactArgs(1),
SilenceUsage: true,
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
if rawURL == "" {
return xerrors.New("CODER_URL must be set!")
}
// token is validated in a header on each request to prevent
// unauthenticated clients from connecting.
if token == "" {
return xerrors.New("CODER_TOKEN must be set!")
}
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return xerrors.Errorf("listen: %w", err)
}
defer listener.Close()
addr, ok := listener.Addr().(*net.TCPAddr)
if !ok {
return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr())
}
url, err := url.Parse(rawURL)
if err != nil {
return err
}
agentID, err := uuid.Parse(args[0])
if err != nil {
return err
}
client := codersdk.New(url)
client.SetSessionToken(token)
handler, closer, err := vscodeipc.New(cmd.Context(), client, agentID, nil)
if err != nil {
return err
}
defer closer.Close()
// nolint:gosec
server := http.Server{
Handler: handler,
}
defer server.Close()
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String())
errChan := make(chan error, 1)
go func() {
err := server.Serve(listener)
errChan <- err
}()
select {
case <-cmd.Context().Done():
return cmd.Context().Err()
case err := <-errChan:
return err
}
},
}
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!")
cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!")
cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!")
return cmd
}

View File

@ -1,313 +0,0 @@
package vscodeipc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
const AuthHeader = "Coder-IPC-Token"
// New creates a VS Code IPC client that can be used to communicate with workspaces.
//
// Creating this IPC was required instead of using SSH, because we're unable to get
// connection information to display in the bottom-bar when using SSH. It's possible
// we could jank around this (maybe by using a temporary SSH host), but that's not
// ideal.
//
// This persists a single workspace connection, and lets you execute commands, check
// for network information, and forward ports.
//
// The VS Code extension is located at https://github.com/coder/vscode-coder. The
// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc`
// which calls this function. This API must maintain backward compatibility with
// the extension to support prior versions of Coder.
func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) {
if options == nil {
options = &codersdk.DialWorkspaceAgentOptions{}
}
// We need this to track upload and download!
options.EnableTrafficStats = true
agentConn, err := client.DialWorkspaceAgent(ctx, agentID, options)
if err != nil {
return nil, nil, err
}
api := &api{
agentConn: agentConn,
}
r := chi.NewRouter()
// This is to prevent unauthorized clients on the same machine from executing
// requests on behalf of the workspace.
r.Use(sessionTokenMiddleware(client.SessionToken()))
r.Route("/v1", func(r chi.Router) {
r.Get("/port/{port}", api.port)
r.Get("/network", api.network)
r.Post("/execute", api.execute)
})
return r, api, nil
}
type api struct {
agentConn *codersdk.AgentConn
sshClient *ssh.Client
sshClientErr error
sshClientOnce sync.Once
lastNetwork time.Time
}
func (api *api) Close() error {
if api.sshClient != nil {
api.sshClient.Close()
}
return api.agentConn.Close()
}
type NetworkResponse struct {
P2P bool `json:"p2p"`
Latency float64 `json:"latency"`
PreferredDERP string `json:"preferred_derp"`
DERPLatency map[string]float64 `json:"derp_latency"`
UploadBytesSec int64 `json:"upload_bytes_sec"`
DownloadBytesSec int64 `json:"download_bytes_sec"`
}
// port accepts an HTTP request to dial a port on the workspace agent.
// It uses an HTTP connection upgrade to transfer the connection to TCP.
func (api *api) port(w http.ResponseWriter, r *http.Request) {
port, err := strconv.Atoi(chi.URLParam(r, "port"))
if err != nil {
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
Message: "Port must be an integer!",
})
return
}
remoteConn, err := api.agentConn.DialContext(r.Context(), "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
httpapi.InternalServerError(w, err)
return
}
defer remoteConn.Close()
// Upgrade an switch to TCP!
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "tcp")
w.WriteHeader(http.StatusSwitchingProtocols)
hijacker, ok := w.(http.Hijacker)
if !ok {
httpapi.InternalServerError(w, xerrors.Errorf("unable to hijack connection: %T", w))
return
}
localConn, brw, err := hijacker.Hijack()
if err != nil {
httpapi.InternalServerError(w, err)
return
}
defer localConn.Close()
_ = brw.Flush()
agent.Bicopy(r.Context(), localConn, remoteConn)
}
// network returns network information about the workspace.
func (api *api) network(w http.ResponseWriter, r *http.Request) {
// Ping the workspace agent to get the latency.
latency, p2p, err := api.agentConn.Ping(r.Context())
if err != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to ping the workspace agent.",
Detail: err.Error(),
})
return
}
node := api.agentConn.Node()
derpMap := api.agentConn.DERPMap()
derpLatency := map[string]float64{}
// Convert DERP region IDs to friendly names for display in the UI.
for rawRegion, latency := range node.DERPLatency {
regionParts := strings.SplitN(rawRegion, "-", 2)
regionID, err := strconv.Atoi(regionParts[0])
if err != nil {
continue
}
region, found := derpMap.Regions[regionID]
if !found {
// It's possible that a workspace agent is using an old DERPMap
// and reports regions that do not exist. If that's the case,
// report the region as unknown!
region = &tailcfg.DERPRegion{
RegionID: regionID,
RegionName: fmt.Sprintf("Unnamed %d", regionID),
}
}
// Convert the microseconds to milliseconds.
derpLatency[region.RegionName] = latency * 1000
}
totalRx := uint64(0)
totalTx := uint64(0)
for _, stat := range api.agentConn.ExtractTrafficStats() {
totalRx += stat.RxBytes
totalTx += stat.TxBytes
}
// Tracking the time since last request is required because
// ExtractTrafficStats() resets its counters after each call.
dur := time.Since(api.lastNetwork)
uploadSecs := float64(totalTx) / dur.Seconds()
downloadSecs := float64(totalRx) / dur.Seconds()
api.lastNetwork = time.Now()
httpapi.Write(r.Context(), w, http.StatusOK, NetworkResponse{
P2P: p2p,
Latency: float64(latency.Microseconds()) / 1000,
PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName,
DERPLatency: derpLatency,
UploadBytesSec: int64(uploadSecs),
DownloadBytesSec: int64(downloadSecs),
})
}
type ExecuteRequest struct {
Command string `json:"command"`
Stdin string `json:"stdin"`
}
type ExecuteResponse struct {
Data string `json:"data"`
ExitCode *int `json:"exit_code"`
}
// execute runs the command provided, streams the output back, and returns the exit code.
func (api *api) execute(w http.ResponseWriter, r *http.Request) {
var req ExecuteRequest
if !httpapi.Read(r.Context(), w, r, &req) {
return
}
api.sshClientOnce.Do(func() {
// The SSH client is lazily created because it's not needed for
// all requests. It's only needed for the execute endpoint.
//
// It's alright if this fails on the first execution, because
// a new instance of this API is created for each remote SSH request.
api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background())
})
if api.sshClientErr != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to create SSH client.",
Detail: api.sshClientErr.Error(),
})
return
}
session, err := api.sshClient.NewSession()
if err != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to create SSH session.",
Detail: err.Error(),
})
return
}
defer session.Close()
f, ok := w.(http.Flusher)
if !ok {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w),
})
return
}
execWriter := &execWriter{w, f}
session.Stdout = execWriter
session.Stderr = execWriter
session.Stdin = strings.NewReader(req.Stdin)
err = session.Start(req.Command)
if err != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start SSH session.",
Detail: err.Error(),
})
return
}
err = session.Wait()
writeExit := func(exitCode int) {
data, _ := json.Marshal(&ExecuteResponse{
ExitCode: &exitCode,
})
_, _ = w.Write(data)
f.Flush()
}
if err != nil {
var exitError *ssh.ExitError
if errors.As(err, &exitError) {
writeExit(exitError.ExitStatus())
return
}
}
writeExit(0)
}
type execWriter struct {
w http.ResponseWriter
f http.Flusher
}
func (e *execWriter) Write(data []byte) (int, error) {
js, err := json.Marshal(&ExecuteResponse{
Data: string(data),
})
if err != nil {
return 0, err
}
_, err = e.w.Write(js)
if err != nil {
return 0, err
}
e.f.Flush()
return len(data), nil
}
func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get(AuthHeader)
if token == "" {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: fmt.Sprintf("A session token must be provided in the `%s` header.", AuthHeader),
})
return
}
if token != sessionToken {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: "The session token provided doesn't match the one used to create the client.",
})
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
h.ServeHTTP(w, r)
})
}
}

View File

@ -1,202 +0,0 @@
package vscodeipc_test
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"nhooyr.io/websocket"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/vscodeipc"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestVSCodeIPC(t *testing.T) {
t.Parallel()
ctx := context.Background()
id := uuid.New()
derpMap := tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator()
t.Cleanup(func() {
_ = coordinator.Close()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id):
assert.Equal(t, r.Method, http.MethodGet)
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
DERPMap: derpMap,
})
return
case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id):
assert.Equal(t, r.Method, http.MethodGet)
ws, err := websocket.Accept(w, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
_ = coordinator.ServeClient(conn, uuid.New(), id)
return
case "/api/v2/workspaceagents/me/version":
assert.Equal(t, r.Method, http.MethodPost)
w.WriteHeader(http.StatusOK)
return
case "/api/v2/workspaceagents/me/metadata":
assert.Equal(t, r.Method, http.MethodGet)
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{
DERPMap: derpMap,
})
return
case "/api/v2/workspaceagents/me/coordinate":
assert.Equal(t, r.Method, http.MethodGet)
ws, err := websocket.Accept(w, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
_ = coordinator.ServeAgent(conn, id)
return
case "/api/v2/workspaceagents/me/report-stats":
assert.Equal(t, r.Method, http.MethodPost)
w.WriteHeader(http.StatusOK)
return
case "/":
w.WriteHeader(http.StatusOK)
return
default:
t.Fatalf("unexpected request %s", r.URL.Path)
}
}))
t.Cleanup(srv.Close)
srvURL, _ := url.Parse(srv.URL)
client := codersdk.New(srvURL)
token := uuid.New().String()
client.SetSessionToken(token)
agentConn := agent.New(agent.Options{
Client: client,
Filesystem: afero.NewMemMapFs(),
TempDir: t.TempDir(),
})
t.Cleanup(func() {
_ = agentConn.Close()
})
handler, closer, err := vscodeipc.New(ctx, client, id, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = closer.Close()
})
// Ensure that we're actually connected!
require.Eventually(t, func() bool {
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/network", nil)
req.Header.Set(vscodeipc.AuthHeader, token)
handler.ServeHTTP(res, req)
network := &vscodeipc.NetworkResponse{}
err = json.NewDecoder(res.Body).Decode(&network)
assert.NoError(t, err)
return network.Latency != 0
}, testutil.WaitLong, testutil.IntervalFast)
_, port, err := net.SplitHostPort(srvURL.Host)
require.NoError(t, err)
t.Run("NoSessionToken", func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
handler.ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Code)
})
t.Run("MismatchedSessionToken", func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
req.Header.Set(vscodeipc.AuthHeader, uuid.NewString())
handler.ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Code)
})
t.Run("Port", func(t *testing.T) {
// Tests that the port endpoint can be used for forward traffic.
// For this test, we simply use the already listening httptest server.
t.Parallel()
input, output := net.Pipe()
defer input.Close()
defer output.Close()
res := &hijackable{httptest.NewRecorder(), output}
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
req.Header.Set(vscodeipc.AuthHeader, token)
go handler.ServeHTTP(res, req)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil)
require.NoError(t, err)
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return input, nil
},
},
}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("Execute", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Execute isn't supported on Windows yet!")
return
}
res := httptest.NewRecorder()
data, _ := json.Marshal(vscodeipc.ExecuteRequest{
Command: "echo test",
})
req := httptest.NewRequest(http.MethodPost, "/v1/execute", bytes.NewReader(data))
req.Header.Set(vscodeipc.AuthHeader, token)
handler.ServeHTTP(res, req)
decoder := json.NewDecoder(res.Body)
var msg vscodeipc.ExecuteResponse
err = decoder.Decode(&msg)
require.NoError(t, err)
require.Equal(t, "test\n", msg.Data)
err = decoder.Decode(&msg)
require.NoError(t, err)
require.Equal(t, 0, *msg.ExitCode)
})
}
type hijackable struct {
*httptest.ResponseRecorder
conn net.Conn
}
func (h *hijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
}

View File

@ -1,44 +0,0 @@
package cli_test
import (
"io"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/testutil"
)
func TestVSCodeIPC(t *testing.T) {
t.Parallel()
// Ensures the vscodeipc command outputs it's running port!
// This signifies to the caller that it's ready to accept requests.
t.Run("PortOutputs", func(t *testing.T) {
t.Parallel()
client, workspace, _ := setupWorkspaceForAgent(t, nil)
cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(),
"--token", client.SessionToken(), "--url", client.URL.String())
rdr, wtr := io.Pipe()
cmd.SetOut(wtr)
ctx, cancelFunc := testutil.Context(t)
defer cancelFunc()
done := make(chan error, 1)
go func() {
err := cmd.ExecuteContext(ctx)
done <- err
}()
buf := make([]byte, 64)
require.Eventually(t, func() bool {
t.Log("Looking for address!")
var err error
_, err = rdr.Read(buf)
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
t.Logf("Address: %s\n", buf)
cancelFunc()
<-done
})
}

237
cli/vscodessh.go Normal file
View File

@ -0,0 +1,237 @@
package cli
import (
"context"
"encoding/json"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"github.com/coder/coder/codersdk"
)
// vscodeSSH is used by the Coder VS Code extension to establish
// a connection to a workspace.
//
// This command needs to remain stable for compatibility with
// various VS Code versions, so it's kept separate from our
// standard SSH command.
func vscodeSSH() *cobra.Command {
var (
sessionTokenFile string
urlFile string
networkInfoDir string
networkInfoInterval time.Duration
)
cmd := &cobra.Command{
// A SSH config entry is added by the VS Code extension that
// passes %h to ProxyCommand. The prefix of `coder-vscode--`
// is a magical string represented in our VS Cod extension.
// It's not important here, only the delimiter `--` is.
Use: "vscodessh <coder-vscode--<owner>-<workspace>-<agent?>>",
Hidden: true,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
if networkInfoDir == "" {
return xerrors.New("network-info-dir must be specified")
}
if sessionTokenFile == "" {
return xerrors.New("session-token-file must be specified")
}
if urlFile == "" {
return xerrors.New("url-file must be specified")
}
fs, ok := cmd.Context().Value("fs").(afero.Fs)
if !ok {
fs = afero.NewOsFs()
}
sessionToken, err := afero.ReadFile(fs, sessionTokenFile)
if err != nil {
return xerrors.Errorf("read session token: %w", err)
}
rawURL, err := afero.ReadFile(fs, urlFile)
if err != nil {
return xerrors.Errorf("read url: %w", err)
}
serverURL, err := url.Parse(string(rawURL))
if err != nil {
return xerrors.Errorf("parse url: %w", err)
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
err = fs.MkdirAll(networkInfoDir, 0700)
if err != nil {
return xerrors.Errorf("mkdir: %w", err)
}
client := codersdk.New(serverURL)
client.SetSessionToken(string(sessionToken))
parts := strings.Split(args[0], "--")
if len(parts) < 3 {
return xerrors.Errorf("invalid argument format. must be: coder-vscode--<owner>-<name>-<agent?>")
}
owner := parts[1]
name := parts[2]
workspace, err := client.WorkspaceByOwnerAndName(ctx, owner, name, codersdk.WorkspaceOptions{})
if err != nil {
return xerrors.Errorf("find workspace: %w", err)
}
var agent codersdk.WorkspaceAgent
var found bool
for _, resource := range workspace.LatestBuild.Resources {
if len(resource.Agents) == 0 {
continue
}
for _, resourceAgent := range resource.Agents {
// If an agent name isn't included we default to
// the first agent!
if len(parts) != 4 {
agent = resourceAgent
found = true
break
}
if resourceAgent.Name != parts[3] {
continue
}
agent = resourceAgent
found = true
break
}
if found {
break
}
}
agentConn, err := client.DialWorkspaceAgent(ctx, agent.ID, &codersdk.DialWorkspaceAgentOptions{
EnableTrafficStats: true,
})
if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err)
}
defer agentConn.Close()
agentConn.AwaitReachable(ctx)
rawSSH, err := agentConn.SSH(ctx)
if err != nil {
return err
}
defer rawSSH.Close()
// Copy SSH traffic over stdio.
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
go func() {
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
}()
// The VS Code extension obtains the PID of the SSH process to
// read the file below which contains network information to display.
//
// We get the parent PID because it's assumed `ssh` is calling this
// command via the ProxyCommand SSH option.
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid()))
ticker := time.NewTicker(networkInfoInterval)
defer ticker.Stop()
lastCollected := time.Now()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
}
stats, err := collectNetworkStats(ctx, agentConn, lastCollected)
if err != nil {
return err
}
rawStats, err := json.Marshal(stats)
if err != nil {
return err
}
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0600)
if err != nil {
return err
}
lastCollected = time.Now()
}
},
}
cmd.Flags().StringVarP(&networkInfoDir, "network-info-dir", "", "", "Specifies a directory to write network information periodically.")
cmd.Flags().StringVarP(&sessionTokenFile, "session-token-file", "", "", "Specifies a file that contains a session token.")
cmd.Flags().StringVarP(&urlFile, "url-file", "", "", "Specifies a file that contains the Coder URL.")
cmd.Flags().DurationVarP(&networkInfoInterval, "network-info-interval", "", 3*time.Second, "Specifies the interval to update network information.")
return cmd
}
type sshNetworkStats struct {
P2P bool `json:"p2p"`
Latency float64 `json:"latency"`
PreferredDERP string `json:"preferred_derp"`
DERPLatency map[string]float64 `json:"derp_latency"`
UploadBytesSec int64 `json:"upload_bytes_sec"`
DownloadBytesSec int64 `json:"download_bytes_sec"`
}
func collectNetworkStats(ctx context.Context, agentConn *codersdk.AgentConn, lastCollected time.Time) (*sshNetworkStats, error) {
latency, p2p, err := agentConn.Ping(ctx)
if err != nil {
return nil, err
}
node := agentConn.Node()
derpMap := agentConn.DERPMap()
derpLatency := map[string]float64{}
// Convert DERP region IDs to friendly names for display in the UI.
for rawRegion, latency := range node.DERPLatency {
regionParts := strings.SplitN(rawRegion, "-", 2)
regionID, err := strconv.Atoi(regionParts[0])
if err != nil {
continue
}
region, found := derpMap.Regions[regionID]
if !found {
// It's possible that a workspace agent is using an old DERPMap
// and reports regions that do not exist. If that's the case,
// report the region as unknown!
region = &tailcfg.DERPRegion{
RegionID: regionID,
RegionName: fmt.Sprintf("Unnamed %d", regionID),
}
}
// Convert the microseconds to milliseconds.
derpLatency[region.RegionName] = latency * 1000
}
totalRx := uint64(0)
totalTx := uint64(0)
for _, stat := range agentConn.ExtractTrafficStats() {
totalRx += stat.RxBytes
totalTx += stat.TxBytes
}
// Tracking the time since last request is required because
// ExtractTrafficStats() resets its counters after each call.
dur := time.Since(lastCollected)
uploadSecs := float64(totalTx) / dur.Seconds()
downloadSecs := float64(totalRx) / dur.Seconds()
return &sshNetworkStats{
P2P: p2p,
Latency: float64(latency.Microseconds()) / 1000,
PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName,
DERPLatency: derpLatency,
UploadBytesSec: int64(uploadSecs),
DownloadBytesSec: int64(downloadSecs),
}, nil
}

73
cli/vscodessh_test.go Normal file
View File

@ -0,0 +1,73 @@
package cli_test
import (
"context"
"fmt"
"testing"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
// TestVSCodeSSH ensures the agent connects properly with SSH
// and that network information is properly written to the FS.
func TestVSCodeSSH(t *testing.T) {
t.Parallel()
ctx, cancel := testutil.Context(t)
defer cancel()
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
user, err := client.User(ctx, codersdk.Me)
require.NoError(t, err)
agentClient := codersdk.New(client.URL)
agentClient.SetSessionToken(agentToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
defer func() {
_ = agentCloser.Close()
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
fs := afero.NewMemMapFs()
err = afero.WriteFile(fs, "/url", []byte(client.URL.String()), 0600)
require.NoError(t, err)
err = afero.WriteFile(fs, "/token", []byte(client.SessionToken()), 0600)
require.NoError(t, err)
cmd, _ := clitest.New(t,
"vscodessh",
"--url-file", "/url",
"--session-token-file", "/token",
"--network-info-dir", "/net",
"--network-info-interval", "25ms",
fmt.Sprintf("coder-vscode--%s--%s", user.Username, workspace.Name))
done := make(chan struct{})
go func() {
//nolint // The above seems reasonable for a one-off test.
err := cmd.ExecuteContext(context.WithValue(ctx, "fs", fs))
if err != nil {
assert.ErrorIs(t, err, context.Canceled)
}
close(done)
}()
require.Eventually(t, func() bool {
entries, err := afero.ReadDir(fs, "/net")
if err != nil {
return false
}
return len(entries) > 0
}, testutil.WaitLong, testutil.IntervalFast)
cancel()
<-done
}