mirror of https://github.com/coder/coder.git
feat(scaletest): replace bash with dd in ssh/rpty traffic and use pseudorandomness (#10821)
Fixes #10795 Refs #8556
This commit is contained in:
parent
433be7b16d
commit
99151183bc
|
@ -10,6 +10,7 @@ import (
|
|||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -173,11 +174,12 @@ func (s *scaletestStrategyFlags) attach(opts *clibase.OptionSet) {
|
|||
|
||||
func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy {
|
||||
var strategy harness.ExecutionStrategy
|
||||
if s.concurrency == 1 {
|
||||
switch s.concurrency {
|
||||
case 1:
|
||||
strategy = harness.LinearExecutionStrategy{}
|
||||
} else if s.concurrency == 0 {
|
||||
case 0:
|
||||
strategy = harness.ConcurrentExecutionStrategy{}
|
||||
} else {
|
||||
default:
|
||||
strategy = harness.ParallelExecutionStrategy{
|
||||
Limit: int(s.concurrency),
|
||||
}
|
||||
|
@ -244,7 +246,9 @@ func (o *scaleTestOutput) write(res harness.Results, stdout io.Writer) error {
|
|||
err := s.Sync()
|
||||
// On Linux, EINVAL is returned when calling fsync on /dev/stdout. We
|
||||
// can safely ignore this error.
|
||||
if err != nil && !xerrors.Is(err, syscall.EINVAL) {
|
||||
// On macOS, ENOTTY is returned when calling sync on /dev/stdout. We
|
||||
// can safely ignore this error.
|
||||
if err != nil && !xerrors.Is(err, syscall.EINVAL) && !xerrors.Is(err, syscall.ENOTTY) {
|
||||
return xerrors.Errorf("flush output file: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -871,9 +875,13 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
|
|||
Middleware: clibase.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *clibase.Invocation) error {
|
||||
Handler: func(inv *clibase.Invocation) (err error) {
|
||||
ctx := inv.Context()
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, InterruptSignals...) // Checked later.
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -965,6 +973,7 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
|
|||
ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agentName),
|
||||
WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agentName),
|
||||
SSH: ssh,
|
||||
Echo: ssh,
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
|
@ -990,6 +999,11 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
|
|||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
|
|
|
@ -25,6 +25,12 @@ type Config struct {
|
|||
WriteMetrics ConnMetrics `json:"-"`
|
||||
|
||||
SSH bool `json:"ssh"`
|
||||
|
||||
// Echo controls whether the agent should echo the data it receives.
|
||||
// If false, the agent will discard the data. Note that setting this
|
||||
// to true will double the amount of data read from the agent for
|
||||
// PTYs (e.g. reconnecting pty or SSH connections that request PTY).
|
||||
Echo bool `json:"echo"`
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
|
|
|
@ -2,95 +2,245 @@ package workspacetraffic
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func connectPTY(ctx context.Context, client *codersdk.Client, agentID, reconnect uuid.UUID) (*countReadWriteCloser, error) {
|
||||
const (
|
||||
// Set a timeout for graceful close of the connection.
|
||||
connCloseTimeout = 30 * time.Second
|
||||
// Set a timeout for waiting for the connection to close.
|
||||
waitCloseTimeout = connCloseTimeout + 5*time.Second
|
||||
|
||||
// In theory, we can send larger payloads to push bandwidth, but we need to
|
||||
// be careful not to send too much data at once or the server will close the
|
||||
// connection. We see this more readily as our JSON payloads approach 28KB.
|
||||
//
|
||||
// failed to write frame: WebSocket closed: received close frame: status = StatusMessageTooBig and reason = "read limited at 32769 bytes"
|
||||
//
|
||||
// Since we can't control fragmentation/buffer sizes, we keep it simple and
|
||||
// match the conservative payload size used by agent/reconnectingpty (1024).
|
||||
rptyJSONMaxDataSize = 1024
|
||||
)
|
||||
|
||||
func connectRPTY(ctx context.Context, client *codersdk.Client, agentID, reconnect uuid.UUID, cmd string) (*countReadWriteCloser, error) {
|
||||
width, height := 80, 25
|
||||
conn, err := client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{
|
||||
AgentID: agentID,
|
||||
Reconnect: reconnect,
|
||||
Height: 25,
|
||||
Width: 80,
|
||||
Command: "sh",
|
||||
Width: uint16(width),
|
||||
Height: uint16(height),
|
||||
Command: cmd,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("connect pty: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the conn in a countReadWriteCloser so we can monitor bytes sent/rcvd.
|
||||
crw := countReadWriteCloser{ctx: ctx, rwc: conn}
|
||||
crw := countReadWriteCloser{rwc: newPTYConn(conn)}
|
||||
return &crw, nil
|
||||
}
|
||||
|
||||
func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID) (*countReadWriteCloser, error) {
|
||||
type rptyConn struct {
|
||||
conn io.ReadWriteCloser
|
||||
wenc *json.Encoder
|
||||
|
||||
readOnce sync.Once
|
||||
readErr chan error
|
||||
|
||||
mu sync.Mutex // Protects following.
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newPTYConn(conn io.ReadWriteCloser) *rptyConn {
|
||||
rc := &rptyConn{
|
||||
conn: conn,
|
||||
wenc: json.NewEncoder(conn),
|
||||
readErr: make(chan error, 1),
|
||||
}
|
||||
return rc
|
||||
}
|
||||
|
||||
func (c *rptyConn) Read(p []byte) (int, error) {
|
||||
n, err := c.conn.Read(p)
|
||||
if err != nil {
|
||||
c.readOnce.Do(func() {
|
||||
c.readErr <- err
|
||||
close(c.readErr)
|
||||
})
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *rptyConn) Write(p []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Early exit in case we're closing, this is to let call write Ctrl+C
|
||||
// without a flood of other writes.
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return c.writeNoLock(p)
|
||||
}
|
||||
|
||||
func (c *rptyConn) writeNoLock(p []byte) (n int, err error) {
|
||||
// If we try to send more than the max payload size, the server will close the connection.
|
||||
for len(p) > 0 {
|
||||
pp := p
|
||||
if len(pp) > rptyJSONMaxDataSize {
|
||||
pp = p[:rptyJSONMaxDataSize]
|
||||
}
|
||||
p = p[len(pp):]
|
||||
req := codersdk.ReconnectingPTYRequest{Data: string(pp)}
|
||||
if err := c.wenc.Encode(req); err != nil {
|
||||
return n, xerrors.Errorf("encode pty request: %w", err)
|
||||
}
|
||||
n += len(pp)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *rptyConn) Close() (err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
c.mu.Unlock()
|
||||
|
||||
defer c.conn.Close()
|
||||
|
||||
// Send Ctrl+C to interrupt the command.
|
||||
_, err = c.writeNoLock([]byte("\u0003"))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write ctrl+c: %w", err)
|
||||
}
|
||||
select {
|
||||
case <-time.After(connCloseTimeout):
|
||||
return xerrors.Errorf("timeout waiting for read to finish")
|
||||
case err = <-c.readErr:
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:revive // Ignore requestPTY control flag.
|
||||
func connectSSH(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, cmd string, requestPTY bool) (rwc *countReadWriteCloser, err error) {
|
||||
var closers []func() error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, c := range closers {
|
||||
if err2 := c(); err2 != nil {
|
||||
err = errors.Join(err, err2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
agentConn, err := client.DialWorkspaceAgent(ctx, agentID, &codersdk.DialWorkspaceAgentOptions{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial workspace agent: %w", err)
|
||||
}
|
||||
agentConn.AwaitReachable(ctx)
|
||||
closers = append(closers, agentConn.Close)
|
||||
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get ssh client: %w", err)
|
||||
}
|
||||
closers = append(closers, sshClient.Close)
|
||||
|
||||
sshSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_ = agentConn.Close()
|
||||
return nil, xerrors.Errorf("new ssh session: %w", err)
|
||||
}
|
||||
wrappedConn := &wrappedSSHConn{ctx: ctx}
|
||||
closers = append(closers, sshSession.Close)
|
||||
|
||||
wrappedConn := &wrappedSSHConn{}
|
||||
|
||||
// Do some plumbing to hook up the wrappedConn
|
||||
pr1, pw1 := io.Pipe()
|
||||
closers = append(closers, pr1.Close, pw1.Close)
|
||||
wrappedConn.stdout = pr1
|
||||
sshSession.Stdout = pw1
|
||||
|
||||
pr2, pw2 := io.Pipe()
|
||||
closers = append(closers, pr2.Close, pw2.Close)
|
||||
sshSession.Stdin = pr2
|
||||
wrappedConn.stdin = pw2
|
||||
err = sshSession.RequestPty("xterm", 25, 80, gossh.TerminalModes{})
|
||||
if err != nil {
|
||||
_ = pr1.Close()
|
||||
_ = pr2.Close()
|
||||
_ = pw1.Close()
|
||||
_ = pw2.Close()
|
||||
_ = sshSession.Close()
|
||||
_ = agentConn.Close()
|
||||
return nil, xerrors.Errorf("request pty: %w", err)
|
||||
|
||||
if requestPTY {
|
||||
err = sshSession.RequestPty("xterm", 25, 80, gossh.TerminalModes{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("request pty: %w", err)
|
||||
}
|
||||
}
|
||||
err = sshSession.Shell()
|
||||
err = sshSession.Start(cmd)
|
||||
if err != nil {
|
||||
_ = sshSession.Close()
|
||||
_ = agentConn.Close()
|
||||
return nil, xerrors.Errorf("shell: %w", err)
|
||||
}
|
||||
waitErr := make(chan error, 1)
|
||||
go func() {
|
||||
waitErr <- sshSession.Wait()
|
||||
}()
|
||||
|
||||
closeFn := func() error {
|
||||
var merr error
|
||||
if err := sshSession.Close(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
// Start by closing stdin so we stop writing to the ssh session.
|
||||
merr := pw2.Close()
|
||||
if err := sshSession.Signal(gossh.SIGHUP); err != nil {
|
||||
merr = errors.Join(merr, err)
|
||||
}
|
||||
if err := agentConn.Close(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
select {
|
||||
case <-time.After(connCloseTimeout):
|
||||
merr = errors.Join(merr, xerrors.Errorf("timeout waiting for ssh session to close"))
|
||||
case err := <-waitErr:
|
||||
if err != nil {
|
||||
var exitErr *gossh.ExitError
|
||||
if xerrors.As(err, &exitErr) {
|
||||
// The exit status is 255 when the command is
|
||||
// interrupted by a signal. This is expected.
|
||||
if exitErr.ExitStatus() != 255 {
|
||||
merr = errors.Join(merr, xerrors.Errorf("ssh session exited with unexpected status: %d", int32(exitErr.ExitStatus())))
|
||||
}
|
||||
} else {
|
||||
merr = errors.Join(merr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, c := range closers {
|
||||
if err := c(); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
merr = errors.Join(merr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return merr
|
||||
}
|
||||
wrappedConn.close = closeFn
|
||||
|
||||
crw := &countReadWriteCloser{ctx: ctx, rwc: wrappedConn}
|
||||
crw := &countReadWriteCloser{rwc: wrappedConn}
|
||||
|
||||
return crw, nil
|
||||
}
|
||||
|
||||
// wrappedSSHConn wraps an ssh.Session to implement io.ReadWriteCloser.
|
||||
type wrappedSSHConn struct {
|
||||
ctx context.Context
|
||||
stdout io.Reader
|
||||
stdin io.Writer
|
||||
stdin io.WriteCloser
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
close func() error
|
||||
|
@ -98,26 +248,15 @@ type wrappedSSHConn struct {
|
|||
|
||||
func (w *wrappedSSHConn) Close() error {
|
||||
w.closeOnce.Do(func() {
|
||||
_, _ = w.stdin.Write([]byte("exit\n"))
|
||||
w.closeErr = w.close()
|
||||
})
|
||||
return w.closeErr
|
||||
}
|
||||
|
||||
func (w *wrappedSSHConn) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return 0, xerrors.Errorf("read: %w", w.ctx.Err())
|
||||
default:
|
||||
return w.stdout.Read(p)
|
||||
}
|
||||
return w.stdout.Read(p)
|
||||
}
|
||||
|
||||
func (w *wrappedSSHConn) Write(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return 0, xerrors.Errorf("write: %w", w.ctx.Err())
|
||||
default:
|
||||
return w.stdin.Write(p)
|
||||
}
|
||||
return w.stdin.Write(p)
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
|
||||
// countReadWriteCloser wraps an io.ReadWriteCloser and counts the number of bytes read and written.
|
||||
type countReadWriteCloser struct {
|
||||
ctx context.Context
|
||||
rwc io.ReadWriteCloser
|
||||
readMetrics ConnMetrics
|
||||
writeMetrics ConnMetrics
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package workspacetraffic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -15,7 +18,6 @@ import (
|
|||
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
)
|
||||
|
@ -38,7 +40,7 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
||||
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
|
@ -63,10 +65,12 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
width uint16 = 80
|
||||
tickInterval = r.cfg.TickInterval
|
||||
bytesPerTick = r.cfg.BytesPerTick
|
||||
echo = r.cfg.Echo
|
||||
)
|
||||
|
||||
logger = logger.With(slog.F("agent_id", agentID))
|
||||
|
||||
logger.Debug(ctx, "config",
|
||||
slog.F("agent_id", agentID),
|
||||
slog.F("reconnecting_pty_id", reconnect),
|
||||
slog.F("height", height),
|
||||
slog.F("width", width),
|
||||
|
@ -78,34 +82,56 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
start := time.Now()
|
||||
deadlineCtx, cancel := context.WithDeadline(ctx, start.Add(r.cfg.Duration))
|
||||
defer cancel()
|
||||
logger.Debug(ctx, "connect to workspace agent", slog.F("agent_id", agentID))
|
||||
logger.Debug(ctx, "connect to workspace agent")
|
||||
|
||||
output := "/dev/stdout"
|
||||
if !echo {
|
||||
output = "/dev/null"
|
||||
}
|
||||
command := fmt.Sprintf("dd if=/dev/stdin of=%s bs=%d status=none", output, bytesPerTick)
|
||||
|
||||
var conn *countReadWriteCloser
|
||||
var err error
|
||||
if r.cfg.SSH {
|
||||
logger.Info(ctx, "connecting to workspace agent", slog.F("agent_id", agentID), slog.F("method", "ssh"))
|
||||
conn, err = connectSSH(ctx, r.client, agentID)
|
||||
logger.Info(ctx, "connecting to workspace agent", slog.F("method", "ssh"))
|
||||
// If echo is enabled, disable PTY to avoid double echo and
|
||||
// reduce CPU usage.
|
||||
requestPTY := !r.cfg.Echo
|
||||
conn, err = connectSSH(ctx, r.client, agentID, command, requestPTY)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "connect to workspace agent via ssh", slog.F("agent_id", agentID), slog.Error(err))
|
||||
logger.Error(ctx, "connect to workspace agent via ssh", slog.Error(err))
|
||||
return xerrors.Errorf("connect to workspace via ssh: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info(ctx, "connecting to workspace agent", slog.F("agent_id", agentID), slog.F("method", "reconnectingpty"))
|
||||
conn, err = connectPTY(ctx, r.client, agentID, reconnect)
|
||||
logger.Info(ctx, "connecting to workspace agent", slog.F("method", "reconnectingpty"))
|
||||
conn, err = connectRPTY(ctx, r.client, agentID, reconnect, command)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "connect to workspace agent via reconnectingpty", slog.F("agent_id", agentID), slog.Error(err))
|
||||
logger.Error(ctx, "connect to workspace agent via reconnectingpty", slog.Error(err))
|
||||
return xerrors.Errorf("connect to workspace via reconnectingpty: %w", err)
|
||||
}
|
||||
}
|
||||
var closeErr error
|
||||
closeOnce := sync.Once{}
|
||||
closeConn := func() error {
|
||||
closeOnce.Do(func() {
|
||||
closeErr = conn.Close()
|
||||
if err != nil {
|
||||
logger.Error(ctx, "close agent connection", slog.Error(err))
|
||||
}
|
||||
})
|
||||
return closeErr
|
||||
}
|
||||
defer func() {
|
||||
if err2 := closeConn(); err2 != nil {
|
||||
// Allow close error to fail the test.
|
||||
if err == nil {
|
||||
err = err2
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conn.readMetrics = r.cfg.ReadMetrics
|
||||
conn.writeMetrics = r.cfg.WriteMetrics
|
||||
|
||||
go func() {
|
||||
<-deadlineCtx.Done()
|
||||
logger.Debug(ctx, "close agent connection", slog.F("agent_id", agentID))
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// Create a ticker for sending data to the conn.
|
||||
tick := time.NewTicker(tickInterval)
|
||||
defer tick.Stop()
|
||||
|
@ -114,58 +140,59 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
|
|||
rch := make(chan error, 1)
|
||||
wch := make(chan error, 1)
|
||||
|
||||
// Read until connection is closed.
|
||||
go func() {
|
||||
<-deadlineCtx.Done()
|
||||
logger.Debug(ctx, "closing agent connection")
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// Read forever in the background.
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID))
|
||||
default:
|
||||
logger.Debug(ctx, "reading from agent", slog.F("agent_id", agentID))
|
||||
rch <- drain(conn)
|
||||
close(rch)
|
||||
}
|
||||
}()
|
||||
|
||||
// To avoid hanging, close the conn when ctx is done
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = conn.Close()
|
||||
rch := rch // Shadowed for reassignment.
|
||||
logger.Debug(ctx, "reading from agent")
|
||||
rch <- drain(conn)
|
||||
logger.Debug(ctx, "done reading from agent")
|
||||
close(rch)
|
||||
}()
|
||||
|
||||
// Write random data to the conn every tick.
|
||||
go func() {
|
||||
logger.Debug(ctx, "writing to agent", slog.F("agent_id", agentID))
|
||||
if r.cfg.SSH {
|
||||
wch <- writeRandomDataSSH(conn, bytesPerTick, tick.C)
|
||||
} else {
|
||||
wch <- writeRandomDataPTY(conn, bytesPerTick, tick.C)
|
||||
}
|
||||
logger.Debug(ctx, "done writing to agent", slog.F("agent_id", agentID))
|
||||
wch := wch // Shadowed for reassignment.
|
||||
logger.Debug(ctx, "writing to agent")
|
||||
wch <- writeRandomData(conn, bytesPerTick, tick.C)
|
||||
logger.Debug(ctx, "done writing to agent")
|
||||
close(wch)
|
||||
}()
|
||||
|
||||
// Write until the context is canceled.
|
||||
if wErr := <-wch; wErr != nil {
|
||||
return xerrors.Errorf("write to agent: %w", wErr)
|
||||
}
|
||||
var waitCloseTimeoutCh <-chan struct{}
|
||||
deadlineCtxCh := deadlineCtx.Done()
|
||||
for {
|
||||
if wch == nil && rch == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Warn(ctx, "timed out reading from agent", slog.F("agent_id", agentID))
|
||||
case rErr := <-rch:
|
||||
logger.Debug(ctx, "done reading from agent", slog.F("agent_id", agentID))
|
||||
if rErr != nil {
|
||||
return xerrors.Errorf("read from agent: %w", rErr)
|
||||
select {
|
||||
case <-waitCloseTimeoutCh:
|
||||
logger.Warn(ctx, "timed out waiting for read/write to complete",
|
||||
slog.F("write_done", wch == nil),
|
||||
slog.F("read_done", rch == nil),
|
||||
)
|
||||
return xerrors.Errorf("timed out waiting for read/write to complete: %w", ctx.Err())
|
||||
case <-deadlineCtxCh:
|
||||
go func() {
|
||||
_ = closeConn()
|
||||
}()
|
||||
deadlineCtxCh = nil // Only trigger once.
|
||||
// Wait at most closeTimeout for the connection to close cleanly.
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), waitCloseTimeout)
|
||||
defer cancel() //nolint:revive // Only called once.
|
||||
waitCloseTimeoutCh = waitCtx.Done()
|
||||
case err = <-wch:
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write to agent: %w", err)
|
||||
}
|
||||
wch = nil
|
||||
case err = <-rch:
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read from agent: %w", err)
|
||||
}
|
||||
rch = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup does nothing, successfully.
|
||||
|
@ -176,6 +203,12 @@ func (*Runner) Cleanup(context.Context, string, io.Writer) error {
|
|||
// drain drains from src until it returns io.EOF or ctx times out.
|
||||
func drain(src io.Reader) error {
|
||||
if _, err := io.Copy(io.Discard, src); err != nil {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if xerrors.Is(err, io.ErrClosedPipe) {
|
||||
return nil
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
|
@ -190,14 +223,27 @@ func drain(src io.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func writeRandomDataPTY(dst io.Writer, size int64, tick <-chan time.Time) error {
|
||||
var (
|
||||
enc = json.NewEncoder(dst)
|
||||
ptyReq = codersdk.ReconnectingPTYRequest{}
|
||||
)
|
||||
// Allowed characters for random strings, exclude most of the 0x00 - 0x1F range.
|
||||
var allowedChars = []byte("\t !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}")
|
||||
|
||||
func writeRandomData(dst io.Writer, size int64, tick <-chan time.Time) error {
|
||||
var b bytes.Buffer
|
||||
p := make([]byte, size-1)
|
||||
for range tick {
|
||||
ptyReq.Data = mustRandomComment(size - 1)
|
||||
if err := enc.Encode(ptyReq); err != nil {
|
||||
b.Reset()
|
||||
|
||||
p := mustRandom(p)
|
||||
for _, c := range p {
|
||||
_, _ = b.WriteRune(rune(allowedChars[c%byte(len(allowedChars))]))
|
||||
}
|
||||
_, _ = b.WriteString("\n")
|
||||
if _, err := b.WriteTo(dst); err != nil {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if xerrors.Is(err, io.ErrClosedPipe) {
|
||||
return nil
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
|
@ -213,36 +259,12 @@ func writeRandomDataPTY(dst io.Writer, size int64, tick <-chan time.Time) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func writeRandomDataSSH(dst io.Writer, size int64, tick <-chan time.Time) error {
|
||||
for range tick {
|
||||
payload := mustRandomComment(size - 1)
|
||||
if _, err := dst.Write([]byte(payload + "\r\n")); err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
if xerrors.As(err, &websocket.CloseError{}) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mustRandomComment returns a random string prefixed by a #.
|
||||
// This allows us to send data both to and from a workspace agent
|
||||
// while placing minimal load upon the workspace itself.
|
||||
func mustRandomComment(l int64) string {
|
||||
if l < 1 {
|
||||
l = 1
|
||||
}
|
||||
randStr, err := cryptorand.String(int(l))
|
||||
// mustRandom writes pseudo random bytes to p and panics if it fails.
|
||||
func mustRandom(p []byte) []byte {
|
||||
n, err := rand.Read(p) //nolint:gosec // We want pseudorandomness here to avoid entropy issues.
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// THIS IS A LOAD-BEARING OCTOTHORPE. DO NOT REMOVE.
|
||||
return "#" + randStr
|
||||
|
||||
return p[:n]
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestRun(t *testing.T) {
|
|||
}
|
||||
|
||||
//nolint:dupl
|
||||
t.Run("PTY", func(t *testing.T) {
|
||||
t.Run("RPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// We need to stand up an in-memory coderd and run a fake workspace.
|
||||
var (
|
||||
|
@ -91,7 +91,6 @@ func TestRun(t *testing.T) {
|
|||
var (
|
||||
bytesPerTick = 1024
|
||||
tickInterval = 1000 * time.Millisecond
|
||||
fudgeWrite = 12 // The ReconnectingPTY payload incurs some overhead
|
||||
readMetrics = &testMetrics{}
|
||||
writeMetrics = &testMetrics{}
|
||||
)
|
||||
|
@ -103,6 +102,7 @@ func TestRun(t *testing.T) {
|
|||
ReadMetrics: readMetrics,
|
||||
WriteMetrics: writeMetrics,
|
||||
SSH: false,
|
||||
Echo: false,
|
||||
})
|
||||
|
||||
var logs strings.Builder
|
||||
|
@ -139,7 +139,7 @@ func TestRun(t *testing.T) {
|
|||
t.Logf("bytes written total: %.0f\n", writeMetrics.Total())
|
||||
|
||||
// We want to ensure the metrics are somewhat accurate.
|
||||
assert.InDelta(t, bytesPerTick+fudgeWrite, writeMetrics.Total(), 0.1)
|
||||
assert.InDelta(t, bytesPerTick, writeMetrics.Total(), 0.1)
|
||||
// Read is highly variable, depending on how far we read before stopping.
|
||||
// Just ensure it's not zero.
|
||||
assert.NotZero(t, readMetrics.Total())
|
||||
|
@ -211,7 +211,6 @@ func TestRun(t *testing.T) {
|
|||
var (
|
||||
bytesPerTick = 1024
|
||||
tickInterval = 1000 * time.Millisecond
|
||||
fudgeWrite = 2 // We send \r\n, which is two bytes
|
||||
readMetrics = &testMetrics{}
|
||||
writeMetrics = &testMetrics{}
|
||||
)
|
||||
|
@ -223,6 +222,7 @@ func TestRun(t *testing.T) {
|
|||
ReadMetrics: readMetrics,
|
||||
WriteMetrics: writeMetrics,
|
||||
SSH: true,
|
||||
Echo: true,
|
||||
})
|
||||
|
||||
var logs strings.Builder
|
||||
|
@ -259,7 +259,7 @@ func TestRun(t *testing.T) {
|
|||
t.Logf("bytes written total: %.0f\n", writeMetrics.Total())
|
||||
|
||||
// We want to ensure the metrics are somewhat accurate.
|
||||
assert.InDelta(t, bytesPerTick+fudgeWrite, writeMetrics.Total(), 0.1)
|
||||
assert.InDelta(t, bytesPerTick, writeMetrics.Total(), 0.1)
|
||||
// Read is highly variable, depending on how far we read before stopping.
|
||||
// Just ensure it's not zero.
|
||||
assert.NotZero(t, readMetrics.Total())
|
||||
|
|
Loading…
Reference in New Issue