mirror of https://github.com/coder/coder.git
389 lines
11 KiB
Go
389 lines
11 KiB
Go
package agentconn
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/sloghuman"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/scaletest/harness"
|
|
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
|
)
|
|
|
|
const defaultRequestTimeout = 5 * time.Second
|
|
|
|
type holdDurationEndedError struct{}
|
|
|
|
func (holdDurationEndedError) Error() string {
|
|
return "hold duration ended"
|
|
}
|
|
|
|
type Runner struct {
|
|
client *codersdk.Client
|
|
cfg Config
|
|
}
|
|
|
|
var _ harness.Runnable = &Runner{}
|
|
|
|
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
|
return &Runner{
|
|
client: client,
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
// Run implements Runnable.
|
|
func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error {
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
logs := loadtestutil.NewSyncWriter(w)
|
|
defer logs.Close()
|
|
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
|
r.client.SetLogger(logger)
|
|
r.client.SetLogBodies(true)
|
|
|
|
_, _ = fmt.Fprintln(logs, "Opening connection to workspace agent")
|
|
switch r.cfg.ConnectionMode {
|
|
case ConnectionModeDirect:
|
|
_, _ = fmt.Fprintln(logs, "\tUsing direct connection...")
|
|
case ConnectionModeDerp:
|
|
_, _ = fmt.Fprintln(logs, "\tUsing proxied DERP connection through coder server...")
|
|
}
|
|
|
|
conn, err := workspacesdk.New(r.client).
|
|
DialAgent(ctx, r.cfg.AgentID, &workspacesdk.DialAgentOptions{
|
|
Logger: logger.Named("agentconn"),
|
|
// If the config requested DERP, then force DERP.
|
|
BlockEndpoints: r.cfg.ConnectionMode == ConnectionModeDerp,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("dial workspace agent: %w", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = waitForDisco(ctx, logs, conn)
|
|
if err != nil {
|
|
return xerrors.Errorf("wait for discovery connection: %w", err)
|
|
}
|
|
|
|
// Wait for a direct connection if requested.
|
|
if r.cfg.ConnectionMode == ConnectionModeDirect {
|
|
err = waitForDirectConnection(ctx, logs, conn)
|
|
if err != nil {
|
|
return xerrors.Errorf("wait for direct connection: %w", err)
|
|
}
|
|
}
|
|
|
|
// Ensure DERP for completeness.
|
|
if r.cfg.ConnectionMode == ConnectionModeDerp {
|
|
status := conn.Status()
|
|
if len(status.Peers()) != 1 {
|
|
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers()))
|
|
}
|
|
peer := status.Peer[status.Peers()[0]]
|
|
if peer.Relay == "" || peer.CurAddr != "" {
|
|
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP")
|
|
}
|
|
}
|
|
|
|
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n")
|
|
|
|
// HACK: even though the ping passed above, we still need to open a
|
|
// connection to the agent to ensure it's ready to accept connections. Not
|
|
// sure why this is the case but it seems to be necessary.
|
|
err = verifyConnection(ctx, logs, conn)
|
|
if err != nil {
|
|
return xerrors.Errorf("verify connection: %w", err)
|
|
}
|
|
|
|
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n")
|
|
|
|
// Make initial connections sequentially to ensure the services are
|
|
// reachable before we start spawning a bunch of goroutines and tickers.
|
|
err = performInitialConnections(ctx, logs, conn, r.cfg.Connections)
|
|
if err != nil {
|
|
return xerrors.Errorf("perform initial connections: %w", err)
|
|
}
|
|
|
|
if r.cfg.HoldDuration > 0 {
|
|
err = holdConnection(ctx, logs, conn, time.Duration(r.cfg.HoldDuration), r.cfg.Connections)
|
|
if err != nil {
|
|
return xerrors.Errorf("hold connection: %w", err)
|
|
}
|
|
}
|
|
|
|
err = conn.Close()
|
|
if err != nil {
|
|
return xerrors.Errorf("close connection: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error {
|
|
const pingAttempts = 10
|
|
const pingDelay = 1 * time.Second
|
|
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
for i := 0; i < pingAttempts; i++ {
|
|
_, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts)
|
|
pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
|
|
_, p2p, _, err := conn.Ping(pingCtx)
|
|
cancel()
|
|
if err == nil {
|
|
_, _ = fmt.Fprintf(logs, "\tDisco ping succeeded after %d attempts, p2p = %v\n", i+1, p2p)
|
|
break
|
|
}
|
|
if i == pingAttempts-1 {
|
|
return xerrors.Errorf("ping workspace agent: %w", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return xerrors.Errorf("wait for connection to be established: %w", ctx.Err())
|
|
// We use time.After here since it's a very short duration so leaking a
|
|
// timer is fine.
|
|
case <-time.After(pingDelay):
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error {
|
|
const directConnectionAttempts = 30
|
|
const directConnectionDelay = 1 * time.Second
|
|
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
for i := 0; i < directConnectionAttempts; i++ {
|
|
_, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
|
|
status := conn.Status()
|
|
|
|
var err error
|
|
if len(status.Peers()) != 1 {
|
|
_, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d\n", len(status.Peers()))
|
|
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
|
|
} else {
|
|
peer := status.Peer[status.Peers()[0]]
|
|
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
|
|
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
|
|
if peer.Relay != "" && peer.CurAddr == "" {
|
|
err = xerrors.Errorf("peer is connected via DERP, not direct")
|
|
}
|
|
}
|
|
if err == nil {
|
|
break
|
|
}
|
|
if i == directConnectionAttempts-1 {
|
|
return xerrors.Errorf("wait for direct connection to agent: %w", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
|
|
// We use time.After here since it's a very short duration so
|
|
// leaking a timer is fine.
|
|
case <-time.After(directConnectionDelay):
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error {
|
|
const verifyConnectionAttempts = 30
|
|
const verifyConnectionDelay = 1 * time.Second
|
|
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
client := agentHTTPClient(conn)
|
|
for i := 0; i < verifyConnectionAttempts; i++ {
|
|
_, _ = fmt.Fprintf(logs, "\tVerify connection attempt %d/%d...\n", i+1, verifyConnectionAttempts)
|
|
verifyCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
|
|
|
|
u := &url.URL{
|
|
Scheme: "http",
|
|
Host: net.JoinHostPort("localhost", strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort)),
|
|
Path: "/",
|
|
}
|
|
req, err := http.NewRequestWithContext(verifyCtx, http.MethodGet, u.String(), nil)
|
|
if err != nil {
|
|
cancel()
|
|
return xerrors.Errorf("new verify connection request to %q: %w", u.String(), err)
|
|
}
|
|
resp, err := client.Do(req)
|
|
cancel()
|
|
if err == nil {
|
|
_ = resp.Body.Close()
|
|
break
|
|
}
|
|
if i == verifyConnectionAttempts-1 {
|
|
return xerrors.Errorf("verify connection: %w", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return xerrors.Errorf("verify connection: %w", ctx.Err())
|
|
case <-time.After(verifyConnectionDelay):
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func performInitialConnections(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, specs []Connection) error {
|
|
if len(specs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
|
|
client := agentHTTPClient(conn)
|
|
for i, connSpec := range specs {
|
|
_, _ = fmt.Fprintf(logs, "\t%d. %s\n", i, connSpec.URL)
|
|
|
|
timeout := defaultRequestTimeout
|
|
if connSpec.Timeout > 0 {
|
|
timeout = time.Duration(connSpec.Timeout)
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
|
|
if err != nil {
|
|
cancel()
|
|
return xerrors.Errorf("create request: %w", err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
cancel()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(logs, "\t\tFailed: %+v\n", err)
|
|
return xerrors.Errorf("make initial connection to conn spec %d %q: %w", i, connSpec.URL, err)
|
|
}
|
|
_ = res.Body.Close()
|
|
|
|
_, _ = fmt.Fprintln(logs, "\t\tOK")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error {
|
|
ctx, span := tracing.StartSpan(ctx)
|
|
defer span.End()
|
|
|
|
eg, egCtx := errgroup.WithContext(ctx)
|
|
client := agentHTTPClient(conn)
|
|
if len(specs) > 0 {
|
|
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...")
|
|
}
|
|
for i, connSpec := range specs {
|
|
i, connSpec := i, connSpec
|
|
if connSpec.Interval <= 0 {
|
|
continue
|
|
}
|
|
|
|
eg.Go(func() error {
|
|
t := time.NewTicker(time.Duration(connSpec.Interval))
|
|
defer t.Stop()
|
|
|
|
timeout := defaultRequestTimeout
|
|
if connSpec.Timeout > 0 {
|
|
timeout = time.Duration(connSpec.Timeout)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-egCtx.Done():
|
|
return egCtx.Err()
|
|
case <-t.C:
|
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
|
|
if err != nil {
|
|
cancel()
|
|
return xerrors.Errorf("create request: %w", err)
|
|
}
|
|
|
|
res, err := client.Do(req)
|
|
cancel()
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
|
|
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
|
|
}
|
|
res.Body.Close()
|
|
|
|
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
|
|
t.Reset(time.Duration(connSpec.Interval))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// Wait for the hold duration to end. We use a fake error to signal that
|
|
// the hold duration has ended.
|
|
_, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", holdDur)
|
|
eg.Go(func() error {
|
|
t := time.NewTicker(holdDur)
|
|
defer t.Stop()
|
|
|
|
select {
|
|
case <-egCtx.Done():
|
|
return egCtx.Err()
|
|
case <-t.C:
|
|
// Returning an error here will cause the errgroup context to
|
|
// be canceled, which is what we want. This fake error is
|
|
// ignored below.
|
|
return holdDurationEndedError{}
|
|
}
|
|
})
|
|
|
|
err := eg.Wait()
|
|
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
|
|
return xerrors.Errorf("run connections loop: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func agentHTTPClient(conn *workspacesdk.AgentConn) *http.Client {
|
|
return &http.Client{
|
|
Transport: &http.Transport{
|
|
DisableKeepAlives: true,
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
_, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
|
|
}
|
|
|
|
portUint, err := strconv.ParseUint(port, 10, 16)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse port %q: %w", port, err)
|
|
}
|
|
|
|
// Addr doesn't matter here, besides the port. DialContext will
|
|
// automatically choose the right IP to dial.
|
|
return conn.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", portUint))
|
|
},
|
|
},
|
|
}
|
|
}
|