coder/agent/reconnectingpty/buffered.go

242 lines
7.3 KiB
Go

package reconnectingpty
import (
"context"
"errors"
"io"
"net"
"time"
"github.com/armon/circbuf"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/pty"
)
// bufferedReconnectingPTY provides a reconnectable PTY by using a ring buffer to store
// scrollback.
type bufferedReconnectingPTY struct {
command *pty.Cmd
activeConns map[string]net.Conn
circularBuffer *circbuf.Buffer
ptty pty.PTYCmd
process pty.Process
metrics *prometheus.CounterVec
state *ptyState
// timer will close the reconnecting pty when it expires. The timer will be
// reset as long as there are active connections.
timer *time.Timer
timeout time.Duration
}
// newBuffered starts the buffered pty. If the context ends the process will be
// killed.
func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY {
rpty := &bufferedReconnectingPTY{
activeConns: map[string]net.Conn{},
command: cmd,
metrics: options.Metrics,
state: newState(),
timeout: options.Timeout,
}
// Default to buffer 64KiB.
circularBuffer, err := circbuf.NewBuffer(64 << 10)
if err != nil {
rpty.state.setState(StateDone, xerrors.Errorf("create circular buffer: %w", err))
return rpty
}
rpty.circularBuffer = circularBuffer
// Add TERM then start the command with a pty. pty.Cmd duplicates Path as the
// first argument so remove it.
cmdWithEnv := pty.CommandContext(ctx, cmd.Path, cmd.Args[1:]...)
cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmdWithEnv.Dir = rpty.command.Dir
ptty, process, err := pty.Start(cmdWithEnv)
if err != nil {
rpty.state.setState(StateDone, xerrors.Errorf("start pty: %w", err))
return rpty
}
rpty.ptty = ptty
rpty.process = process
go rpty.lifecycle(ctx, logger)
// Multiplex the output onto the circular buffer and each active connection.
// We do not need to separately monitor for the process exiting. When it
// exits, our ptty.OutputReader() will return EOF after reading all process
// output.
go func() {
buffer := make([]byte, 1024)
for {
read, err := ptty.OutputReader().Read(buffer)
if err != nil {
// When the PTY is closed, this is triggered.
// Error is typically a benign EOF, so only log for debugging.
if errors.Is(err, io.EOF) {
logger.Debug(ctx, "unable to read pty output, command might have exited", slog.Error(err))
} else {
logger.Warn(ctx, "unable to read pty output, command might have exited", slog.Error(err))
rpty.metrics.WithLabelValues("output_reader").Add(1)
}
// Could have been killed externally or failed to start at all (command
// not found for example).
// TODO: Should we check the process's exit code in case the command was
// invalid?
rpty.Close(nil)
break
}
part := buffer[:read]
rpty.state.cond.L.Lock()
_, err = rpty.circularBuffer.Write(part)
if err != nil {
logger.Error(ctx, "write to circular buffer", slog.Error(err))
rpty.metrics.WithLabelValues("write_buffer").Add(1)
}
// TODO: Instead of ranging over a map, could we send the output to a
// channel and have each individual Attach read from that?
for cid, conn := range rpty.activeConns {
_, err = conn.Write(part)
if err != nil {
logger.Warn(ctx,
"error writing to active connection",
slog.F("connection_id", cid),
slog.Error(err),
)
rpty.metrics.WithLabelValues("write").Add(1)
}
}
rpty.state.cond.L.Unlock()
}
}()
return rpty
}
// lifecycle manages the lifecycle of the reconnecting pty. If the context ends
// or the reconnecting pty closes the pty will be shut down.
func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.Logger) {
rpty.timer = time.AfterFunc(attachTimeout, func() {
rpty.Close(xerrors.New("reconnecting pty timeout"))
})
logger.Debug(ctx, "reconnecting pty ready")
rpty.state.setState(StateReady, nil)
state, reasonErr := rpty.state.waitForStateOrContext(ctx, StateClosing)
if state < StateClosing {
// If we have not closed yet then the context is what unblocked us (which
// means the agent is shutting down) so move into the closing phase.
rpty.Close(reasonErr)
}
rpty.timer.Stop()
rpty.state.cond.L.Lock()
// Log these closes only for debugging since the connections or processes
// might have already closed on their own.
for _, conn := range rpty.activeConns {
err := conn.Close()
if err != nil {
logger.Debug(ctx, "closed conn with error", slog.Error(err))
}
}
// Connections get removed once they close but it is possible there is still
// some data that will be written before that happens so clear the map now to
// avoid writing to closed connections.
rpty.activeConns = map[string]net.Conn{}
rpty.state.cond.L.Unlock()
// Log close/kill only for debugging since the process might have already
// closed on its own.
err := rpty.ptty.Close()
if err != nil {
logger.Debug(ctx, "closed ptty with error", slog.Error(err))
}
err = rpty.process.Kill()
if err != nil {
logger.Debug(ctx, "killed process with error", slog.Error(err))
}
logger.Info(ctx, "closed reconnecting pty")
rpty.state.setState(StateDone, reasonErr)
}
func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string, conn net.Conn, height, width uint16, logger slog.Logger) error {
logger.Info(ctx, "attach to reconnecting pty")
// This will kill the heartbeat once we hit EOF or an error.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
err := rpty.doAttach(connID, conn)
if err != nil {
return err
}
defer func() {
rpty.state.cond.L.Lock()
defer rpty.state.cond.L.Unlock()
delete(rpty.activeConns, connID)
}()
state, err := rpty.state.waitForStateOrContext(ctx, StateReady)
if state != StateReady {
return err
}
go heartbeat(ctx, rpty.timer, rpty.timeout)
// Resize the PTY to initial height + width.
err = rpty.ptty.Resize(height, width)
if err != nil {
// We can continue after this, it's not fatal!
logger.Warn(ctx, "reconnecting PTY initial resize failed, but will continue", slog.Error(err))
rpty.metrics.WithLabelValues("resize").Add(1)
}
// Pipe conn -> pty and block. pty -> conn is handled in newBuffered().
readConnLoop(ctx, conn, rpty.ptty, rpty.metrics, logger)
return nil
}
// doAttach adds the connection to the map and replays the buffer. It exists
// separately only for convenience to defer the mutex unlock which is not
// possible in Attach since it blocks.
func (rpty *bufferedReconnectingPTY) doAttach(connID string, conn net.Conn) error {
rpty.state.cond.L.Lock()
defer rpty.state.cond.L.Unlock()
// Write any previously stored data for the TTY. Since the command might be
// short-lived and have already exited, make sure we always at least output
// the buffer before returning, mostly just so tests pass.
prevBuf := slices.Clone(rpty.circularBuffer.Bytes())
_, err := conn.Write(prevBuf)
if err != nil {
rpty.metrics.WithLabelValues("write").Add(1)
return xerrors.Errorf("write buffer to conn: %w", err)
}
rpty.activeConns[connID] = conn
return nil
}
func (rpty *bufferedReconnectingPTY) Wait() {
_, _ = rpty.state.waitForState(StateClosing)
}
func (rpty *bufferedReconnectingPTY) Close(error error) {
// The closing state change will be handled by the lifecycle.
rpty.state.setState(StateClosing, error)
}