feat: Add `vscodeipc` subcommand for VS Code Extension (#5326)

* Add extio

* feat: Add `vscodeipc` subcommand for VS Code Extension

This enables the VS Code extension to communicate with a Coder client.
The extension will download the slim binary from `/bin/*` for the
respective client architecture and OS, then execute `coder vscodeipc`
for the connecting workspace.

* Add authentication header, improve comments, and add tests for the CLI

* Update cli/vscodeipc_test.go

Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>

* Update cli/vscodeipc_test.go

Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>

* Update cli/vscodeipc/vscodeipc_test.go

Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>

* Fix requested changes

* Fix IPC tests

* Fix shell execution

* Fix nix flake

* Silence usage

Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
This commit is contained in:
Kyle Carberry 2022-12-18 17:50:06 -06:00 committed by GitHub
parent d1f8fec1d3
commit e61234f260
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 712 additions and 33 deletions

View File

@ -98,6 +98,7 @@ func Core() []*cobra.Command {
users(),
versionCmd(),
workspaceAgent(),
vscodeipcCmd(),
}
}

View File

@ -71,7 +71,7 @@ func speedtest() *cobra.Command {
return ctx.Err()
case <-ticker.C:
}
dur, err := conn.Ping(ctx)
dur, p2p, err := conn.Ping(ctx)
if err != nil {
continue
}
@ -80,7 +80,7 @@ func speedtest() *cobra.Command {
continue
}
peer := status.Peer[status.Peers()[0]]
if peer.CurAddr == "" && direct {
if !p2p && direct {
cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay)
continue
}

View File

@ -65,6 +65,8 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
workspace, err := client.Workspace(context.Background(), workspace.ID)
require.NoError(t, err)
return client, workspace, agentToken
}

88
cli/vscodeipc.go Normal file
View File

@ -0,0 +1,88 @@
package cli
import (
"fmt"
"net"
"net/http"
"net/url"
"github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/vscodeipc"
"github.com/coder/coder/codersdk"
)
// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages.
// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder
func vscodeipcCmd() *cobra.Command {
var (
rawURL string
token string
port uint16
)
cmd := &cobra.Command{
Use: "vscodeipc <workspace-agent>",
Args: cobra.ExactArgs(1),
SilenceUsage: true,
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
if rawURL == "" {
return xerrors.New("CODER_URL must be set!")
}
// token is validated in a header on each request to prevent
// unauthenticated clients from connecting.
if token == "" {
return xerrors.New("CODER_TOKEN must be set!")
}
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return xerrors.Errorf("listen: %w", err)
}
defer listener.Close()
addr, ok := listener.Addr().(*net.TCPAddr)
if !ok {
return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr())
}
url, err := url.Parse(rawURL)
if err != nil {
return err
}
agentID, err := uuid.Parse(args[0])
if err != nil {
return err
}
client := codersdk.New(url)
client.SetSessionToken(token)
handler, closer, err := vscodeipc.New(cmd.Context(), client, agentID, nil)
if err != nil {
return err
}
defer closer.Close()
// nolint:gosec
server := http.Server{
Handler: handler,
}
defer server.Close()
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String())
errChan := make(chan error, 1)
go func() {
err := server.Serve(listener)
errChan <- err
}()
select {
case <-cmd.Context().Done():
return cmd.Context().Err()
case err := <-errChan:
return err
}
},
}
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!")
cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!")
cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!")
return cmd
}

313
cli/vscodeipc/vscodeipc.go Normal file
View File

@ -0,0 +1,313 @@
package vscodeipc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
const AuthHeader = "Coder-IPC-Token"
// New creates a VS Code IPC client that can be used to communicate with workspaces.
//
// Creating this IPC was required instead of using SSH, because we're unable to get
// connection information to display in the bottom-bar when using SSH. It's possible
// we could jank around this (maybe by using a temporary SSH host), but that's not
// ideal.
//
// This persists a single workspace connection, and lets you execute commands, check
// for network information, and forward ports.
//
// The VS Code extension is located at https://github.com/coder/vscode-coder. The
// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc`
// which calls this function. This API must maintain backward compatibility with
// the extension to support prior versions of Coder.
func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) {
if options == nil {
options = &codersdk.DialWorkspaceAgentOptions{}
}
// We need this to track upload and download!
options.EnableTrafficStats = true
agentConn, err := client.DialWorkspaceAgent(ctx, agentID, options)
if err != nil {
return nil, nil, err
}
api := &api{
agentConn: agentConn,
}
r := chi.NewRouter()
// This is to prevent unauthorized clients on the same machine from executing
// requests on behalf of the workspace.
r.Use(sessionTokenMiddleware(client.SessionToken()))
r.Route("/v1", func(r chi.Router) {
r.Get("/port/{port}", api.port)
r.Get("/network", api.network)
r.Post("/execute", api.execute)
})
return r, api, nil
}
type api struct {
agentConn *codersdk.AgentConn
sshClient *ssh.Client
sshClientErr error
sshClientOnce sync.Once
lastNetwork time.Time
}
func (api *api) Close() error {
if api.sshClient != nil {
api.sshClient.Close()
}
return api.agentConn.Close()
}
type NetworkResponse struct {
P2P bool `json:"p2p"`
Latency float64 `json:"latency"`
PreferredDERP string `json:"preferred_derp"`
DERPLatency map[string]float64 `json:"derp_latency"`
UploadBytesSec int64 `json:"upload_bytes_sec"`
DownloadBytesSec int64 `json:"download_bytes_sec"`
}
// port accepts an HTTP request to dial a port on the workspace agent.
// It uses an HTTP connection upgrade to transfer the connection to TCP.
func (api *api) port(w http.ResponseWriter, r *http.Request) {
port, err := strconv.Atoi(chi.URLParam(r, "port"))
if err != nil {
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
Message: "Port must be an integer!",
})
return
}
remoteConn, err := api.agentConn.DialContext(r.Context(), "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
httpapi.InternalServerError(w, err)
return
}
defer remoteConn.Close()
// Upgrade an switch to TCP!
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "tcp")
w.WriteHeader(http.StatusSwitchingProtocols)
hijacker, ok := w.(http.Hijacker)
if !ok {
httpapi.InternalServerError(w, xerrors.Errorf("unable to hijack connection: %T", w))
return
}
localConn, brw, err := hijacker.Hijack()
if err != nil {
httpapi.InternalServerError(w, err)
return
}
defer localConn.Close()
_ = brw.Flush()
agent.Bicopy(r.Context(), localConn, remoteConn)
}
// network returns network information about the workspace.
func (api *api) network(w http.ResponseWriter, r *http.Request) {
// Ping the workspace agent to get the latency.
latency, p2p, err := api.agentConn.Ping(r.Context())
if err != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to ping the workspace agent.",
Detail: err.Error(),
})
return
}
node := api.agentConn.Node()
derpMap := api.agentConn.DERPMap()
derpLatency := map[string]float64{}
// Convert DERP region IDs to friendly names for display in the UI.
for rawRegion, latency := range node.DERPLatency {
regionParts := strings.SplitN(rawRegion, "-", 2)
regionID, err := strconv.Atoi(regionParts[0])
if err != nil {
continue
}
region, found := derpMap.Regions[regionID]
if !found {
// It's possible that a workspace agent is using an old DERPMap
// and reports regions that do not exist. If that's the case,
// report the region as unknown!
region = &tailcfg.DERPRegion{
RegionID: regionID,
RegionName: fmt.Sprintf("Unnamed %d", regionID),
}
}
// Convert the microseconds to milliseconds.
derpLatency[region.RegionName] = latency * 1000
}
totalRx := uint64(0)
totalTx := uint64(0)
for _, stat := range api.agentConn.ExtractTrafficStats() {
totalRx += stat.RxBytes
totalTx += stat.TxBytes
}
// Tracking the time since last request is required because
// ExtractTrafficStats() resets its counters after each call.
dur := time.Since(api.lastNetwork)
uploadSecs := float64(totalTx) / dur.Seconds()
downloadSecs := float64(totalRx) / dur.Seconds()
api.lastNetwork = time.Now()
httpapi.Write(r.Context(), w, http.StatusOK, NetworkResponse{
P2P: p2p,
Latency: float64(latency.Microseconds()) / 1000,
PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName,
DERPLatency: derpLatency,
UploadBytesSec: int64(uploadSecs),
DownloadBytesSec: int64(downloadSecs),
})
}
type ExecuteRequest struct {
Command string `json:"command"`
Stdin string `json:"stdin"`
}
type ExecuteResponse struct {
Data string `json:"data"`
ExitCode *int `json:"exit_code"`
}
// execute runs the command provided, streams the output back, and returns the exit code.
func (api *api) execute(w http.ResponseWriter, r *http.Request) {
var req ExecuteRequest
if !httpapi.Read(r.Context(), w, r, &req) {
return
}
api.sshClientOnce.Do(func() {
// The SSH client is lazily created because it's not needed for
// all requests. It's only needed for the execute endpoint.
//
// It's alright if this fails on the first execution, because
// a new instance of this API is created for each remote SSH request.
api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background())
})
if api.sshClientErr != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to create SSH client.",
Detail: api.sshClientErr.Error(),
})
return
}
session, err := api.sshClient.NewSession()
if err != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to create SSH session.",
Detail: err.Error(),
})
return
}
defer session.Close()
f, ok := w.(http.Flusher)
if !ok {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w),
})
return
}
execWriter := &execWriter{w, f}
session.Stdout = execWriter
session.Stderr = execWriter
session.Stdin = strings.NewReader(req.Stdin)
err = session.Start(req.Command)
if err != nil {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start SSH session.",
Detail: err.Error(),
})
return
}
err = session.Wait()
writeExit := func(exitCode int) {
data, _ := json.Marshal(&ExecuteResponse{
ExitCode: &exitCode,
})
_, _ = w.Write(data)
f.Flush()
}
if err != nil {
var exitError *ssh.ExitError
if errors.As(err, &exitError) {
writeExit(exitError.ExitStatus())
return
}
}
writeExit(0)
}
type execWriter struct {
w http.ResponseWriter
f http.Flusher
}
func (e *execWriter) Write(data []byte) (int, error) {
js, err := json.Marshal(&ExecuteResponse{
Data: string(data),
})
if err != nil {
return 0, err
}
_, err = e.w.Write(js)
if err != nil {
return 0, err
}
e.f.Flush()
return len(data), nil
}
func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get(AuthHeader)
if token == "" {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: fmt.Sprintf("A session token must be provided in the `%s` header.", AuthHeader),
})
return
}
if token != sessionToken {
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
Message: "The session token provided doesn't match the one used to create the client.",
})
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
h.ServeHTTP(w, r)
})
}
}

View File

@ -0,0 +1,202 @@
package vscodeipc_test
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"nhooyr.io/websocket"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/vscodeipc"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestVSCodeIPC(t *testing.T) {
t.Parallel()
ctx := context.Background()
id := uuid.New()
derpMap := tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator()
t.Cleanup(func() {
_ = coordinator.Close()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id):
assert.Equal(t, r.Method, http.MethodGet)
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
DERPMap: derpMap,
})
return
case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id):
assert.Equal(t, r.Method, http.MethodGet)
ws, err := websocket.Accept(w, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
_ = coordinator.ServeClient(conn, uuid.New(), id)
return
case "/api/v2/workspaceagents/me/version":
assert.Equal(t, r.Method, http.MethodPost)
w.WriteHeader(http.StatusOK)
return
case "/api/v2/workspaceagents/me/metadata":
assert.Equal(t, r.Method, http.MethodGet)
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{
DERPMap: derpMap,
})
return
case "/api/v2/workspaceagents/me/coordinate":
assert.Equal(t, r.Method, http.MethodGet)
ws, err := websocket.Accept(w, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
_ = coordinator.ServeAgent(conn, id)
return
case "/api/v2/workspaceagents/me/report-stats":
assert.Equal(t, r.Method, http.MethodPost)
w.WriteHeader(http.StatusOK)
return
case "/":
w.WriteHeader(http.StatusOK)
return
default:
t.Fatalf("unexpected request %s", r.URL.Path)
}
}))
t.Cleanup(srv.Close)
srvURL, _ := url.Parse(srv.URL)
client := codersdk.New(srvURL)
token := uuid.New().String()
client.SetSessionToken(token)
agentConn := agent.New(agent.Options{
Client: client,
Filesystem: afero.NewMemMapFs(),
TempDir: t.TempDir(),
})
t.Cleanup(func() {
_ = agentConn.Close()
})
handler, closer, err := vscodeipc.New(ctx, client, id, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = closer.Close()
})
// Ensure that we're actually connected!
require.Eventually(t, func() bool {
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/v1/network", nil)
req.Header.Set(vscodeipc.AuthHeader, token)
handler.ServeHTTP(res, req)
network := &vscodeipc.NetworkResponse{}
err = json.NewDecoder(res.Body).Decode(&network)
assert.NoError(t, err)
return network.Latency != 0
}, testutil.WaitLong, testutil.IntervalFast)
_, port, err := net.SplitHostPort(srvURL.Host)
require.NoError(t, err)
t.Run("NoSessionToken", func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
handler.ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Code)
})
t.Run("MismatchedSessionToken", func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
req.Header.Set(vscodeipc.AuthHeader, uuid.NewString())
handler.ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Code)
})
t.Run("Port", func(t *testing.T) {
// Tests that the port endpoint can be used for forward traffic.
// For this test, we simply use the already listening httptest server.
t.Parallel()
input, output := net.Pipe()
defer input.Close()
defer output.Close()
res := &hijackable{httptest.NewRecorder(), output}
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil)
req.Header.Set(vscodeipc.AuthHeader, token)
go handler.ServeHTTP(res, req)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil)
require.NoError(t, err)
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return input, nil
},
},
}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("Execute", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Execute isn't supported on Windows yet!")
return
}
res := httptest.NewRecorder()
data, _ := json.Marshal(vscodeipc.ExecuteRequest{
Command: "echo test",
})
req := httptest.NewRequest(http.MethodPost, "/v1/execute", bytes.NewReader(data))
req.Header.Set(vscodeipc.AuthHeader, token)
handler.ServeHTTP(res, req)
decoder := json.NewDecoder(res.Body)
var msg vscodeipc.ExecuteResponse
err = decoder.Decode(&msg)
require.NoError(t, err)
require.Equal(t, "test\n", msg.Data)
err = decoder.Decode(&msg)
require.NoError(t, err)
require.Equal(t, 0, *msg.ExitCode)
})
}
type hijackable struct {
*httptest.ResponseRecorder
conn net.Conn
}
func (h *hijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil
}

44
cli/vscodeipc_test.go Normal file
View File

@ -0,0 +1,44 @@
package cli_test
import (
"io"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/testutil"
)
func TestVSCodeIPC(t *testing.T) {
t.Parallel()
// Ensures the vscodeipc command outputs it's running port!
// This signifies to the caller that it's ready to accept requests.
t.Run("PortOutputs", func(t *testing.T) {
t.Parallel()
client, workspace, _ := setupWorkspaceForAgent(t, nil)
cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(),
"--token", client.SessionToken(), "--url", client.URL.String())
rdr, wtr := io.Pipe()
cmd.SetOut(wtr)
ctx, cancelFunc := testutil.Context(t)
defer cancelFunc()
done := make(chan error, 1)
go func() {
err := cmd.ExecuteContext(ctx)
done <- err
}()
buf := make([]byte, 64)
require.Eventually(t, func() bool {
t.Log("Looking for address!")
var err error
_, err = rdr.Read(buf)
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
t.Logf("Address: %s\n", buf)
cancelFunc()
<-done
})
}

View File

@ -139,7 +139,9 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool {
return c.Conn.AwaitReachable(ctx, TailnetIP)
}
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
// Ping pings the agent and returns the round-trip time.
// The bool returns true if the ping was made P2P.
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()

View File

@ -346,7 +346,8 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context) (net.Conn, error) {
type DialWorkspaceAgentOptions struct {
Logger slog.Logger
// BlockEndpoints forced a direct connection through DERP.
BlockEndpoints bool
BlockEndpoints bool
EnableTrafficStats bool
}
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) {
@ -369,10 +370,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
ip := tailnet.IP()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
Logger: options.Logger,
BlockEndpoints: options.BlockEndpoints,
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
DERPMap: connInfo.DERPMap,
Logger: options.Logger,
BlockEndpoints: options.BlockEndpoints,
EnableTrafficStats: options.EnableTrafficStats,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)

View File

@ -81,7 +81,7 @@ func TestReplicas(t *testing.T) {
require.Eventually(t, func() bool {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancelFunc()
_, err = conn.Ping(ctx)
_, _, err = conn.Ping(ctx)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
_ = conn.Close()
@ -124,7 +124,7 @@ func TestReplicas(t *testing.T) {
require.Eventually(t, func() bool {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow)
defer cancelFunc()
_, err = conn.Ping(ctx)
_, _, err = conn.Ping(ctx)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
_ = conn.Close()

View File

@ -18,6 +18,7 @@
buildInputs = with pkgs; [
bash
bat
cairo
drpc.defaultPackage.${system}
exa
getopt
@ -34,7 +35,10 @@
nodejs
openssh
openssl
pango
pixman
postgresql
pkg-config
protoc-gen-go
ripgrep
shellcheck

View File

@ -141,7 +141,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn)
for i := 0; i < pingAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts)
pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
_, err := conn.Ping(pingCtx)
_, _, err := conn.Ping(pingCtx)
cancel()
if err == nil {
break

View File

@ -77,6 +77,7 @@ func NewConn(options *Options) (*Conn, error) {
nodePublicKey := nodePrivateKey.Public()
netMap := &netmap.NetworkMap{
DERPMap: options.DERPMap,
NodeKey: nodePublicKey,
PrivateKey: nodePrivateKey,
Addresses: options.Addresses,
@ -407,26 +408,34 @@ func (c *Conn) Status() *ipnstate.Status {
}
// Ping sends a Disco ping to the Wireguard engine.
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, error) {
// The bool returned is true if the ping was performed P2P.
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, error) {
errCh := make(chan error, 1)
durCh := make(chan time.Duration, 1)
prChan := make(chan *ipnstate.PingResult, 1)
go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
if pr.Err != "" {
errCh <- xerrors.New(pr.Err)
return
}
durCh <- time.Duration(pr.LatencySeconds * float64(time.Second))
prChan <- pr
})
select {
case err := <-errCh:
return 0, err
return 0, false, err
case <-ctx.Done():
return 0, ctx.Err()
case dur := <-durCh:
return dur, nil
return 0, false, ctx.Err()
case pr := <-prChan:
return time.Duration(pr.LatencySeconds * float64(time.Second)), pr.Endpoint != "", nil
}
}
// DERPMap returns the currently set DERP mapping.
func (c *Conn) DERPMap() *tailcfg.DERPMap {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.DERPMap
}
// AwaitReachable pings the provided IP continually until the
// address is reachable. It's the callers responsibility to provide
// a timeout, otherwise this function will block forever.
@ -445,7 +454,7 @@ func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool {
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
_, err := c.Ping(ctx, ip)
_, _, err := c.Ping(ctx, ip)
if err == nil {
completed()
}
@ -523,20 +532,7 @@ func (c *Conn) sendNode() {
c.nodeChanged = true
return
}
node := &Node{
ID: c.netMap.SelfNode.ID,
AsOf: database.Now(),
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: c.lastEndpoints,
PreferredDERP: c.lastPreferredDERP,
DERPLatency: c.lastDERPLatency,
}
if c.blockEndpoints {
node.Endpoints = nil
}
node := c.selfNode()
nodeCallback := c.nodeCallback
if nodeCallback == nil {
return
@ -557,6 +553,31 @@ func (c *Conn) sendNode() {
}()
}
// Node returns the last node that was sent to the node callback.
func (c *Conn) Node() *Node {
c.lastMutex.Lock()
defer c.lastMutex.Unlock()
return c.selfNode()
}
func (c *Conn) selfNode() *Node {
node := &Node{
ID: c.netMap.SelfNode.ID,
AsOf: database.Now(),
Key: c.netMap.SelfNode.Key,
Addresses: c.netMap.SelfNode.Addresses,
AllowedIPs: c.netMap.SelfNode.AllowedIPs,
DiscoKey: c.magicConn.DiscoPublicKey(),
Endpoints: c.lastEndpoints,
PreferredDERP: c.lastPreferredDERP,
DERPLatency: c.lastDERPLatency,
}
if c.blockEndpoints {
node.Endpoints = nil
}
return node
}
// This and below is taken _mostly_ verbatim from Tailscale:
// https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494