feat: make trace provider in loadtest, add tracing to sdk (#4939)

This commit is contained in:
Dean Sheather 2022-11-09 08:10:48 +10:00 committed by GitHub
parent fa844d0878
commit d82364b9b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 757 additions and 206 deletions

View File

@ -58,9 +58,12 @@ func TestAgent(t *testing.T) {
t.Run("SSH", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
sshClient, err := conn.SSHClient()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
@ -75,9 +78,12 @@ func TestAgent(t *testing.T) {
t.Run("ReconnectingPTY", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.NewString(), 128, 128, "/bin/bash")
require.NoError(t, err)
defer ptyConn.Close()
@ -217,6 +223,8 @@ func TestAgent(t *testing.T) {
t.Run("SFTP", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
u, err := user.Current()
require.NoError(t, err, "get current user")
home := u.HomeDir
@ -224,7 +232,7 @@ func TestAgent(t *testing.T) {
home = "/" + strings.ReplaceAll(home, "\\", "/")
}
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
sshClient, err := conn.SSHClient()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
client, err := sftp.NewClient(sshClient)
@ -250,8 +258,11 @@ func TestAgent(t *testing.T) {
t.Run("SCP", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
sshClient, err := conn.SSHClient()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
scpClient, err := scp.NewClientBySSH(sshClient)
@ -386,9 +397,12 @@ func TestAgent(t *testing.T) {
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
id := uuid.NewString()
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err)
bufRead := bufio.NewReader(netConn)
@ -426,7 +440,7 @@ func TestAgent(t *testing.T) {
expectLine(matchEchoOutput)
_ = netConn.Close()
netConn, err = conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err)
bufRead = bufio.NewReader(netConn)
@ -504,12 +518,14 @@ func TestAgent(t *testing.T) {
t.Run("Speedtest", func(t *testing.T) {
t.Parallel()
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
derpMap := tailnettest.RunDERPAndSTUN(t)
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{
DERPMap: derpMap,
}, 0)
defer conn.Close()
res, err := conn.Speedtest(speedtest.Upload, 250*time.Millisecond)
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
require.NoError(t, err)
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
})
@ -599,7 +615,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
if err != nil {
return
}
ssh, err := agentConn.SSH()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
ssh, err := agentConn.SSH(ctx)
cancel()
if err != nil {
_ = conn.Close()
return
@ -626,8 +645,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
}
func setupSSHSession(t *testing.T, options codersdk.WorkspaceAgentMetadata) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, _ := setupAgent(t, options, 0)
sshClient, err := conn.SSHClient()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClient.Close()

View File

@ -198,7 +198,7 @@ func TestWorkspaceAgent(t *testing.T) {
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
sshClient, err := dialer.SSHClient()
sshClient, err := dialer.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()

View File

@ -28,6 +28,7 @@ import (
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
)
func sshConfigFileName(t *testing.T) (sshConfig string) {
@ -131,7 +132,9 @@ func TestConfigSSH(t *testing.T) {
if err != nil {
break
}
ssh, err := agentConn.SSH()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
ssh, err := agentConn.SSH(ctx)
cancel()
assert.NoError(t, err)
wg.Add(2)
go func() {

View File

@ -8,20 +8,30 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/harness"
)
const loadtestTracerName = "coder_loadtest"
func loadtest() *cobra.Command {
var (
configPath string
outputSpecs []string
traceEnable bool
traceCoder bool
traceHoneycombAPIKey string
tracePropagate bool
)
cmd := &cobra.Command{
Use: "loadtest --config <path> [--output json[:path]] [--output text[:path]]]",
@ -53,6 +63,8 @@ func loadtest() *cobra.Command {
Hidden: true,
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
ctx := tracing.SetTracerName(cmd.Context(), loadtestTracerName)
config, err := loadLoadTestConfigFile(configPath, cmd.InOrStdin())
if err != nil {
return err
@ -67,7 +79,7 @@ func loadtest() *cobra.Command {
return err
}
me, err := client.User(cmd.Context(), codersdk.Me)
me, err := client.User(ctx, codersdk.Me)
if err != nil {
return xerrors.Errorf("fetch current user: %w", err)
}
@ -84,11 +96,43 @@ func loadtest() *cobra.Command {
}
}
if !ok {
return xerrors.Errorf("Not logged in as site owner. Load testing is only available to site owners.")
return xerrors.Errorf("Not logged in as a site owner. Load testing is only available to site owners.")
}
// Disable ratelimits for future requests.
// Setup tracing and start a span.
var (
shouldTrace = traceEnable || traceCoder || traceHoneycombAPIKey != ""
tracerProvider trace.TracerProvider = trace.NewNoopTracerProvider()
closeTracingOnce sync.Once
closeTracing = func(_ context.Context) error {
return nil
}
)
if shouldTrace {
tracerProvider, closeTracing, err = tracing.TracerProvider(ctx, loadtestTracerName, tracing.TracerOpts{
Default: traceEnable,
Coder: traceCoder,
Honeycomb: traceHoneycombAPIKey,
})
if err != nil {
return xerrors.Errorf("initialize tracing: %w", err)
}
defer func() {
closeTracingOnce.Do(func() {
// Allow time for traces to flush even if command
// context is canceled.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = closeTracing(ctx)
})
}()
}
tracer := tracerProvider.Tracer(loadtestTracerName)
// Disable ratelimits and propagate tracing spans for future
// requests. Individual tests will setup their own loggers.
client.BypassRatelimits = true
client.PropagateTracing = tracePropagate
// Prepare the test.
strategy := config.Strategy.ExecutionStrategy()
@ -99,18 +143,22 @@ func loadtest() *cobra.Command {
for j := 0; j < t.Count; j++ {
id := strconv.Itoa(j)
runner, err := t.NewRunner(client)
runner, err := t.NewRunner(client.Clone())
if err != nil {
return xerrors.Errorf("create %q runner for %s/%s: %w", t.Type, name, id, err)
}
th.AddRun(name, id, runner)
th.AddRun(name, id, &runnableTraceWrapper{
tracer: tracer,
spanName: fmt.Sprintf("%s/%s", name, id),
runner: runner,
})
}
}
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...")
testCtx := cmd.Context()
testCtx := ctx
if config.Timeout > 0 {
var cancel func()
testCtx, cancel = context.WithTimeout(testCtx, time.Duration(config.Timeout))
@ -158,11 +206,24 @@ func loadtest() *cobra.Command {
// Cleanup.
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...")
err = th.Cleanup(cmd.Context())
err = th.Cleanup(ctx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
// Upload traces.
if shouldTrace {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nUploading traces...")
closeTracingOnce.Do(func() {
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
err := closeTracing(ctx)
if err != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nError uploading traces: %+v\n", err)
}
})
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
@ -173,6 +234,12 @@ func loadtest() *cobra.Command {
cliflag.StringVarP(cmd.Flags(), &configPath, "config", "", "CODER_LOADTEST_CONFIG_PATH", "", "Path to the load test configuration file, or - to read from stdin.")
cliflag.StringArrayVarP(cmd.Flags(), &outputSpecs, "output", "", "CODER_LOADTEST_OUTPUTS", []string{"text"}, "Output formats, see usage for more information.")
cliflag.BoolVarP(cmd.Flags(), &traceEnable, "trace", "", "CODER_LOADTEST_TRACE", false, "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md")
cliflag.BoolVarP(cmd.Flags(), &traceCoder, "trace-coder", "", "CODER_LOADTEST_TRACE_CODER", false, "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.")
cliflag.StringVarP(cmd.Flags(), &traceHoneycombAPIKey, "trace-honeycomb-api-key", "", "CODER_LOADTEST_TRACE_HONEYCOMB_API_KEY", "", "Enables trace exporting to Honeycomb.io using the provided API key.")
cliflag.BoolVarP(cmd.Flags(), &tracePropagate, "trace-propagate", "", "CODER_LOADTEST_TRACE_PROPAGATE", false, "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.")
return cmd
}
@ -271,3 +338,53 @@ func parseLoadTestOutputs(outputs []string) ([]loadTestOutput, error) {
return out, nil
}
type runnableTraceWrapper struct {
tracer trace.Tracer
spanName string
runner harness.Runnable
span trace.Span
}
var _ harness.Runnable = &runnableTraceWrapper{}
var _ harness.Cleanable = &runnableTraceWrapper{}
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
ctx, span := r.tracer.Start(ctx, r.spanName, trace.WithNewRoot())
defer span.End()
r.span = span
traceID := "unknown trace ID"
spanID := "unknown span ID"
if span.SpanContext().HasTraceID() {
traceID = span.SpanContext().TraceID().String()
}
if span.SpanContext().HasSpanID() {
spanID = span.SpanContext().SpanID().String()
}
_, _ = fmt.Fprintf(logs, "Trace ID: %s\n", traceID)
_, _ = fmt.Fprintf(logs, "Span ID: %s\n\n", spanID)
// Make a separate span for the run itself so the sub-spans are grouped
// neatly. The cleanup span is also a child of the above span so this is
// important for readability.
ctx2, span2 := r.tracer.Start(ctx, r.spanName+" run")
defer span2.End()
return r.runner.Run(ctx2, id, logs)
}
func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string) error {
c, ok := r.runner.(harness.Cleanable)
if !ok {
return nil
}
if r.span != nil {
ctx = trace.ContextWithSpanContext(ctx, r.span.SpanContext())
}
ctx, span := r.tracer.Start(ctx, r.spanName+" cleanup")
defer span.End()
return c.Cleanup(ctx, id)
}

View File

@ -277,6 +277,8 @@ func TestLoadTest(t *testing.T) {
require.NoError(t, err, msg)
}
t.Logf("output %d:\n\n%s", i, string(b))
switch output.format {
case "text":
require.Contains(t, string(b), "Test results:", msg)

View File

@ -128,8 +128,9 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
if cfg.Trace.Enable.Value || shouldCoderTrace {
sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
Default: cfg.Trace.Enable.Value,
Coder: shouldCoderTrace,
Default: cfg.Trace.Enable.Value,
Coder: shouldCoderTrace,
Honeycomb: cfg.Trace.HoneycombAPIKey.Value,
})
if err != nil {
logger.Warn(ctx, "start telemetry exporter", slog.Error(err))

View File

@ -95,7 +95,7 @@ func speedtest() *cobra.Command {
dir = tsspeedtest.Upload
}
cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), dir)
results, err := conn.Speedtest(dir, duration)
results, err := conn.Speedtest(ctx, dir, duration)
if err != nil {
return err
}

View File

@ -100,7 +100,7 @@ func ssh() *cobra.Command {
defer stopPolling()
if stdio {
rawSSH, err := conn.SSH()
rawSSH, err := conn.SSH(ctx)
if err != nil {
return err
}
@ -113,7 +113,7 @@ func ssh() *cobra.Command {
return nil
}
sshClient, err := conn.SSHClient()
sshClient, err := conn.SSHClient(ctx)
if err != nil {
return err
}

View File

@ -88,7 +88,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
require.NoError(t, err)
defer conn.Close()
sshConn, err := conn.SSHClient()
sshConn, err := conn.SSHClient(ctx)
require.NoError(t, err)
_ = sshConn.Close()

View File

@ -633,7 +633,7 @@ func TestTemplateMetrics(t *testing.T) {
_ = conn.Close()
}()
sshConn, err := conn.SSHClient()
sshConn, err := conn.SSHClient(ctx)
require.NoError(t, err)
_ = sshConn.Close()

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/go-logr/logr"
"github.com/hashicorp/go-multierror"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
@ -82,11 +83,23 @@ func TracerProvider(ctx context.Context, service string, opts TracerOpts) (*sdkt
otel.SetLogger(logr.Discard())
return tracerProvider, func(ctx context.Context) error {
for _, close := range closers {
_ = close(ctx)
var merr error
err := tracerProvider.ForceFlush(ctx)
if err != nil {
merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.ForceFlush(): %w", err))
}
_ = tracerProvider.Shutdown(ctx)
return nil
for i, closer := range closers {
err = closer(ctx)
if err != nil {
merr = multierror.Append(merr, xerrors.Errorf("closer() %d: %w", i, err))
}
}
err = tracerProvider.Shutdown(ctx)
if err != nil {
merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.Shutdown(): %w", err))
}
return merr
}, nil
}

View File

@ -6,6 +6,8 @@ import (
"net/http"
"github.com/go-chi/chi/v5"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace"
)
@ -23,11 +25,27 @@ func Middleware(tracerProvider trace.TracerProvider) func(http.Handler) http.Han
return
}
// Extract the trace context from the request headers.
tmp := otel.GetTextMapPropagator()
hc := propagation.HeaderCarrier(r.Header)
ctx := tmp.Extract(r.Context(), hc)
// start span with default span name. Span name will be updated to "method route" format once request finishes.
ctx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.RequestURI))
ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI))
defer span.End()
r = r.WithContext(ctx)
if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() {
// Technically these values are included in the Traceparent
// header, but they are easier to read for humans this way.
rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String())
rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String())
// Inject the trace context into the response headers.
hc := propagation.HeaderCarrier(rw.Header())
tmp.Inject(ctx, hc)
}
sw, ok := rw.(*StatusWriter)
if !ok {
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
@ -62,6 +80,37 @@ func EndHTTPSpan(r *http.Request, status int, span trace.Span) {
span.End()
}
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
return trace.SpanFromContext(ctx).TracerProvider().Tracer(TracerName).Start(ctx, FuncNameSkip(1), opts...)
type tracerNameKey struct{}
// SetTracerName sets the tracer name that will be used by all spans created
// from the context.
func SetTracerName(ctx context.Context, tracerName string) context.Context {
return context.WithValue(ctx, tracerNameKey{}, tracerName)
}
// GetTracerName returns the tracer name from the context, or TracerName if none
// is set.
func GetTracerName(ctx context.Context) string {
if tracerName, ok := ctx.Value(tracerNameKey{}).(string); ok {
return tracerName
}
return TracerName
}
// StartSpan calls StartSpanWithName with the name set to the caller's function
// name.
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
return StartSpanWithName(ctx, FuncNameSkip(1), opts...)
}
// StartSpanWithName starts a new span with the given name from the context. If
// a tracer name was set on the context (or one of its parents), it will be used
// as the tracer name instead of the default TracerName.
func StartSpanWithName(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
tracerName := GetTracerName(ctx)
return trace.SpanFromContext(ctx).
TracerProvider().
Tracer(tracerName).
Start(ctx, name, opts...)
}

View File

@ -247,7 +247,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
return
}
defer release()
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return

View File

@ -260,7 +260,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
})
require.NoError(t, err)
defer conn.Close()
sshClient, err := conn.SSHClient()
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)

View File

@ -20,6 +20,7 @@ import (
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/tailnet"
)
@ -133,6 +134,9 @@ type AgentConn struct {
}
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
errCh := make(chan error, 1)
durCh := make(chan time.Duration, 1)
go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
@ -171,8 +175,11 @@ type ReconnectingPTYInit struct {
Command string
}
func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort)))
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id string, height, width uint16, command string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort)))
if err != nil {
return nil, err
}
@ -197,14 +204,18 @@ func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command str
return conn, nil
}
func (c *AgentConn) SSH() (net.Conn, error) {
return c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort)))
func (c *AgentConn) SSH(ctx context.Context) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
return c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort)))
}
// SSHClient calls SSH to create a client that uses a weak cipher
// for high throughput.
func (c *AgentConn) SSHClient() (*ssh.Client, error) {
netConn, err := c.SSH()
func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
netConn, err := c.SSH(ctx)
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
@ -220,8 +231,10 @@ func (c *AgentConn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil
}
func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort)))
func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort)))
if err != nil {
return nil, xerrors.Errorf("dial speedtest: %w", err)
}
@ -233,6 +246,8 @@ func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Durat
}
func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if network == "unix" {
return nil, xerrors.New("network must be tcp or udp")
}
@ -277,6 +292,8 @@ func (c *AgentConn) statisticsClient() *http.Client {
}
func (c *AgentConn) doStatisticsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
host := net.JoinHostPort(TailnetIP.String(), strconv.Itoa(TailnetStatisticsPort))
url := fmt.Sprintf("http://%s%s", host, path)
@ -309,6 +326,8 @@ type ListeningPort struct {
}
func (c *AgentConn) ListeningPorts(ctx context.Context) (ListeningPortsResponse, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
res, err := c.doStatisticsRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil)
if err != nil {
return ListeningPortsResponse{}, xerrors.Errorf("do request: %w", err)

View File

@ -12,7 +12,15 @@ import (
"net/url"
"strings"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
"cdr.dev/slog"
)
// These cookies are Coder-specific. If a new one is added or changed, the name
@ -30,6 +38,13 @@ const (
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
)
var loggableMimeTypes = map[string]struct{}{
"application/json": {},
"text/plain": {},
// lots of webserver error pages are HTML
"text/html": {},
}
// New creates a Coder client for the provided URL.
func New(serverURL *url.URL) *Client {
return &Client{
@ -45,9 +60,35 @@ type Client struct {
SessionToken string
URL *url.URL
// Logger can be provided to log requests. Request method, URL and response
// status code will be logged by default.
Logger slog.Logger
// LogBodies determines whether the request and response bodies are logged
// to the provided Logger. This is useful for debugging or testing.
LogBodies bool
// BypassRatelimits is an optional flag that can be set by the site owner to
// disable ratelimit checks for the client.
BypassRatelimits bool
// PropagateTracing is an optional flag that can be set to propagate tracing
// spans to the Coder API. This is useful for seeing the entire request
// from end-to-end.
PropagateTracing bool
}
func (c *Client) Clone() *Client {
hc := *c.HTTPClient
u := *c.URL
return &Client{
HTTPClient: &hc,
SessionToken: c.SessionToken,
URL: &u,
Logger: c.Logger,
LogBodies: c.LogBodies,
BypassRatelimits: c.BypassRatelimits,
PropagateTracing: c.PropagateTracing,
}
}
type RequestOption func(*http.Request)
@ -63,30 +104,46 @@ func WithQueryParam(key, value string) RequestOption {
}
}
// Request performs an HTTP request with the body provided.
// The caller is responsible for closing the response body.
// Request performs a HTTP request with the body provided. The caller is
// responsible for closing the response body.
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
ctx, span := tracing.StartSpanWithName(ctx, tracing.FuncNameSkip(1))
defer span.End()
serverURL, err := c.URL.Parse(path)
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
var buf bytes.Buffer
var r io.Reader
if body != nil {
if data, ok := body.([]byte); ok {
buf = *bytes.NewBuffer(data)
r = bytes.NewReader(data)
} else {
// Assume JSON if not bytes.
enc := json.NewEncoder(&buf)
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
if err != nil {
return nil, xerrors.Errorf("encode body: %w", err)
}
r = buf
}
}
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf)
// Copy the request body so we can log it.
var reqBody []byte
if r != nil && c.LogBodies {
reqBody, err = io.ReadAll(r)
if err != nil {
return nil, xerrors.Errorf("read request body: %w", err)
}
r = bytes.NewReader(reqBody)
}
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), r)
if err != nil {
return nil, xerrors.Errorf("create request: %w", err)
}
@ -95,17 +152,61 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
req.Header.Set(BypassRatelimitHeader, "true")
}
if body != nil {
if r != nil {
req.Header.Set("Content-Type", "application/json")
}
for _, opt := range opts {
opt(req)
}
span.SetAttributes(semconv.NetAttributesFromHTTPRequest("tcp", req)...)
span.SetAttributes(semconv.HTTPClientAttributesFromHTTPRequest(req)...)
// Inject tracing headers if enabled.
if c.PropagateTracing {
tmp := otel.GetTextMapPropagator()
hc := propagation.HeaderCarrier(req.Header)
tmp.Inject(ctx, hc)
}
ctx = slog.With(ctx,
slog.F("method", req.Method),
slog.F("url", req.URL.String()),
)
c.Logger.Debug(ctx, "sdk request", slog.F("body", string(reqBody)))
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, xerrors.Errorf("do: %w", err)
}
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(resp.StatusCode, trace.SpanKindClient))
// Copy the response body so we can log it if it's a loggable mime type.
var respBody []byte
if resp.Body != nil && c.LogBodies {
mimeType := parseMimeType(resp.Header.Get("Content-Type"))
if _, ok := loggableMimeTypes[mimeType]; ok {
respBody, err = io.ReadAll(resp.Body)
if err != nil {
return nil, xerrors.Errorf("copy response body for logs: %w", err)
}
err = resp.Body.Close()
if err != nil {
return nil, xerrors.Errorf("close response body: %w", err)
}
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
c.Logger.Debug(ctx, "sdk response",
slog.F("status", resp.StatusCode),
slog.F("body", string(respBody)),
slog.F("trace_id", resp.Header.Get("X-Trace-Id")),
slog.F("span_id", resp.Header.Get("X-Span-Id")),
)
return resp, err
}
@ -138,10 +239,7 @@ func readBodyAsError(res *http.Response) error {
return xerrors.Errorf("read body: %w", err)
}
mimeType, _, err := mime.ParseMediaType(contentType)
if err != nil {
mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0])
}
mimeType := parseMimeType(contentType)
if mimeType != "application/json" {
if len(resp) > 1024 {
resp = append(resp[:1024], []byte("...")...)
@ -238,3 +336,12 @@ type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
func parseMimeType(contentType string) string {
mimeType, _, err := mime.ParseMediaType(contentType)
if err != nil {
mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0])
}
return mimeType
}

View File

@ -2,23 +2,114 @@ package codersdk
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/testutil"
)
const (
jsonCT = "application/json"
)
const jsonCT = "application/json"
func Test_Client(t *testing.T) {
t.Parallel()
const method = http.MethodPost
const path = "/ok"
const token = "token"
const reqBody = `{"msg": "request body"}`
const resBody = `{"status": "ok"}`
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, method, r.Method)
assert.Equal(t, path, r.URL.Path)
assert.Equal(t, token, r.Header.Get(SessionCustomHeader))
assert.Equal(t, "true", r.Header.Get(BypassRatelimitHeader))
assert.NotEmpty(t, r.Header.Get("Traceparent"))
for k, v := range r.Header {
t.Logf("header %q: %q", k, strings.Join(v, ", "))
}
w.Header().Set("Content-Type", jsonCT)
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, resBody)
}))
u, err := url.Parse(s.URL)
require.NoError(t, err)
client := New(u)
client.SessionToken = token
client.BypassRatelimits = true
logBuf := bytes.NewBuffer(nil)
client.Logger = slog.Make(sloghuman.Sink(logBuf)).Leveled(slog.LevelDebug)
client.LogBodies = true
// Setup tracing.
res := resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceNameKey.String("codersdk_test"),
)
tracerOpts := []sdktrace.TracerProviderOption{
sdktrace.WithResource(res),
}
tracerProvider := sdktrace.NewTracerProvider(tracerOpts...)
otel.SetTracerProvider(tracerProvider)
otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) {}))
otel.SetTextMapPropagator(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
)
otel.SetLogger(logr.Discard())
client.PropagateTracing = true
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ctx, span := tracerProvider.Tracer("codersdk_test").Start(ctx, "codersdk client test 1")
defer span.End()
resp, err := client.Request(ctx, method, path, []byte(reqBody))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, jsonCT, resp.Header.Get("Content-Type"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, resBody, string(body))
logStr := logBuf.String()
require.Contains(t, logStr, "sdk request")
require.Contains(t, logStr, method)
require.Contains(t, logStr, path)
require.Contains(t, logStr, strings.ReplaceAll(reqBody, `"`, `\"`))
require.Contains(t, logStr, "sdk response")
require.Contains(t, logStr, "200")
require.Contains(t, logStr, strings.ReplaceAll(resBody, `"`, `\"`))
}
func Test_readBodyAsError(t *testing.T) {
t.Parallel()

View File

@ -2,11 +2,14 @@ package codersdk
import (
"bufio"
"context"
"fmt"
"io"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
)
type ServerSentEvent struct {
@ -22,7 +25,10 @@ const (
ServerSentEventTypeError ServerSentEventType = "error"
)
func ServerSentEventReader(rc io.ReadCloser) func() (*ServerSentEvent, error) {
func ServerSentEventReader(ctx context.Context, rc io.ReadCloser) func() (*ServerSentEvent, error) {
_, span := tracing.StartSpan(ctx)
defer span.End()
reader := bufio.NewReader(rc)
nextLineValue := func(prefix string) ([]byte, error) {
var (

View File

@ -10,6 +10,8 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
)
// Workspace is a deployment of a template. It references a specific
@ -137,6 +139,8 @@ func (c *Client) CreateWorkspaceBuild(ctx context.Context, workspace uuid.UUID,
}
func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Workspace, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
//nolint:bodyclose
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/watch", id), nil)
if err != nil {
@ -145,7 +149,7 @@ func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Works
if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res)
}
nextEvent := ServerSentEventReader(res.Body)
nextEvent := ServerSentEventReader(ctx, res.Body)
wc := make(chan Workspace, 256)
go func() {

View File

@ -9,7 +9,6 @@ import (
"net/netip"
"net/url"
"strconv"
"sync"
"time"
"golang.org/x/sync/errgroup"
@ -17,8 +16,10 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/loadtestutil"
)
const defaultRequestTimeout = 5 * time.Second
@ -45,11 +46,13 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
// Run implements Runnable.
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
logs = syncWriter{
mut: &sync.Mutex{},
w: logs,
}
ctx, span := tracing.StartSpan(ctx)
defer span.End()
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
r.client.Logger = logger
r.client.LogBodies = true
_, _ = fmt.Fprintln(logs, "Opening connection to workspace agent")
switch r.cfg.ConnectionMode {
@ -69,9 +72,72 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
}
defer conn.Close()
// Wait for the disco connection to be established.
err = waitForDisco(ctx, logs, conn)
if err != nil {
return xerrors.Errorf("wait for discovery connection: %w", err)
}
// Wait for a direct connection if requested.
if r.cfg.ConnectionMode == ConnectionModeDirect {
err = waitForDirectConnection(ctx, logs, conn)
if err != nil {
return xerrors.Errorf("wait for direct connection: %w", err)
}
}
// Ensure DERP for completeness.
if r.cfg.ConnectionMode == ConnectionModeDerp {
status := conn.Status()
if len(status.Peers()) != 1 {
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers()))
}
peer := status.Peer[status.Peers()[0]]
if peer.Relay == "" || peer.CurAddr != "" {
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP")
}
}
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n")
// HACK: even though the ping passed above, we still need to open a
// connection to the agent to ensure it's ready to accept connections. Not
// sure why this is the case but it seems to be necessary.
err = verifyConnection(ctx, logs, conn)
if err != nil {
return xerrors.Errorf("verify connection: %w", err)
}
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n")
// Make initial connections sequentially to ensure the services are
// reachable before we start spawning a bunch of goroutines and tickers.
err = performInitialConnections(ctx, logs, conn, r.cfg.Connections)
if err != nil {
return xerrors.Errorf("perform initial connections: %w", err)
}
if r.cfg.HoldDuration > 0 {
err = holdConnection(ctx, logs, conn, time.Duration(r.cfg.HoldDuration), r.cfg.Connections)
if err != nil {
return xerrors.Errorf("hold connection: %w", err)
}
}
err = conn.Close()
if err != nil {
return xerrors.Errorf("close connection: %w", err)
}
return nil
}
func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
const pingAttempts = 10
const pingDelay = 1 * time.Second
ctx, span := tracing.StartSpan(ctx)
defer span.End()
for i := 0; i < pingAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts)
pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
@ -93,80 +159,59 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
}
}
// Wait for a direct connection if requested.
if r.cfg.ConnectionMode == ConnectionModeDirect {
const directConnectionAttempts = 30
const directConnectionDelay = 1 * time.Second
for i := 0; i < directConnectionAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
status := conn.Status()
return nil
}
var err error
if len(status.Peers()) != 1 {
_, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers()))
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
} else {
peer := status.Peer[status.Peers()[0]]
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
if peer.Relay != "" && peer.CurAddr == "" {
err = xerrors.Errorf("peer is connected via DERP, not direct")
}
}
if err == nil {
break
}
if i == directConnectionAttempts-1 {
return xerrors.Errorf("wait for direct connection to agent: %w", err)
}
func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
const directConnectionAttempts = 30
const directConnectionDelay = 1 * time.Second
select {
case <-ctx.Done():
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
// We use time.After here since it's a very short duration so
// leaking a timer is fine.
case <-time.After(directConnectionDelay):
}
}
}
ctx, span := tracing.StartSpan(ctx)
defer span.End()
// Ensure DERP for completeness.
if r.cfg.ConnectionMode == ConnectionModeDerp {
for i := 0; i < directConnectionAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
status := conn.Status()
var err error
if len(status.Peers()) != 1 {
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers()))
_, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers()))
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
} else {
peer := status.Peer[status.Peers()[0]]
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
if peer.Relay != "" && peer.CurAddr == "" {
err = xerrors.Errorf("peer is connected via DERP, not direct")
}
}
peer := status.Peer[status.Peers()[0]]
if peer.Relay == "" || peer.CurAddr != "" {
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP")
if err == nil {
break
}
if i == directConnectionAttempts-1 {
return xerrors.Errorf("wait for direct connection to agent: %w", err)
}
select {
case <-ctx.Done():
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
// We use time.After here since it's a very short duration so
// leaking a timer is fine.
case <-time.After(directConnectionDelay):
}
}
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n")
return nil
}
client := &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
}
portUint, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, xerrors.Errorf("parse port %q: %w", port, err)
}
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint)))
},
},
}
// HACK: even though the ping passed above, we still need to open a
// connection to the agent to ensure it's ready to accept connections. Not
// sure why this is the case but it seems to be necessary.
func verifyConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
const verifyConnectionAttempts = 30
const verifyConnectionDelay = 1 * time.Second
ctx, span := tracing.StartSpan(ctx)
defer span.End()
client := agentHTTPClient(conn)
for i := 0; i < verifyConnectionAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tVerify connection attempt %d/%d...\n", i+1, verifyConnectionAttempts)
verifyCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
@ -198,14 +243,20 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
}
}
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n")
return nil
}
// Make initial connections sequentially to ensure the services are
// reachable before we start spawning a bunch of goroutines and tickers.
if len(r.cfg.Connections) > 0 {
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
func performInitialConnections(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, specs []Connection) error {
if len(specs) == 0 {
return nil
}
for i, connSpec := range r.cfg.Connections {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
client := agentHTTPClient(conn)
for i, connSpec := range specs {
_, _ = fmt.Fprintf(logs, "\t%d. %s\n", i, connSpec.URL)
timeout := defaultRequestTimeout
@ -230,95 +281,102 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
_, _ = fmt.Fprintln(logs, "\t\tOK")
}
if r.cfg.HoldDuration > 0 {
eg, egCtx := errgroup.WithContext(ctx)
return nil
}
if len(r.cfg.Connections) > 0 {
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...")
}
for i, connSpec := range r.cfg.Connections {
i, connSpec := i, connSpec
if connSpec.Interval <= 0 {
continue
}
func holdConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, holdDur time.Duration, specs []Connection) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
eg.Go(func() error {
t := time.NewTicker(time.Duration(connSpec.Interval))
defer t.Stop()
timeout := defaultRequestTimeout
if connSpec.Timeout > 0 {
timeout = time.Duration(connSpec.Timeout)
}
for {
select {
case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
ctx, cancel := context.WithTimeout(ctx, timeout)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
if err != nil {
cancel()
return xerrors.Errorf("create request: %w", err)
}
res, err := client.Do(req)
cancel()
if err != nil {
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
}
res.Body.Close()
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
t.Reset(time.Duration(connSpec.Interval))
}
}
})
eg, egCtx := errgroup.WithContext(ctx)
client := agentHTTPClient(conn)
if len(specs) > 0 {
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...")
}
for i, connSpec := range specs {
i, connSpec := i, connSpec
if connSpec.Interval <= 0 {
continue
}
// Wait for the hold duration to end. We use a fake error to signal that
// the hold duration has ended.
_, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", time.Duration(r.cfg.HoldDuration))
eg.Go(func() error {
t := time.NewTicker(time.Duration(r.cfg.HoldDuration))
t := time.NewTicker(time.Duration(connSpec.Interval))
defer t.Stop()
select {
case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
// Returning an error here will cause the errgroup context to
// be canceled, which is what we want. This fake error is
// ignored below.
return holdDurationEndedError{}
timeout := defaultRequestTimeout
if connSpec.Timeout > 0 {
timeout = time.Duration(connSpec.Timeout)
}
for {
select {
case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
ctx, cancel := context.WithTimeout(ctx, timeout)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
if err != nil {
cancel()
return xerrors.Errorf("create request: %w", err)
}
res, err := client.Do(req)
cancel()
if err != nil {
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
}
res.Body.Close()
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
t.Reset(time.Duration(connSpec.Interval))
}
}
})
err = eg.Wait()
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
return xerrors.Errorf("run connections loop: %w", err)
}
}
err = conn.Close()
if err != nil {
return xerrors.Errorf("close connection: %w", err)
// Wait for the hold duration to end. We use a fake error to signal that
// the hold duration has ended.
_, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", holdDur)
eg.Go(func() error {
t := time.NewTicker(holdDur)
defer t.Stop()
select {
case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
// Returning an error here will cause the errgroup context to
// be canceled, which is what we want. This fake error is
// ignored below.
return holdDurationEndedError{}
}
})
err := eg.Wait()
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
return xerrors.Errorf("run connections loop: %w", err)
}
return nil
}
// syncWriter wraps an io.Writer in a sync.Mutex.
type syncWriter struct {
mut *sync.Mutex
w io.Writer
}
func agentHTTPClient(conn *codersdk.AgentConn) *http.Client {
return &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
}
// Write implements io.Writer.
func (sw syncWriter) Write(p []byte) (n int, err error) {
sw.mut.Lock()
defer sw.mut.Unlock()
return sw.w.Write(p)
portUint, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, xerrors.Errorf("parse port %q: %w", port, err)
}
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint)))
},
},
}
}

View File

@ -7,6 +7,8 @@ import (
"github.com/hashicorp/go-multierror"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
)
// ExecutionStrategy defines how a TestHarness should execute a set of runs. It
@ -49,6 +51,9 @@ func NewTestHarness(strategy ExecutionStrategy) *TestHarness {
//
// Panics if called more than once.
func (h *TestHarness) Run(ctx context.Context) (err error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
h.mut.Lock()
if h.started {
h.mut.Unlock()

View File

@ -95,6 +95,7 @@ type timeoutRunnerWrapper struct {
}
var _ Runnable = timeoutRunnerWrapper{}
var _ Cleanable = timeoutRunnerWrapper{}
func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
ctx, cancel := context.WithTimeout(ctx, t.timeout)
@ -103,6 +104,15 @@ func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer
return t.inner.Run(ctx, id, logs)
}
func (t timeoutRunnerWrapper) Cleanup(ctx context.Context, id string) error {
c, ok := t.inner.(Cleanable)
if !ok {
return nil
}
return c.Cleanup(ctx, id)
}
// Execute implements ExecutionStrategy.
func (t TimeoutExecutionStrategyWrapper) Execute(ctx context.Context, runs []*TestRun) error {
for _, run := range runs {

View File

@ -0,0 +1,26 @@
package loadtestutil
import (
"io"
"sync"
)
// SyncWriter wraps an io.Writer in a sync.Mutex.
type SyncWriter struct {
mut *sync.Mutex
w io.Writer
}
func NewSyncWriter(w io.Writer) *SyncWriter {
return &SyncWriter{
mut: &sync.Mutex{},
w: w,
}
}
// Write implements io.Writer.
func (sw *SyncWriter) Write(p []byte) (n int, err error) {
sw.mut.Lock()
defer sw.mut.Unlock()
return sw.w.Write(p)
}

View File

@ -9,9 +9,14 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/loadtestutil"
)
type Runner struct {
@ -32,6 +37,14 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
// Run implements Runnable.
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
r.client.Logger = logger
r.client.LogBodies = true
req := r.cfg.Request
if req.Name == "" {
randName, err := cryptorand.HexString(8)
@ -66,6 +79,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error {
if r.workspaceID == uuid.Nil {
return nil
}
ctx, span := tracing.StartSpan(ctx)
defer span.End()
build, err := r.client.CreateWorkspaceBuild(ctx, r.workspaceID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionDelete,
@ -85,6 +100,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error {
}
func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, buildID uuid.UUID) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
_, _ = fmt.Fprint(w, "Build is currently queued...")
// Wait for build to start.
@ -154,6 +171,8 @@ func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, bui
}
func waitForAgents(ctx context.Context, w io.Writer, client *codersdk.Client, workspaceID uuid.UUID) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
_, _ = fmt.Fprint(w, "Waiting for agents to connect...\n\n")
for {