coder/pty/ptytest/ptytest.go

280 lines
5.0 KiB
Go

package ptytest
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"os/exec"
"runtime"
"strings"
"sync"
"testing"
"time"
"unicode/utf8"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/pty"
"github.com/coder/coder/testutil"
)
func New(t *testing.T, opts ...pty.Option) *PTY {
t.Helper()
ptty, err := pty.New(opts...)
require.NoError(t, err)
return create(t, ptty, "cmd")
}
func Start(t *testing.T, cmd *exec.Cmd, opts ...pty.StartOption) (*PTY, pty.Process) {
t.Helper()
ptty, ps, err := pty.Start(cmd, opts...)
require.NoError(t, err)
t.Cleanup(func() {
_ = ps.Kill()
_ = ps.Wait()
})
return create(t, ptty, cmd.Args[0]), ps
}
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
// Use pipe for logging.
logDone := make(chan struct{})
logr, logw := io.Pipe()
t.Cleanup(func() {
_ = logw.Close()
_ = logr.Close()
<-logDone // Guard against logging after test.
})
// Write to log and output buffer.
copyDone := make(chan struct{})
out := newStdbuf()
w := io.MultiWriter(logw, out)
go func() {
defer close(copyDone)
_, err := io.Copy(w, ptty.Output())
_ = out.closeErr(err)
}()
t.Cleanup(func() {
_ = out.Close()
_ = ptty.Close()
<-copyDone
})
tpty := &PTY{
t: t,
PTY: ptty,
out: out,
name: name,
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
}
go func() {
defer close(logDone)
s := bufio.NewScanner(logr)
for s.Scan() {
// Quote output to avoid terminal escape codes, e.g. bell.
tpty.logf("stdout: %q", s.Text())
}
}()
return tpty
}
type PTY struct {
pty.PTY
t *testing.T
out *stdbuf
name string
runeReader *bufio.Reader
}
func (p *PTY) ExpectMatch(str string) string {
p.t.Helper()
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
if strings.Contains(buffer.String(), str) {
return nil
}
}
}()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure gorouine is cleaned up before test exit.
_ = p.out.closeErr(p.Close())
<-match
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
return ""
}
}
func (p *PTY) Write(r rune) {
p.t.Helper()
p.logf("stdin: %q", r)
_, err := p.Input().Write([]byte{byte(r)})
require.NoError(p.t, err, "write failed")
}
func (p *PTY) WriteLine(str string) {
p.t.Helper()
newline := []byte{'\r'}
if runtime.GOOS == "windows" {
newline = append(newline, '\n')
}
p.logf("stdin: %q", str+string(newline))
_, err := p.Input().Write(append([]byte(str), newline...))
require.NoError(p.t, err, "write line failed")
}
func (p *PTY) logf(format string, args ...interface{}) {
p.t.Helper()
// Match regular logger timestamp format, we seem to be logging in
// UTC in other places as well, so match here.
p.t.Logf("%s: %s: %s", time.Now().UTC().Format("2006-01-02 15:04:05.000"), p.name, fmt.Sprintf(format, args...))
}
func (p *PTY) fatalf(reason string, format string, args ...interface{}) {
p.t.Helper()
// Ensure the message is part of the normal log stream before
// failing the test.
p.logf("%s: %s", reason, fmt.Sprintf(format, args...))
require.FailNowf(p.t, reason, format, args...)
}
// stdbuf is like a buffered stdout, it buffers writes until read.
type stdbuf struct {
r io.Reader
mu sync.Mutex // Protects following.
b []byte
more chan struct{}
err error
}
func newStdbuf() *stdbuf {
return &stdbuf{more: make(chan struct{}, 1)}
}
func (b *stdbuf) Read(p []byte) (int, error) {
if b.r == nil {
return b.readOrWaitForMore(p)
}
n, err := b.r.Read(p)
if xerrors.Is(err, io.EOF) {
b.r = nil
err = nil
if n == 0 {
return b.readOrWaitForMore(p)
}
}
return n, err
}
func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
// Deplete channel so that more check
// is for future input into buffer.
select {
case <-b.more:
default:
}
if len(b.b) == 0 {
if b.err != nil {
return 0, b.err
}
b.mu.Unlock()
<-b.more
b.mu.Lock()
}
b.r = bytes.NewReader(b.b)
b.b = b.b[len(b.b):]
return b.r.Read(p)
}
func (b *stdbuf) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
b.mu.Lock()
defer b.mu.Unlock()
if b.err != nil {
return 0, b.err
}
b.b = append(b.b, p...)
select {
case b.more <- struct{}{}:
default:
}
return len(p), nil
}
func (b *stdbuf) Close() error {
return b.closeErr(nil)
}
func (b *stdbuf) closeErr(err error) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.err != nil {
return err
}
if err == nil {
b.err = io.EOF
} else {
b.err = err
}
close(b.more)
return err
}