mirror of https://github.com/coder/coder.git
feat: Add `vscodeipc` subcommand for VS Code Extension (#5326)
* Add extio * feat: Add `vscodeipc` subcommand for VS Code Extension This enables the VS Code extension to communicate with a Coder client. The extension will download the slim binary from `/bin/*` for the respective client architecture and OS, then execute `coder vscodeipc` for the connecting workspace. * Add authentication header, improve comments, and add tests for the CLI * Update cli/vscodeipc_test.go Co-authored-by: Mathias Fredriksson <mafredri@gmail.com> * Update cli/vscodeipc_test.go Co-authored-by: Mathias Fredriksson <mafredri@gmail.com> * Update cli/vscodeipc/vscodeipc_test.go Co-authored-by: Mathias Fredriksson <mafredri@gmail.com> * Fix requested changes * Fix IPC tests * Fix shell execution * Fix nix flake * Silence usage Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
This commit is contained in:
parent
d1f8fec1d3
commit
e61234f260
|
@ -98,6 +98,7 @@ func Core() []*cobra.Command {
|
|||
users(),
|
||||
versionCmd(),
|
||||
workspaceAgent(),
|
||||
vscodeipcCmd(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ func speedtest() *cobra.Command {
|
|||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
dur, err := conn.Ping(ctx)
|
||||
dur, p2p, err := conn.Ping(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func speedtest() *cobra.Command {
|
|||
continue
|
||||
}
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
if peer.CurAddr == "" && direct {
|
||||
if !p2p && direct {
|
||||
cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay)
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -65,6 +65,8 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A
|
|||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
workspace, err := client.Workspace(context.Background(), workspace.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
return client, workspace, agentToken
|
||||
}
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/vscodeipc"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages.
|
||||
// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder
|
||||
func vscodeipcCmd() *cobra.Command {
|
||||
var (
|
||||
rawURL string
|
||||
token string
|
||||
port uint16
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "vscodeipc <workspace-agent>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
SilenceUsage: true,
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if rawURL == "" {
|
||||
return xerrors.New("CODER_URL must be set!")
|
||||
}
|
||||
// token is validated in a header on each request to prevent
|
||||
// unauthenticated clients from connecting.
|
||||
if token == "" {
|
||||
return xerrors.New("CODER_TOKEN must be set!")
|
||||
}
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("listen: %w", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
addr, ok := listener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr())
|
||||
}
|
||||
url, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
agentID, err := uuid.Parse(args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := codersdk.New(url)
|
||||
client.SetSessionToken(token)
|
||||
|
||||
handler, closer, err := vscodeipc.New(cmd.Context(), client, agentID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closer.Close()
|
||||
// nolint:gosec
|
||||
server := http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
defer server.Close()
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String())
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
errChan <- err
|
||||
}()
|
||||
select {
|
||||
case <-cmd.Context().Done():
|
||||
return cmd.Context().Err()
|
||||
case err := <-errChan:
|
||||
return err
|
||||
}
|
||||
},
|
||||
}
|
||||
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!")
|
||||
cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!")
|
||||
cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!")
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,313 @@
|
|||
package vscodeipc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
const AuthHeader = "Coder-IPC-Token"
|
||||
|
||||
// New creates a VS Code IPC client that can be used to communicate with workspaces.
|
||||
//
|
||||
// Creating this IPC was required instead of using SSH, because we're unable to get
|
||||
// connection information to display in the bottom-bar when using SSH. It's possible
|
||||
// we could jank around this (maybe by using a temporary SSH host), but that's not
|
||||
// ideal.
|
||||
//
|
||||
// This persists a single workspace connection, and lets you execute commands, check
|
||||
// for network information, and forward ports.
|
||||
//
|
||||
// The VS Code extension is located at https://github.com/coder/vscode-coder. The
|
||||
// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc`
|
||||
// which calls this function. This API must maintain backward compatibility with
|
||||
// the extension to support prior versions of Coder.
|
||||
func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) {
|
||||
if options == nil {
|
||||
options = &codersdk.DialWorkspaceAgentOptions{}
|
||||
}
|
||||
// We need this to track upload and download!
|
||||
options.EnableTrafficStats = true
|
||||
|
||||
agentConn, err := client.DialWorkspaceAgent(ctx, agentID, options)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
api := &api{
|
||||
agentConn: agentConn,
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
// This is to prevent unauthorized clients on the same machine from executing
|
||||
// requests on behalf of the workspace.
|
||||
r.Use(sessionTokenMiddleware(client.SessionToken()))
|
||||
r.Route("/v1", func(r chi.Router) {
|
||||
r.Get("/port/{port}", api.port)
|
||||
r.Get("/network", api.network)
|
||||
r.Post("/execute", api.execute)
|
||||
})
|
||||
return r, api, nil
|
||||
}
|
||||
|
||||
type api struct {
|
||||
agentConn *codersdk.AgentConn
|
||||
sshClient *ssh.Client
|
||||
sshClientErr error
|
||||
sshClientOnce sync.Once
|
||||
|
||||
lastNetwork time.Time
|
||||
}
|
||||
|
||||
func (api *api) Close() error {
|
||||
if api.sshClient != nil {
|
||||
api.sshClient.Close()
|
||||
}
|
||||
return api.agentConn.Close()
|
||||
}
|
||||
|
||||
type NetworkResponse struct {
|
||||
P2P bool `json:"p2p"`
|
||||
Latency float64 `json:"latency"`
|
||||
PreferredDERP string `json:"preferred_derp"`
|
||||
DERPLatency map[string]float64 `json:"derp_latency"`
|
||||
UploadBytesSec int64 `json:"upload_bytes_sec"`
|
||||
DownloadBytesSec int64 `json:"download_bytes_sec"`
|
||||
}
|
||||
|
||||
// port accepts an HTTP request to dial a port on the workspace agent.
|
||||
// It uses an HTTP connection upgrade to transfer the connection to TCP.
|
||||
func (api *api) port(w http.ResponseWriter, r *http.Request) {
|
||||
port, err := strconv.Atoi(chi.URLParam(r, "port"))
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Port must be an integer!",
|
||||
})
|
||||
return
|
||||
}
|
||||
remoteConn, err := api.agentConn.DialContext(r.Context(), "tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(w, err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
|
||||
// Upgrade an switch to TCP!
|
||||
w.Header().Set("Connection", "Upgrade")
|
||||
w.Header().Set("Upgrade", "tcp")
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
httpapi.InternalServerError(w, xerrors.Errorf("unable to hijack connection: %T", w))
|
||||
return
|
||||
}
|
||||
|
||||
localConn, brw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(w, err)
|
||||
return
|
||||
}
|
||||
defer localConn.Close()
|
||||
|
||||
_ = brw.Flush()
|
||||
agent.Bicopy(r.Context(), localConn, remoteConn)
|
||||
}
|
||||
|
||||
// network returns network information about the workspace.
|
||||
func (api *api) network(w http.ResponseWriter, r *http.Request) {
|
||||
// Ping the workspace agent to get the latency.
|
||||
latency, p2p, err := api.agentConn.Ping(r.Context())
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to ping the workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
node := api.agentConn.Node()
|
||||
derpMap := api.agentConn.DERPMap()
|
||||
derpLatency := map[string]float64{}
|
||||
|
||||
// Convert DERP region IDs to friendly names for display in the UI.
|
||||
for rawRegion, latency := range node.DERPLatency {
|
||||
regionParts := strings.SplitN(rawRegion, "-", 2)
|
||||
regionID, err := strconv.Atoi(regionParts[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
region, found := derpMap.Regions[regionID]
|
||||
if !found {
|
||||
// It's possible that a workspace agent is using an old DERPMap
|
||||
// and reports regions that do not exist. If that's the case,
|
||||
// report the region as unknown!
|
||||
region = &tailcfg.DERPRegion{
|
||||
RegionID: regionID,
|
||||
RegionName: fmt.Sprintf("Unnamed %d", regionID),
|
||||
}
|
||||
}
|
||||
// Convert the microseconds to milliseconds.
|
||||
derpLatency[region.RegionName] = latency * 1000
|
||||
}
|
||||
|
||||
totalRx := uint64(0)
|
||||
totalTx := uint64(0)
|
||||
for _, stat := range api.agentConn.ExtractTrafficStats() {
|
||||
totalRx += stat.RxBytes
|
||||
totalTx += stat.TxBytes
|
||||
}
|
||||
// Tracking the time since last request is required because
|
||||
// ExtractTrafficStats() resets its counters after each call.
|
||||
dur := time.Since(api.lastNetwork)
|
||||
uploadSecs := float64(totalTx) / dur.Seconds()
|
||||
downloadSecs := float64(totalRx) / dur.Seconds()
|
||||
|
||||
api.lastNetwork = time.Now()
|
||||
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, NetworkResponse{
|
||||
P2P: p2p,
|
||||
Latency: float64(latency.Microseconds()) / 1000,
|
||||
PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName,
|
||||
DERPLatency: derpLatency,
|
||||
UploadBytesSec: int64(uploadSecs),
|
||||
DownloadBytesSec: int64(downloadSecs),
|
||||
})
|
||||
}
|
||||
|
||||
type ExecuteRequest struct {
|
||||
Command string `json:"command"`
|
||||
Stdin string `json:"stdin"`
|
||||
}
|
||||
|
||||
type ExecuteResponse struct {
|
||||
Data string `json:"data"`
|
||||
ExitCode *int `json:"exit_code"`
|
||||
}
|
||||
|
||||
// execute runs the command provided, streams the output back, and returns the exit code.
|
||||
func (api *api) execute(w http.ResponseWriter, r *http.Request) {
|
||||
var req ExecuteRequest
|
||||
if !httpapi.Read(r.Context(), w, r, &req) {
|
||||
return
|
||||
}
|
||||
api.sshClientOnce.Do(func() {
|
||||
// The SSH client is lazily created because it's not needed for
|
||||
// all requests. It's only needed for the execute endpoint.
|
||||
//
|
||||
// It's alright if this fails on the first execution, because
|
||||
// a new instance of this API is created for each remote SSH request.
|
||||
api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background())
|
||||
})
|
||||
if api.sshClientErr != nil {
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create SSH client.",
|
||||
Detail: api.sshClientErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
session, err := api.sshClient.NewSession()
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create SSH session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
f, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
execWriter := &execWriter{w, f}
|
||||
session.Stdout = execWriter
|
||||
session.Stderr = execWriter
|
||||
session.Stdin = strings.NewReader(req.Stdin)
|
||||
err = session.Start(req.Command)
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start SSH session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = session.Wait()
|
||||
|
||||
writeExit := func(exitCode int) {
|
||||
data, _ := json.Marshal(&ExecuteResponse{
|
||||
ExitCode: &exitCode,
|
||||
})
|
||||
_, _ = w.Write(data)
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
var exitError *ssh.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
writeExit(exitError.ExitStatus())
|
||||
return
|
||||
}
|
||||
}
|
||||
writeExit(0)
|
||||
}
|
||||
|
||||
type execWriter struct {
|
||||
w http.ResponseWriter
|
||||
f http.Flusher
|
||||
}
|
||||
|
||||
func (e *execWriter) Write(data []byte) (int, error) {
|
||||
js, err := json.Marshal(&ExecuteResponse{
|
||||
Data: string(data),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = e.w.Write(js)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
e.f.Flush()
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get(AuthHeader)
|
||||
if token == "" {
|
||||
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: fmt.Sprintf("A session token must be provided in the `%s` header.", AuthHeader),
|
||||
})
|
||||
return
|
||||
}
|
||||
if token != sessionToken {
|
||||
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "The session token provided doesn't match the one used to create the client.",
|
||||
})
|
||||
return
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
package vscodeipc_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/vscodeipc"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestVSCodeIPC(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
id := uuid.New()
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
t.Cleanup(func() {
|
||||
_ = coordinator.Close()
|
||||
})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id):
|
||||
assert.Equal(t, r.Method, http.MethodGet)
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
|
||||
DERPMap: derpMap,
|
||||
})
|
||||
return
|
||||
case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id):
|
||||
assert.Equal(t, r.Method, http.MethodGet)
|
||||
ws, err := websocket.Accept(w, r, nil)
|
||||
require.NoError(t, err)
|
||||
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
|
||||
_ = coordinator.ServeClient(conn, uuid.New(), id)
|
||||
return
|
||||
case "/api/v2/workspaceagents/me/version":
|
||||
assert.Equal(t, r.Method, http.MethodPost)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
case "/api/v2/workspaceagents/me/metadata":
|
||||
assert.Equal(t, r.Method, http.MethodGet)
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{
|
||||
DERPMap: derpMap,
|
||||
})
|
||||
return
|
||||
case "/api/v2/workspaceagents/me/coordinate":
|
||||
assert.Equal(t, r.Method, http.MethodGet)
|
||||
ws, err := websocket.Accept(w, r, nil)
|
||||
require.NoError(t, err)
|
||||
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
|
||||
_ = coordinator.ServeAgent(conn, id)
|
||||
return
|
||||
case "/api/v2/workspaceagents/me/report-stats":
|
||||
assert.Equal(t, r.Method, http.MethodPost)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
case "/":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
default:
|
||||
t.Fatalf("unexpected request %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
srvURL, _ := url.Parse(srv.URL)
|
||||
|
||||
client := codersdk.New(srvURL)
|
||||
token := uuid.New().String()
|
||||
client.SetSessionToken(token)
|
||||
agentConn := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
TempDir: t.TempDir(),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentConn.Close()
|
||||
})
|
||||
|
||||
handler, closer, err := vscodeipc.New(ctx, client, id, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = closer.Close()
|
||||
})
|
||||
|
||||
// Ensure that we're actually connected!
|
||||
require.Eventually(t, func() bool {
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/network", nil)
|
||||
req.Header.Set(vscodeipc.AuthHeader, token)
|
||||
handler.ServeHTTP(res, req)
|
||||
network := &vscodeipc.NetworkResponse{}
|
||||
err = json.NewDecoder(res.Body).Decode(&network)
|
||||
assert.NoError(t, err)
|
||||
return network.Latency != 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
_, port, err := net.SplitHostPort(srvURL.Host)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("NoSessionToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
|
||||
handler.ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Code)
|
||||
})
|
||||
|
||||
t.Run("MismatchedSessionToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
|
||||
req.Header.Set(vscodeipc.AuthHeader, uuid.NewString())
|
||||
handler.ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Code)
|
||||
})
|
||||
|
||||
t.Run("Port", func(t *testing.T) {
|
||||
// Tests that the port endpoint can be used for forward traffic.
|
||||
// For this test, we simply use the already listening httptest server.
|
||||
t.Parallel()
|
||||
input, output := net.Pipe()
|
||||
defer input.Close()
|
||||
defer output.Close()
|
||||
res := &hijackable{httptest.NewRecorder(), output}
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
|
||||
req.Header.Set(vscodeipc.AuthHeader, token)
|
||||
go handler.ServeHTTP(res, req)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil)
|
||||
require.NoError(t, err)
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return input, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Execute", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Execute isn't supported on Windows yet!")
|
||||
return
|
||||
}
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
data, _ := json.Marshal(vscodeipc.ExecuteRequest{
|
||||
Command: "echo test",
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/execute", bytes.NewReader(data))
|
||||
req.Header.Set(vscodeipc.AuthHeader, token)
|
||||
handler.ServeHTTP(res, req)
|
||||
|
||||
decoder := json.NewDecoder(res.Body)
|
||||
var msg vscodeipc.ExecuteResponse
|
||||
err = decoder.Decode(&msg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test\n", msg.Data)
|
||||
err = decoder.Decode(&msg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, *msg.ExitCode)
|
||||
})
|
||||
}
|
||||
|
||||
type hijackable struct {
|
||||
*httptest.ResponseRecorder
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (h *hijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestVSCodeIPC(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Ensures the vscodeipc command outputs it's running port!
|
||||
// This signifies to the caller that it's ready to accept requests.
|
||||
t.Run("PortOutputs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, workspace, _ := setupWorkspaceForAgent(t, nil)
|
||||
cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(),
|
||||
"--token", client.SessionToken(), "--url", client.URL.String())
|
||||
rdr, wtr := io.Pipe()
|
||||
cmd.SetOut(wtr)
|
||||
ctx, cancelFunc := testutil.Context(t)
|
||||
defer cancelFunc()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, 64)
|
||||
require.Eventually(t, func() bool {
|
||||
t.Log("Looking for address!")
|
||||
var err error
|
||||
_, err = rdr.Read(buf)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
t.Logf("Address: %s\n", buf)
|
||||
|
||||
cancelFunc()
|
||||
<-done
|
||||
})
|
||||
}
|
|
@ -139,7 +139,9 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool {
|
|||
return c.Conn.AwaitReachable(ctx, TailnetIP)
|
||||
}
|
||||
|
||||
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
|
||||
// Ping pings the agent and returns the round-trip time.
|
||||
// The bool returns true if the ping was made P2P.
|
||||
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
|
|
@ -346,7 +346,8 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context) (net.Conn, error) {
|
|||
type DialWorkspaceAgentOptions struct {
|
||||
Logger slog.Logger
|
||||
// BlockEndpoints forced a direct connection through DERP.
|
||||
BlockEndpoints bool
|
||||
BlockEndpoints bool
|
||||
EnableTrafficStats bool
|
||||
}
|
||||
|
||||
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) {
|
||||
|
@ -369,10 +370,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
|
|||
|
||||
ip := tailnet.IP()
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||
DERPMap: connInfo.DERPMap,
|
||||
Logger: options.Logger,
|
||||
BlockEndpoints: options.BlockEndpoints,
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||
DERPMap: connInfo.DERPMap,
|
||||
Logger: options.Logger,
|
||||
BlockEndpoints: options.BlockEndpoints,
|
||||
EnableTrafficStats: options.EnableTrafficStats,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create tailnet: %w", err)
|
||||
|
|
|
@ -81,7 +81,7 @@ func TestReplicas(t *testing.T) {
|
|||
require.Eventually(t, func() bool {
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancelFunc()
|
||||
_, err = conn.Ping(ctx)
|
||||
_, _, err = conn.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
_ = conn.Close()
|
||||
|
@ -124,7 +124,7 @@ func TestReplicas(t *testing.T) {
|
|||
require.Eventually(t, func() bool {
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow)
|
||||
defer cancelFunc()
|
||||
_, err = conn.Ping(ctx)
|
||||
_, _, err = conn.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
_ = conn.Close()
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
buildInputs = with pkgs; [
|
||||
bash
|
||||
bat
|
||||
cairo
|
||||
drpc.defaultPackage.${system}
|
||||
exa
|
||||
getopt
|
||||
|
@ -34,7 +35,10 @@
|
|||
nodejs
|
||||
openssh
|
||||
openssl
|
||||
pango
|
||||
pixman
|
||||
postgresql
|
||||
pkg-config
|
||||
protoc-gen-go
|
||||
ripgrep
|
||||
shellcheck
|
||||
|
|
|
@ -141,7 +141,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn)
|
|||
for i := 0; i < pingAttempts; i++ {
|
||||
_, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts)
|
||||
pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
|
||||
_, err := conn.Ping(pingCtx)
|
||||
_, _, err := conn.Ping(pingCtx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
break
|
||||
|
|
|
@ -77,6 +77,7 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
nodePublicKey := nodePrivateKey.Public()
|
||||
|
||||
netMap := &netmap.NetworkMap{
|
||||
DERPMap: options.DERPMap,
|
||||
NodeKey: nodePublicKey,
|
||||
PrivateKey: nodePrivateKey,
|
||||
Addresses: options.Addresses,
|
||||
|
@ -407,26 +408,34 @@ func (c *Conn) Status() *ipnstate.Status {
|
|||
}
|
||||
|
||||
// Ping sends a Disco ping to the Wireguard engine.
|
||||
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, error) {
|
||||
// The bool returned is true if the ping was performed P2P.
|
||||
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, error) {
|
||||
errCh := make(chan error, 1)
|
||||
durCh := make(chan time.Duration, 1)
|
||||
prChan := make(chan *ipnstate.PingResult, 1)
|
||||
go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
|
||||
if pr.Err != "" {
|
||||
errCh <- xerrors.New(pr.Err)
|
||||
return
|
||||
}
|
||||
durCh <- time.Duration(pr.LatencySeconds * float64(time.Second))
|
||||
prChan <- pr
|
||||
})
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return 0, err
|
||||
return 0, false, err
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
case dur := <-durCh:
|
||||
return dur, nil
|
||||
return 0, false, ctx.Err()
|
||||
case pr := <-prChan:
|
||||
return time.Duration(pr.LatencySeconds * float64(time.Second)), pr.Endpoint != "", nil
|
||||
}
|
||||
}
|
||||
|
||||
// DERPMap returns the currently set DERP mapping.
|
||||
func (c *Conn) DERPMap() *tailcfg.DERPMap {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return c.netMap.DERPMap
|
||||
}
|
||||
|
||||
// AwaitReachable pings the provided IP continually until the
|
||||
// address is reachable. It's the callers responsibility to provide
|
||||
// a timeout, otherwise this function will block forever.
|
||||
|
@ -445,7 +454,7 @@ func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
|
|||
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.Ping(ctx, ip)
|
||||
_, _, err := c.Ping(ctx, ip)
|
||||
if err == nil {
|
||||
completed()
|
||||
}
|
||||
|
@ -523,20 +532,7 @@ func (c *Conn) sendNode() {
|
|||
c.nodeChanged = true
|
||||
return
|
||||
}
|
||||
node := &Node{
|
||||
ID: c.netMap.SelfNode.ID,
|
||||
AsOf: database.Now(),
|
||||
Key: c.netMap.SelfNode.Key,
|
||||
Addresses: c.netMap.SelfNode.Addresses,
|
||||
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
|
||||
DiscoKey: c.magicConn.DiscoPublicKey(),
|
||||
Endpoints: c.lastEndpoints,
|
||||
PreferredDERP: c.lastPreferredDERP,
|
||||
DERPLatency: c.lastDERPLatency,
|
||||
}
|
||||
if c.blockEndpoints {
|
||||
node.Endpoints = nil
|
||||
}
|
||||
node := c.selfNode()
|
||||
nodeCallback := c.nodeCallback
|
||||
if nodeCallback == nil {
|
||||
return
|
||||
|
@ -557,6 +553,31 @@ func (c *Conn) sendNode() {
|
|||
}()
|
||||
}
|
||||
|
||||
// Node returns the last node that was sent to the node callback.
|
||||
func (c *Conn) Node() *Node {
|
||||
c.lastMutex.Lock()
|
||||
defer c.lastMutex.Unlock()
|
||||
return c.selfNode()
|
||||
}
|
||||
|
||||
func (c *Conn) selfNode() *Node {
|
||||
node := &Node{
|
||||
ID: c.netMap.SelfNode.ID,
|
||||
AsOf: database.Now(),
|
||||
Key: c.netMap.SelfNode.Key,
|
||||
Addresses: c.netMap.SelfNode.Addresses,
|
||||
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
|
||||
DiscoKey: c.magicConn.DiscoPublicKey(),
|
||||
Endpoints: c.lastEndpoints,
|
||||
PreferredDERP: c.lastPreferredDERP,
|
||||
DERPLatency: c.lastDERPLatency,
|
||||
}
|
||||
if c.blockEndpoints {
|
||||
node.Endpoints = nil
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// This and below is taken _mostly_ verbatim from Tailscale:
|
||||
// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494
|
||||
|
||||
|
|
Loading…
Reference in New Issue