mirror of https://github.com/coder/coder.git
102 lines
2.7 KiB
Go
102 lines
2.7 KiB
Go
package testutil
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/hinshun/vt10x"
|
|
)
|
|
|
|
// TerminalReader emulates a terminal and allows matching output. It's important in cases where we
|
|
// can get control sequences to parse them correctly, and keep the state of the terminal across the
|
|
// lifespan of the PTY, since some control sequences are relative to the current cursor position.
|
|
type TerminalReader struct {
|
|
t *testing.T
|
|
r io.Reader
|
|
term vt10x.Terminal
|
|
}
|
|
|
|
func NewTerminalReader(t *testing.T, r io.Reader) *TerminalReader {
|
|
return &TerminalReader{
|
|
t: t,
|
|
r: r,
|
|
term: vt10x.New(vt10x.WithSize(80, 80)),
|
|
}
|
|
}
|
|
|
|
// ReadUntilString emulates a terminal and reads one byte at a time until we
|
|
// either see the string we want, or the context expires. The PTY must be sized
|
|
// to 80x80 or there could be unexpected results.
|
|
func (tr *TerminalReader) ReadUntilString(ctx context.Context, want string) error {
|
|
return tr.ReadUntil(ctx, func(line string) bool {
|
|
return strings.TrimSpace(line) == want
|
|
})
|
|
}
|
|
|
|
// ReadUntil emulates a terminal and reads one byte at a time until the matcher
|
|
// returns true or the context expires. If the matcher is nil, read until EOF.
|
|
// The PTY must be sized to 80x80 or there could be unexpected results.
|
|
func (tr *TerminalReader) ReadUntil(ctx context.Context, matcher func(line string) bool) (retErr error) {
|
|
readBytes := make([]byte, 0)
|
|
readErrs := make(chan error, 1)
|
|
defer func() {
|
|
// Dump the terminal contents since they can be helpful for debugging, but
|
|
// trim empty lines since much of the terminal will usually be blank.
|
|
got := tr.term.String()
|
|
lines := strings.Split(got, "\n")
|
|
for i := range lines {
|
|
if strings.TrimSpace(lines[i]) != "" {
|
|
lines = lines[i:]
|
|
break
|
|
}
|
|
}
|
|
for i := len(lines) - 1; i >= 0; i-- {
|
|
if strings.TrimSpace(lines[i]) != "" {
|
|
lines = lines[:i+1]
|
|
break
|
|
}
|
|
}
|
|
gotTrimmed := strings.Join(lines, "\n")
|
|
tr.t.Logf("Terminal contents:\n%s", gotTrimmed)
|
|
// EOF is expected when matcher == nil
|
|
if retErr != nil && !(xerrors.Is(retErr, io.EOF) && matcher == nil) {
|
|
tr.t.Logf("Bytes Read: %q", string(readBytes))
|
|
}
|
|
}()
|
|
for {
|
|
b := make([]byte, 1)
|
|
go func() {
|
|
_, err := tr.r.Read(b)
|
|
readErrs <- err
|
|
}()
|
|
select {
|
|
case err := <-readErrs:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
readBytes = append(readBytes, b...)
|
|
_, err = tr.term.Write(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
if matcher == nil {
|
|
// A nil matcher means to read until EOF.
|
|
continue
|
|
}
|
|
got := tr.term.String()
|
|
lines := strings.Split(got, "\n")
|
|
for _, line := range lines {
|
|
if matcher(line) {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
}
|