mirror of https://github.com/coder/coder.git
feat: make trace provider in loadtest, add tracing to sdk (#4939)
This commit is contained in:
parent
fa844d0878
commit
d82364b9b5
|
@ -58,9 +58,12 @@ func TestAgent(t *testing.T) {
|
|||
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
|
||||
sshClient, err := conn.SSHClient()
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
|
@ -75,9 +78,12 @@ func TestAgent(t *testing.T) {
|
|||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
|
@ -217,6 +223,8 @@ func TestAgent(t *testing.T) {
|
|||
|
||||
t.Run("SFTP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err, "get current user")
|
||||
home := u.HomeDir
|
||||
|
@ -224,7 +232,7 @@ func TestAgent(t *testing.T) {
|
|||
home = "/" + strings.ReplaceAll(home, "\\", "/")
|
||||
}
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
client, err := sftp.NewClient(sshClient)
|
||||
|
@ -250,8 +258,11 @@ func TestAgent(t *testing.T) {
|
|||
t.Run("SCP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
scpClient, err := scp.NewClientBySSH(sshClient)
|
||||
|
@ -386,9 +397,12 @@ func TestAgent(t *testing.T) {
|
|||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
id := uuid.NewString()
|
||||
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
|
||||
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
bufRead := bufio.NewReader(netConn)
|
||||
|
||||
|
@ -426,7 +440,7 @@ func TestAgent(t *testing.T) {
|
|||
expectLine(matchEchoOutput)
|
||||
|
||||
_ = netConn.Close()
|
||||
netConn, err = conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
|
||||
netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
bufRead = bufio.NewReader(netConn)
|
||||
|
||||
|
@ -504,12 +518,14 @@ func TestAgent(t *testing.T) {
|
|||
t.Run("Speedtest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
defer conn.Close()
|
||||
res, err := conn.Speedtest(speedtest.Upload, 250*time.Millisecond)
|
||||
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
|
||||
})
|
||||
|
@ -599,7 +615,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
ssh, err := agentConn.SSH(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
|
@ -626,8 +645,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
|||
}
|
||||
|
||||
func setupSSHSession(t *testing.T, options codersdk.WorkspaceAgentMetadata) *ssh.Session {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
conn, _ := setupAgent(t, options, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sshClient.Close()
|
||||
|
|
|
@ -198,7 +198,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
|||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
sshClient, err := dialer.SSHClient()
|
||||
sshClient, err := dialer.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func sshConfigFileName(t *testing.T) (sshConfig string) {
|
||||
|
@ -131,7 +132,9 @@ func TestConfigSSH(t *testing.T) {
|
|||
if err != nil {
|
||||
break
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
ssh, err := agentConn.SSH(ctx)
|
||||
cancel()
|
||||
assert.NoError(t, err)
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
|
|
131
cli/loadtest.go
131
cli/loadtest.go
|
@ -8,20 +8,30 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/loadtest/harness"
|
||||
)
|
||||
|
||||
const loadtestTracerName = "coder_loadtest"
|
||||
|
||||
func loadtest() *cobra.Command {
|
||||
var (
|
||||
configPath string
|
||||
outputSpecs []string
|
||||
|
||||
traceEnable bool
|
||||
traceCoder bool
|
||||
traceHoneycombAPIKey string
|
||||
tracePropagate bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "loadtest --config <path> [--output json[:path]] [--output text[:path]]]",
|
||||
|
@ -53,6 +63,8 @@ func loadtest() *cobra.Command {
|
|||
Hidden: true,
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := tracing.SetTracerName(cmd.Context(), loadtestTracerName)
|
||||
|
||||
config, err := loadLoadTestConfigFile(configPath, cmd.InOrStdin())
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -67,7 +79,7 @@ func loadtest() *cobra.Command {
|
|||
return err
|
||||
}
|
||||
|
||||
me, err := client.User(cmd.Context(), codersdk.Me)
|
||||
me, err := client.User(ctx, codersdk.Me)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch current user: %w", err)
|
||||
}
|
||||
|
@ -84,11 +96,43 @@ func loadtest() *cobra.Command {
|
|||
}
|
||||
}
|
||||
if !ok {
|
||||
return xerrors.Errorf("Not logged in as site owner. Load testing is only available to site owners.")
|
||||
return xerrors.Errorf("Not logged in as a site owner. Load testing is only available to site owners.")
|
||||
}
|
||||
|
||||
// Disable ratelimits for future requests.
|
||||
// Setup tracing and start a span.
|
||||
var (
|
||||
shouldTrace = traceEnable || traceCoder || traceHoneycombAPIKey != ""
|
||||
tracerProvider trace.TracerProvider = trace.NewNoopTracerProvider()
|
||||
closeTracingOnce sync.Once
|
||||
closeTracing = func(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
)
|
||||
if shouldTrace {
|
||||
tracerProvider, closeTracing, err = tracing.TracerProvider(ctx, loadtestTracerName, tracing.TracerOpts{
|
||||
Default: traceEnable,
|
||||
Coder: traceCoder,
|
||||
Honeycomb: traceHoneycombAPIKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("initialize tracing: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeTracingOnce.Do(func() {
|
||||
// Allow time for traces to flush even if command
|
||||
// context is canceled.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
_ = closeTracing(ctx)
|
||||
})
|
||||
}()
|
||||
}
|
||||
tracer := tracerProvider.Tracer(loadtestTracerName)
|
||||
|
||||
// Disable ratelimits and propagate tracing spans for future
|
||||
// requests. Individual tests will setup their own loggers.
|
||||
client.BypassRatelimits = true
|
||||
client.PropagateTracing = tracePropagate
|
||||
|
||||
// Prepare the test.
|
||||
strategy := config.Strategy.ExecutionStrategy()
|
||||
|
@ -99,18 +143,22 @@ func loadtest() *cobra.Command {
|
|||
|
||||
for j := 0; j < t.Count; j++ {
|
||||
id := strconv.Itoa(j)
|
||||
runner, err := t.NewRunner(client)
|
||||
runner, err := t.NewRunner(client.Clone())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create %q runner for %s/%s: %w", t.Type, name, id, err)
|
||||
}
|
||||
|
||||
th.AddRun(name, id, runner)
|
||||
th.AddRun(name, id, &runnableTraceWrapper{
|
||||
tracer: tracer,
|
||||
spanName: fmt.Sprintf("%s/%s", name, id),
|
||||
runner: runner,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...")
|
||||
|
||||
testCtx := cmd.Context()
|
||||
testCtx := ctx
|
||||
if config.Timeout > 0 {
|
||||
var cancel func()
|
||||
testCtx, cancel = context.WithTimeout(testCtx, time.Duration(config.Timeout))
|
||||
|
@ -158,11 +206,24 @@ func loadtest() *cobra.Command {
|
|||
|
||||
// Cleanup.
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...")
|
||||
err = th.Cleanup(cmd.Context())
|
||||
err = th.Cleanup(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
|
||||
// Upload traces.
|
||||
if shouldTrace {
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nUploading traces...")
|
||||
closeTracingOnce.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
|
||||
defer cancel()
|
||||
err := closeTracing(ctx)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nError uploading traces: %+v\n", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
@ -173,6 +234,12 @@ func loadtest() *cobra.Command {
|
|||
|
||||
cliflag.StringVarP(cmd.Flags(), &configPath, "config", "", "CODER_LOADTEST_CONFIG_PATH", "", "Path to the load test configuration file, or - to read from stdin.")
|
||||
cliflag.StringArrayVarP(cmd.Flags(), &outputSpecs, "output", "", "CODER_LOADTEST_OUTPUTS", []string{"text"}, "Output formats, see usage for more information.")
|
||||
|
||||
cliflag.BoolVarP(cmd.Flags(), &traceEnable, "trace", "", "CODER_LOADTEST_TRACE", false, "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md")
|
||||
cliflag.BoolVarP(cmd.Flags(), &traceCoder, "trace-coder", "", "CODER_LOADTEST_TRACE_CODER", false, "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.")
|
||||
cliflag.StringVarP(cmd.Flags(), &traceHoneycombAPIKey, "trace-honeycomb-api-key", "", "CODER_LOADTEST_TRACE_HONEYCOMB_API_KEY", "", "Enables trace exporting to Honeycomb.io using the provided API key.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &tracePropagate, "trace-propagate", "", "CODER_LOADTEST_TRACE_PROPAGATE", false, "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
@ -271,3 +338,53 @@ func parseLoadTestOutputs(outputs []string) ([]loadTestOutput, error) {
|
|||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type runnableTraceWrapper struct {
|
||||
tracer trace.Tracer
|
||||
spanName string
|
||||
runner harness.Runnable
|
||||
|
||||
span trace.Span
|
||||
}
|
||||
|
||||
var _ harness.Runnable = &runnableTraceWrapper{}
|
||||
var _ harness.Cleanable = &runnableTraceWrapper{}
|
||||
|
||||
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
|
||||
ctx, span := r.tracer.Start(ctx, r.spanName, trace.WithNewRoot())
|
||||
defer span.End()
|
||||
r.span = span
|
||||
|
||||
traceID := "unknown trace ID"
|
||||
spanID := "unknown span ID"
|
||||
if span.SpanContext().HasTraceID() {
|
||||
traceID = span.SpanContext().TraceID().String()
|
||||
}
|
||||
if span.SpanContext().HasSpanID() {
|
||||
spanID = span.SpanContext().SpanID().String()
|
||||
}
|
||||
_, _ = fmt.Fprintf(logs, "Trace ID: %s\n", traceID)
|
||||
_, _ = fmt.Fprintf(logs, "Span ID: %s\n\n", spanID)
|
||||
|
||||
// Make a separate span for the run itself so the sub-spans are grouped
|
||||
// neatly. The cleanup span is also a child of the above span so this is
|
||||
// important for readability.
|
||||
ctx2, span2 := r.tracer.Start(ctx, r.spanName+" run")
|
||||
defer span2.End()
|
||||
return r.runner.Run(ctx2, id, logs)
|
||||
}
|
||||
|
||||
func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string) error {
|
||||
c, ok := r.runner.(harness.Cleanable)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.span != nil {
|
||||
ctx = trace.ContextWithSpanContext(ctx, r.span.SpanContext())
|
||||
}
|
||||
ctx, span := r.tracer.Start(ctx, r.spanName+" cleanup")
|
||||
defer span.End()
|
||||
|
||||
return c.Cleanup(ctx, id)
|
||||
}
|
||||
|
|
|
@ -277,6 +277,8 @@ func TestLoadTest(t *testing.T) {
|
|||
require.NoError(t, err, msg)
|
||||
}
|
||||
|
||||
t.Logf("output %d:\n\n%s", i, string(b))
|
||||
|
||||
switch output.format {
|
||||
case "text":
|
||||
require.Contains(t, string(b), "Test results:", msg)
|
||||
|
|
|
@ -128,8 +128,9 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
|
|||
|
||||
if cfg.Trace.Enable.Value || shouldCoderTrace {
|
||||
sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
|
||||
Default: cfg.Trace.Enable.Value,
|
||||
Coder: shouldCoderTrace,
|
||||
Default: cfg.Trace.Enable.Value,
|
||||
Coder: shouldCoderTrace,
|
||||
Honeycomb: cfg.Trace.HoneycombAPIKey.Value,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "start telemetry exporter", slog.Error(err))
|
||||
|
|
|
@ -95,7 +95,7 @@ func speedtest() *cobra.Command {
|
|||
dir = tsspeedtest.Upload
|
||||
}
|
||||
cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), dir)
|
||||
results, err := conn.Speedtest(dir, duration)
|
||||
results, err := conn.Speedtest(ctx, dir, duration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ func ssh() *cobra.Command {
|
|||
defer stopPolling()
|
||||
|
||||
if stdio {
|
||||
rawSSH, err := conn.SSH()
|
||||
rawSSH, err := conn.SSH(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func ssh() *cobra.Command {
|
|||
return nil
|
||||
}
|
||||
|
||||
sshClient, err := conn.SSHClient()
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
sshConn, err := conn.SSHClient()
|
||||
sshConn, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
_ = sshConn.Close()
|
||||
|
||||
|
|
|
@ -633,7 +633,7 @@ func TestTemplateMetrics(t *testing.T) {
|
|||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
sshConn, err := conn.SSHClient()
|
||||
sshConn, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
_ = sshConn.Close()
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
|
@ -82,11 +83,23 @@ func TracerProvider(ctx context.Context, service string, opts TracerOpts) (*sdkt
|
|||
otel.SetLogger(logr.Discard())
|
||||
|
||||
return tracerProvider, func(ctx context.Context) error {
|
||||
for _, close := range closers {
|
||||
_ = close(ctx)
|
||||
var merr error
|
||||
err := tracerProvider.ForceFlush(ctx)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.ForceFlush(): %w", err))
|
||||
}
|
||||
_ = tracerProvider.Shutdown(ctx)
|
||||
return nil
|
||||
for i, closer := range closers {
|
||||
err = closer(ctx)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, xerrors.Errorf("closer() %d: %w", i, err))
|
||||
}
|
||||
}
|
||||
err = tracerProvider.Shutdown(ctx)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.Shutdown(): %w", err))
|
||||
}
|
||||
|
||||
return merr
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
@ -23,11 +25,27 @@ func Middleware(tracerProvider trace.TracerProvider) func(http.Handler) http.Han
|
|||
return
|
||||
}
|
||||
|
||||
// Extract the trace context from the request headers.
|
||||
tmp := otel.GetTextMapPropagator()
|
||||
hc := propagation.HeaderCarrier(r.Header)
|
||||
ctx := tmp.Extract(r.Context(), hc)
|
||||
|
||||
// start span with default span name. Span name will be updated to "method route" format once request finishes.
|
||||
ctx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.RequestURI))
|
||||
ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI))
|
||||
defer span.End()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() {
|
||||
// Technically these values are included in the Traceparent
|
||||
// header, but they are easier to read for humans this way.
|
||||
rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String())
|
||||
rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String())
|
||||
|
||||
// Inject the trace context into the response headers.
|
||||
hc := propagation.HeaderCarrier(rw.Header())
|
||||
tmp.Inject(ctx, hc)
|
||||
}
|
||||
|
||||
sw, ok := rw.(*StatusWriter)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
|
||||
|
@ -62,6 +80,37 @@ func EndHTTPSpan(r *http.Request, status int, span trace.Span) {
|
|||
span.End()
|
||||
}
|
||||
|
||||
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
return trace.SpanFromContext(ctx).TracerProvider().Tracer(TracerName).Start(ctx, FuncNameSkip(1), opts...)
|
||||
type tracerNameKey struct{}
|
||||
|
||||
// SetTracerName sets the tracer name that will be used by all spans created
|
||||
// from the context.
|
||||
func SetTracerName(ctx context.Context, tracerName string) context.Context {
|
||||
return context.WithValue(ctx, tracerNameKey{}, tracerName)
|
||||
}
|
||||
|
||||
// GetTracerName returns the tracer name from the context, or TracerName if none
|
||||
// is set.
|
||||
func GetTracerName(ctx context.Context) string {
|
||||
if tracerName, ok := ctx.Value(tracerNameKey{}).(string); ok {
|
||||
return tracerName
|
||||
}
|
||||
|
||||
return TracerName
|
||||
}
|
||||
|
||||
// StartSpan calls StartSpanWithName with the name set to the caller's function
|
||||
// name.
|
||||
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
return StartSpanWithName(ctx, FuncNameSkip(1), opts...)
|
||||
}
|
||||
|
||||
// StartSpanWithName starts a new span with the given name from the context. If
|
||||
// a tracer name was set on the context (or one of its parents), it will be used
|
||||
// as the tracer name instead of the default TracerName.
|
||||
func StartSpanWithName(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
tracerName := GetTracerName(ctx)
|
||||
return trace.SpanFromContext(ctx).
|
||||
TracerProvider().
|
||||
Tracer(tracerName).
|
||||
Start(ctx, name, opts...)
|
||||
}
|
||||
|
|
|
@ -247,7 +247,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
defer release()
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
|
||||
return
|
||||
|
|
|
@ -260,7 +260,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
sshClient, err := conn.SSHClient()
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
|
@ -133,6 +134,9 @@ type AgentConn struct {
|
|||
}
|
||||
|
||||
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
durCh := make(chan time.Duration, 1)
|
||||
go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
|
||||
|
@ -171,8 +175,11 @@ type ReconnectingPTYInit struct {
|
|||
Command string
|
||||
}
|
||||
|
||||
func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
|
||||
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort)))
|
||||
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id string, height, width uint16, command string) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -197,14 +204,18 @@ func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command str
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *AgentConn) SSH() (net.Conn, error) {
|
||||
return c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort)))
|
||||
func (c *AgentConn) SSH(ctx context.Context) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
return c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort)))
|
||||
}
|
||||
|
||||
// SSHClient calls SSH to create a client that uses a weak cipher
|
||||
// for high throughput.
|
||||
func (c *AgentConn) SSHClient() (*ssh.Client, error) {
|
||||
netConn, err := c.SSH()
|
||||
func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
netConn, err := c.SSH(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
}
|
||||
|
@ -220,8 +231,10 @@ func (c *AgentConn) SSHClient() (*ssh.Client, error) {
|
|||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort)))
|
||||
func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort)))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial speedtest: %w", err)
|
||||
}
|
||||
|
@ -233,6 +246,8 @@ func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Durat
|
|||
}
|
||||
|
||||
func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
if network == "unix" {
|
||||
return nil, xerrors.New("network must be tcp or udp")
|
||||
}
|
||||
|
@ -277,6 +292,8 @@ func (c *AgentConn) statisticsClient() *http.Client {
|
|||
}
|
||||
|
||||
func (c *AgentConn) doStatisticsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
host := net.JoinHostPort(TailnetIP.String(), strconv.Itoa(TailnetStatisticsPort))
|
||||
url := fmt.Sprintf("http://%s%s", host, path)
|
||||
|
||||
|
@ -309,6 +326,8 @@ type ListeningPort struct {
|
|||
}
|
||||
|
||||
func (c *AgentConn) ListeningPorts(ctx context.Context) (ListeningPortsResponse, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
res, err := c.doStatisticsRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil)
|
||||
if err != nil {
|
||||
return ListeningPortsResponse{}, xerrors.Errorf("do request: %w", err)
|
||||
|
|
|
@ -12,7 +12,15 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// These cookies are Coder-specific. If a new one is added or changed, the name
|
||||
|
@ -30,6 +38,13 @@ const (
|
|||
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
|
||||
)
|
||||
|
||||
var loggableMimeTypes = map[string]struct{}{
|
||||
"application/json": {},
|
||||
"text/plain": {},
|
||||
// lots of webserver error pages are HTML
|
||||
"text/html": {},
|
||||
}
|
||||
|
||||
// New creates a Coder client for the provided URL.
|
||||
func New(serverURL *url.URL) *Client {
|
||||
return &Client{
|
||||
|
@ -45,9 +60,35 @@ type Client struct {
|
|||
SessionToken string
|
||||
URL *url.URL
|
||||
|
||||
// Logger can be provided to log requests. Request method, URL and response
|
||||
// status code will be logged by default.
|
||||
Logger slog.Logger
|
||||
// LogBodies determines whether the request and response bodies are logged
|
||||
// to the provided Logger. This is useful for debugging or testing.
|
||||
LogBodies bool
|
||||
|
||||
// BypassRatelimits is an optional flag that can be set by the site owner to
|
||||
// disable ratelimit checks for the client.
|
||||
BypassRatelimits bool
|
||||
|
||||
// PropagateTracing is an optional flag that can be set to propagate tracing
|
||||
// spans to the Coder API. This is useful for seeing the entire request
|
||||
// from end-to-end.
|
||||
PropagateTracing bool
|
||||
}
|
||||
|
||||
func (c *Client) Clone() *Client {
|
||||
hc := *c.HTTPClient
|
||||
u := *c.URL
|
||||
return &Client{
|
||||
HTTPClient: &hc,
|
||||
SessionToken: c.SessionToken,
|
||||
URL: &u,
|
||||
Logger: c.Logger,
|
||||
LogBodies: c.LogBodies,
|
||||
BypassRatelimits: c.BypassRatelimits,
|
||||
PropagateTracing: c.PropagateTracing,
|
||||
}
|
||||
}
|
||||
|
||||
type RequestOption func(*http.Request)
|
||||
|
@ -63,30 +104,46 @@ func WithQueryParam(key, value string) RequestOption {
|
|||
}
|
||||
}
|
||||
|
||||
// Request performs an HTTP request with the body provided.
|
||||
// The caller is responsible for closing the response body.
|
||||
// Request performs a HTTP request with the body provided. The caller is
|
||||
// responsible for closing the response body.
|
||||
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
|
||||
ctx, span := tracing.StartSpanWithName(ctx, tracing.FuncNameSkip(1))
|
||||
defer span.End()
|
||||
|
||||
serverURL, err := c.URL.Parse(path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
var r io.Reader
|
||||
if body != nil {
|
||||
if data, ok := body.([]byte); ok {
|
||||
buf = *bytes.NewBuffer(data)
|
||||
r = bytes.NewReader(data)
|
||||
} else {
|
||||
// Assume JSON if not bytes.
|
||||
enc := json.NewEncoder(&buf)
|
||||
buf := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
err = enc.Encode(body)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("encode body: %w", err)
|
||||
}
|
||||
|
||||
r = buf
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf)
|
||||
// Copy the request body so we can log it.
|
||||
var reqBody []byte
|
||||
if r != nil && c.LogBodies {
|
||||
reqBody, err = io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read request body: %w", err)
|
||||
}
|
||||
r = bytes.NewReader(reqBody)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), r)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
@ -95,17 +152,61 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
|
|||
req.Header.Set(BypassRatelimitHeader, "true")
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
if r != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(req)
|
||||
}
|
||||
|
||||
span.SetAttributes(semconv.NetAttributesFromHTTPRequest("tcp", req)...)
|
||||
span.SetAttributes(semconv.HTTPClientAttributesFromHTTPRequest(req)...)
|
||||
|
||||
// Inject tracing headers if enabled.
|
||||
if c.PropagateTracing {
|
||||
tmp := otel.GetTextMapPropagator()
|
||||
hc := propagation.HeaderCarrier(req.Header)
|
||||
tmp.Inject(ctx, hc)
|
||||
}
|
||||
|
||||
ctx = slog.With(ctx,
|
||||
slog.F("method", req.Method),
|
||||
slog.F("url", req.URL.String()),
|
||||
)
|
||||
c.Logger.Debug(ctx, "sdk request", slog.F("body", string(reqBody)))
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("do: %w", err)
|
||||
}
|
||||
|
||||
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
|
||||
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(resp.StatusCode, trace.SpanKindClient))
|
||||
|
||||
// Copy the response body so we can log it if it's a loggable mime type.
|
||||
var respBody []byte
|
||||
if resp.Body != nil && c.LogBodies {
|
||||
mimeType := parseMimeType(resp.Header.Get("Content-Type"))
|
||||
if _, ok := loggableMimeTypes[mimeType]; ok {
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("copy response body for logs: %w", err)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("close response body: %w", err)
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
c.Logger.Debug(ctx, "sdk response",
|
||||
slog.F("status", resp.StatusCode),
|
||||
slog.F("body", string(respBody)),
|
||||
slog.F("trace_id", resp.Header.Get("X-Trace-Id")),
|
||||
slog.F("span_id", resp.Header.Get("X-Span-Id")),
|
||||
)
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
|
@ -138,10 +239,7 @@ func readBodyAsError(res *http.Response) error {
|
|||
return xerrors.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
mimeType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
}
|
||||
mimeType := parseMimeType(contentType)
|
||||
if mimeType != "application/json" {
|
||||
if len(resp) > 1024 {
|
||||
resp = append(resp[:1024], []byte("...")...)
|
||||
|
@ -238,3 +336,12 @@ type closeFunc func() error
|
|||
func (c closeFunc) Close() error {
|
||||
return c()
|
||||
}
|
||||
|
||||
func parseMimeType(contentType string) string {
|
||||
mimeType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
}
|
||||
|
||||
return mimeType
|
||||
}
|
||||
|
|
|
@ -2,23 +2,114 @@ package codersdk
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-logr/logr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonCT = "application/json"
|
||||
)
|
||||
const jsonCT = "application/json"
|
||||
|
||||
func Test_Client(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const method = http.MethodPost
|
||||
const path = "/ok"
|
||||
const token = "token"
|
||||
const reqBody = `{"msg": "request body"}`
|
||||
const resBody = `{"status": "ok"}`
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, method, r.Method)
|
||||
assert.Equal(t, path, r.URL.Path)
|
||||
assert.Equal(t, token, r.Header.Get(SessionCustomHeader))
|
||||
assert.Equal(t, "true", r.Header.Get(BypassRatelimitHeader))
|
||||
assert.NotEmpty(t, r.Header.Get("Traceparent"))
|
||||
for k, v := range r.Header {
|
||||
t.Logf("header %q: %q", k, strings.Join(v, ", "))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", jsonCT)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, resBody)
|
||||
}))
|
||||
|
||||
u, err := url.Parse(s.URL)
|
||||
require.NoError(t, err)
|
||||
client := New(u)
|
||||
client.SessionToken = token
|
||||
client.BypassRatelimits = true
|
||||
|
||||
logBuf := bytes.NewBuffer(nil)
|
||||
client.Logger = slog.Make(sloghuman.Sink(logBuf)).Leveled(slog.LevelDebug)
|
||||
client.LogBodies = true
|
||||
|
||||
// Setup tracing.
|
||||
res := resource.NewWithAttributes(
|
||||
semconv.SchemaURL,
|
||||
semconv.ServiceNameKey.String("codersdk_test"),
|
||||
)
|
||||
tracerOpts := []sdktrace.TracerProviderOption{
|
||||
sdktrace.WithResource(res),
|
||||
}
|
||||
tracerProvider := sdktrace.NewTracerProvider(tracerOpts...)
|
||||
otel.SetTracerProvider(tracerProvider)
|
||||
otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) {}))
|
||||
otel.SetTextMapPropagator(
|
||||
propagation.NewCompositeTextMapPropagator(
|
||||
propagation.TraceContext{},
|
||||
propagation.Baggage{},
|
||||
),
|
||||
)
|
||||
otel.SetLogger(logr.Discard())
|
||||
client.PropagateTracing = true
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
ctx, span := tracerProvider.Tracer("codersdk_test").Start(ctx, "codersdk client test 1")
|
||||
defer span.End()
|
||||
|
||||
resp, err := client.Request(ctx, method, path, []byte(reqBody))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, jsonCT, resp.Header.Get("Content-Type"))
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resBody, string(body))
|
||||
|
||||
logStr := logBuf.String()
|
||||
require.Contains(t, logStr, "sdk request")
|
||||
require.Contains(t, logStr, method)
|
||||
require.Contains(t, logStr, path)
|
||||
require.Contains(t, logStr, strings.ReplaceAll(reqBody, `"`, `\"`))
|
||||
require.Contains(t, logStr, "sdk response")
|
||||
require.Contains(t, logStr, "200")
|
||||
require.Contains(t, logStr, strings.ReplaceAll(resBody, `"`, `\"`))
|
||||
}
|
||||
|
||||
func Test_readBodyAsError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -2,11 +2,14 @@ package codersdk
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
type ServerSentEvent struct {
|
||||
|
@ -22,7 +25,10 @@ const (
|
|||
ServerSentEventTypeError ServerSentEventType = "error"
|
||||
)
|
||||
|
||||
func ServerSentEventReader(rc io.ReadCloser) func() (*ServerSentEvent, error) {
|
||||
func ServerSentEventReader(ctx context.Context, rc io.ReadCloser) func() (*ServerSentEvent, error) {
|
||||
_, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
reader := bufio.NewReader(rc)
|
||||
nextLineValue := func(prefix string) ([]byte, error) {
|
||||
var (
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
// Workspace is a deployment of a template. It references a specific
|
||||
|
@ -137,6 +139,8 @@ func (c *Client) CreateWorkspaceBuild(ctx context.Context, workspace uuid.UUID,
|
|||
}
|
||||
|
||||
func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Workspace, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
//nolint:bodyclose
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/watch", id), nil)
|
||||
if err != nil {
|
||||
|
@ -145,7 +149,7 @@ func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Works
|
|||
if res.StatusCode != http.StatusOK {
|
||||
return nil, readBodyAsError(res)
|
||||
}
|
||||
nextEvent := ServerSentEventReader(res.Body)
|
||||
nextEvent := ServerSentEventReader(ctx, res.Body)
|
||||
|
||||
wc := make(chan Workspace, 256)
|
||||
go func() {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
@ -17,8 +16,10 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/loadtest/harness"
|
||||
"github.com/coder/coder/loadtest/loadtestutil"
|
||||
)
|
||||
|
||||
const defaultRequestTimeout = 5 * time.Second
|
||||
|
@ -45,11 +46,13 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
|||
|
||||
// Run implements Runnable.
|
||||
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
||||
logs = syncWriter{
|
||||
mut: &sync.Mutex{},
|
||||
w: logs,
|
||||
}
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
logs = loadtestutil.NewSyncWriter(logs)
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
||||
r.client.Logger = logger
|
||||
r.client.LogBodies = true
|
||||
|
||||
_, _ = fmt.Fprintln(logs, "Opening connection to workspace agent")
|
||||
switch r.cfg.ConnectionMode {
|
||||
|
@ -69,9 +72,72 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Wait for the disco connection to be established.
|
||||
err = waitForDisco(ctx, logs, conn)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("wait for discovery connection: %w", err)
|
||||
}
|
||||
|
||||
// Wait for a direct connection if requested.
|
||||
if r.cfg.ConnectionMode == ConnectionModeDirect {
|
||||
err = waitForDirectConnection(ctx, logs, conn)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("wait for direct connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure DERP for completeness.
|
||||
if r.cfg.ConnectionMode == ConnectionModeDerp {
|
||||
status := conn.Status()
|
||||
if len(status.Peers()) != 1 {
|
||||
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers()))
|
||||
}
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
if peer.Relay == "" || peer.CurAddr != "" {
|
||||
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP")
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n")
|
||||
|
||||
// HACK: even though the ping passed above, we still need to open a
|
||||
// connection to the agent to ensure it's ready to accept connections. Not
|
||||
// sure why this is the case but it seems to be necessary.
|
||||
err = verifyConnection(ctx, logs, conn)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("verify connection: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n")
|
||||
|
||||
// Make initial connections sequentially to ensure the services are
|
||||
// reachable before we start spawning a bunch of goroutines and tickers.
|
||||
err = performInitialConnections(ctx, logs, conn, r.cfg.Connections)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("perform initial connections: %w", err)
|
||||
}
|
||||
|
||||
if r.cfg.HoldDuration > 0 {
|
||||
err = holdConnection(ctx, logs, conn, time.Duration(r.cfg.HoldDuration), r.cfg.Connections)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("hold connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("close connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
|
||||
const pingAttempts = 10
|
||||
const pingDelay = 1 * time.Second
|
||||
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
for i := 0; i < pingAttempts; i++ {
|
||||
_, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts)
|
||||
pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
|
||||
|
@ -93,80 +159,59 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Wait for a direct connection if requested.
|
||||
if r.cfg.ConnectionMode == ConnectionModeDirect {
|
||||
const directConnectionAttempts = 30
|
||||
const directConnectionDelay = 1 * time.Second
|
||||
for i := 0; i < directConnectionAttempts; i++ {
|
||||
_, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
|
||||
status := conn.Status()
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(status.Peers()) != 1 {
|
||||
_, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers()))
|
||||
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
|
||||
} else {
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
|
||||
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
|
||||
if peer.Relay != "" && peer.CurAddr == "" {
|
||||
err = xerrors.Errorf("peer is connected via DERP, not direct")
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i == directConnectionAttempts-1 {
|
||||
return xerrors.Errorf("wait for direct connection to agent: %w", err)
|
||||
}
|
||||
func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
|
||||
const directConnectionAttempts = 30
|
||||
const directConnectionDelay = 1 * time.Second
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
|
||||
// We use time.After here since it's a very short duration so
|
||||
// leaking a timer is fine.
|
||||
case <-time.After(directConnectionDelay):
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
// Ensure DERP for completeness.
|
||||
if r.cfg.ConnectionMode == ConnectionModeDerp {
|
||||
for i := 0; i < directConnectionAttempts; i++ {
|
||||
_, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
|
||||
status := conn.Status()
|
||||
|
||||
var err error
|
||||
if len(status.Peers()) != 1 {
|
||||
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers()))
|
||||
_, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers()))
|
||||
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
|
||||
} else {
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
|
||||
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
|
||||
if peer.Relay != "" && peer.CurAddr == "" {
|
||||
err = xerrors.Errorf("peer is connected via DERP, not direct")
|
||||
}
|
||||
}
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
if peer.Relay == "" || peer.CurAddr != "" {
|
||||
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP")
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i == directConnectionAttempts-1 {
|
||||
return xerrors.Errorf("wait for direct connection to agent: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
|
||||
// We use time.After here since it's a very short duration so
|
||||
// leaking a timer is fine.
|
||||
case <-time.After(directConnectionDelay):
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
|
||||
}
|
||||
|
||||
portUint, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse port %q: %w", port, err)
|
||||
}
|
||||
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint)))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// HACK: even though the ping passed above, we still need to open a
|
||||
// connection to the agent to ensure it's ready to accept connections. Not
|
||||
// sure why this is the case but it seems to be necessary.
|
||||
func verifyConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
|
||||
const verifyConnectionAttempts = 30
|
||||
const verifyConnectionDelay = 1 * time.Second
|
||||
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
client := agentHTTPClient(conn)
|
||||
for i := 0; i < verifyConnectionAttempts; i++ {
|
||||
_, _ = fmt.Fprintf(logs, "\tVerify connection attempt %d/%d...\n", i+1, verifyConnectionAttempts)
|
||||
verifyCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
|
||||
|
@ -198,14 +243,20 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make initial connections sequentially to ensure the services are
|
||||
// reachable before we start spawning a bunch of goroutines and tickers.
|
||||
if len(r.cfg.Connections) > 0 {
|
||||
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
|
||||
func performInitialConnections(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, specs []Connection) error {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i, connSpec := range r.cfg.Connections {
|
||||
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
|
||||
client := agentHTTPClient(conn)
|
||||
for i, connSpec := range specs {
|
||||
_, _ = fmt.Fprintf(logs, "\t%d. %s\n", i, connSpec.URL)
|
||||
|
||||
timeout := defaultRequestTimeout
|
||||
|
@ -230,95 +281,102 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
_, _ = fmt.Fprintln(logs, "\t\tOK")
|
||||
}
|
||||
|
||||
if r.cfg.HoldDuration > 0 {
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(r.cfg.Connections) > 0 {
|
||||
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...")
|
||||
}
|
||||
for i, connSpec := range r.cfg.Connections {
|
||||
i, connSpec := i, connSpec
|
||||
if connSpec.Interval <= 0 {
|
||||
continue
|
||||
}
|
||||
func holdConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, holdDur time.Duration, specs []Connection) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
eg.Go(func() error {
|
||||
t := time.NewTicker(time.Duration(connSpec.Interval))
|
||||
defer t.Stop()
|
||||
|
||||
timeout := defaultRequestTimeout
|
||||
if connSpec.Timeout > 0 {
|
||||
timeout = time.Duration(connSpec.Timeout)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-egCtx.Done():
|
||||
return egCtx.Err()
|
||||
case <-t.C:
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
cancel()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
|
||||
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
|
||||
t.Reset(time.Duration(connSpec.Interval))
|
||||
}
|
||||
}
|
||||
})
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
client := agentHTTPClient(conn)
|
||||
if len(specs) > 0 {
|
||||
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...")
|
||||
}
|
||||
for i, connSpec := range specs {
|
||||
i, connSpec := i, connSpec
|
||||
if connSpec.Interval <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Wait for the hold duration to end. We use a fake error to signal that
|
||||
// the hold duration has ended.
|
||||
_, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", time.Duration(r.cfg.HoldDuration))
|
||||
eg.Go(func() error {
|
||||
t := time.NewTicker(time.Duration(r.cfg.HoldDuration))
|
||||
t := time.NewTicker(time.Duration(connSpec.Interval))
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-egCtx.Done():
|
||||
return egCtx.Err()
|
||||
case <-t.C:
|
||||
// Returning an error here will cause the errgroup context to
|
||||
// be canceled, which is what we want. This fake error is
|
||||
// ignored below.
|
||||
return holdDurationEndedError{}
|
||||
timeout := defaultRequestTimeout
|
||||
if connSpec.Timeout > 0 {
|
||||
timeout = time.Duration(connSpec.Timeout)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-egCtx.Done():
|
||||
return egCtx.Err()
|
||||
case <-t.C:
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
cancel()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
|
||||
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
|
||||
t.Reset(time.Duration(connSpec.Interval))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
err = eg.Wait()
|
||||
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
|
||||
return xerrors.Errorf("run connections loop: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("close connection: %w", err)
|
||||
// Wait for the hold duration to end. We use a fake error to signal that
|
||||
// the hold duration has ended.
|
||||
_, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", holdDur)
|
||||
eg.Go(func() error {
|
||||
t := time.NewTicker(holdDur)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-egCtx.Done():
|
||||
return egCtx.Err()
|
||||
case <-t.C:
|
||||
// Returning an error here will cause the errgroup context to
|
||||
// be canceled, which is what we want. This fake error is
|
||||
// ignored below.
|
||||
return holdDurationEndedError{}
|
||||
}
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
|
||||
return xerrors.Errorf("run connections loop: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncWriter wraps an io.Writer in a sync.Mutex.
|
||||
type syncWriter struct {
|
||||
mut *sync.Mutex
|
||||
w io.Writer
|
||||
}
|
||||
func agentHTTPClient(conn *codersdk.AgentConn) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (sw syncWriter) Write(p []byte) (n int, err error) {
|
||||
sw.mut.Lock()
|
||||
defer sw.mut.Unlock()
|
||||
return sw.w.Write(p)
|
||||
portUint, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse port %q: %w", port, err)
|
||||
}
|
||||
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint)))
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
// ExecutionStrategy defines how a TestHarness should execute a set of runs. It
|
||||
|
@ -49,6 +51,9 @@ func NewTestHarness(strategy ExecutionStrategy) *TestHarness {
|
|||
//
|
||||
// Panics if called more than once.
|
||||
func (h *TestHarness) Run(ctx context.Context) (err error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
h.mut.Lock()
|
||||
if h.started {
|
||||
h.mut.Unlock()
|
||||
|
|
|
@ -95,6 +95,7 @@ type timeoutRunnerWrapper struct {
|
|||
}
|
||||
|
||||
var _ Runnable = timeoutRunnerWrapper{}
|
||||
var _ Cleanable = timeoutRunnerWrapper{}
|
||||
|
||||
func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
|
@ -103,6 +104,15 @@ func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer
|
|||
return t.inner.Run(ctx, id, logs)
|
||||
}
|
||||
|
||||
func (t timeoutRunnerWrapper) Cleanup(ctx context.Context, id string) error {
|
||||
c, ok := t.inner.(Cleanable)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.Cleanup(ctx, id)
|
||||
}
|
||||
|
||||
// Execute implements ExecutionStrategy.
|
||||
func (t TimeoutExecutionStrategyWrapper) Execute(ctx context.Context, runs []*TestRun) error {
|
||||
for _, run := range runs {
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package loadtestutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SyncWriter wraps an io.Writer in a sync.Mutex.
|
||||
type SyncWriter struct {
|
||||
mut *sync.Mutex
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func NewSyncWriter(w io.Writer) *SyncWriter {
|
||||
return &SyncWriter{
|
||||
mut: &sync.Mutex{},
|
||||
w: w,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (sw *SyncWriter) Write(p []byte) (n int, err error) {
|
||||
sw.mut.Lock()
|
||||
defer sw.mut.Unlock()
|
||||
return sw.w.Write(p)
|
||||
}
|
|
@ -9,9 +9,14 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/loadtest/harness"
|
||||
"github.com/coder/coder/loadtest/loadtestutil"
|
||||
)
|
||||
|
||||
type Runner struct {
|
||||
|
@ -32,6 +37,14 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
|||
|
||||
// Run implements Runnable.
|
||||
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
logs = loadtestutil.NewSyncWriter(logs)
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
||||
r.client.Logger = logger
|
||||
r.client.LogBodies = true
|
||||
|
||||
req := r.cfg.Request
|
||||
if req.Name == "" {
|
||||
randName, err := cryptorand.HexString(8)
|
||||
|
@ -66,6 +79,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error {
|
|||
if r.workspaceID == uuid.Nil {
|
||||
return nil
|
||||
}
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
build, err := r.client.CreateWorkspaceBuild(ctx, r.workspaceID, codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: codersdk.WorkspaceTransitionDelete,
|
||||
|
@ -85,6 +100,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error {
|
|||
}
|
||||
|
||||
func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, buildID uuid.UUID) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
_, _ = fmt.Fprint(w, "Build is currently queued...")
|
||||
|
||||
// Wait for build to start.
|
||||
|
@ -154,6 +171,8 @@ func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, bui
|
|||
}
|
||||
|
||||
func waitForAgents(ctx context.Context, w io.Writer, client *codersdk.Client, workspaceID uuid.UUID) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
_, _ = fmt.Fprint(w, "Waiting for agents to connect...\n\n")
|
||||
|
||||
for {
|
||||
|
|
Loading…
Reference in New Issue