coder/testutil/pty.go

102 lines
2.7 KiB
Go

package testutil
import (
"context"
"io"
"strings"
"testing"
"golang.org/x/xerrors"
"github.com/hinshun/vt10x"
)
// TerminalReader emulates a terminal and allows matching output. It's important in cases where we
// can get control sequences to parse them correctly, and keep the state of the terminal across the
// lifespan of the PTY, since some control sequences are relative to the current cursor position.
type TerminalReader struct {
t *testing.T
r io.Reader
term vt10x.Terminal
}
func NewTerminalReader(t *testing.T, r io.Reader) *TerminalReader {
return &TerminalReader{
t: t,
r: r,
term: vt10x.New(vt10x.WithSize(80, 80)),
}
}
// ReadUntilString emulates a terminal and reads one byte at a time until we
// either see the string we want, or the context expires. The PTY must be sized
// to 80x80 or there could be unexpected results.
func (tr *TerminalReader) ReadUntilString(ctx context.Context, want string) error {
return tr.ReadUntil(ctx, func(line string) bool {
return strings.TrimSpace(line) == want
})
}
// ReadUntil emulates a terminal and reads one byte at a time until the matcher
// returns true or the context expires. If the matcher is nil, read until EOF.
// The PTY must be sized to 80x80 or there could be unexpected results.
func (tr *TerminalReader) ReadUntil(ctx context.Context, matcher func(line string) bool) (retErr error) {
readBytes := make([]byte, 0)
readErrs := make(chan error, 1)
defer func() {
// Dump the terminal contents since they can be helpful for debugging, but
// trim empty lines since much of the terminal will usually be blank.
got := tr.term.String()
lines := strings.Split(got, "\n")
for i := range lines {
if strings.TrimSpace(lines[i]) != "" {
lines = lines[i:]
break
}
}
for i := len(lines) - 1; i >= 0; i-- {
if strings.TrimSpace(lines[i]) != "" {
lines = lines[:i+1]
break
}
}
gotTrimmed := strings.Join(lines, "\n")
tr.t.Logf("Terminal contents:\n%s", gotTrimmed)
// EOF is expected when matcher == nil
if retErr != nil && !(xerrors.Is(retErr, io.EOF) && matcher == nil) {
tr.t.Logf("Bytes Read: %q", string(readBytes))
}
}()
for {
b := make([]byte, 1)
go func() {
_, err := tr.r.Read(b)
readErrs <- err
}()
select {
case err := <-readErrs:
if err != nil {
return err
}
readBytes = append(readBytes, b...)
_, err = tr.term.Write(b)
if err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
if matcher == nil {
// A nil matcher means to read until EOF.
continue
}
got := tr.term.String()
lines := strings.Split(got, "\n")
for _, line := range lines {
if matcher(line) {
return nil
}
}
}
}