mirror of https://github.com/coder/coder.git
parent
d7dee2c069
commit
92a95fbd5f
|
@ -7,112 +7,137 @@ import (
|
|||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/pty"
|
||||
)
|
||||
|
||||
var (
|
||||
// Used to ensure terminal output doesn't have anything crazy!
|
||||
// See: https://stackoverflow.com/a/29497680
|
||||
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
|
||||
)
|
||||
|
||||
func New(t *testing.T) *PTY {
|
||||
ptty, err := pty.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
return create(t, ptty)
|
||||
return create(t, ptty, "cmd")
|
||||
}
|
||||
|
||||
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
|
||||
ptty, ps, err := pty.Start(cmd)
|
||||
require.NoError(t, err)
|
||||
return create(t, ptty), ps
|
||||
return create(t, ptty, cmd.Args[0]), ps
|
||||
}
|
||||
|
||||
func create(t *testing.T, ptty pty.PTY) *PTY {
|
||||
reader, writer := io.Pipe()
|
||||
scanner := bufio.NewScanner(reader)
|
||||
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
|
||||
// Use pipe for logging.
|
||||
logDone := make(chan struct{})
|
||||
logr, logw := io.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = reader.Close()
|
||||
_ = writer.Close()
|
||||
_ = logw.Close()
|
||||
_ = logr.Close()
|
||||
<-logDone // Guard against logging after test.
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
if scanner.Err() != nil {
|
||||
return
|
||||
}
|
||||
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
|
||||
defer close(logDone)
|
||||
s := bufio.NewScanner(logr)
|
||||
for s.Scan() {
|
||||
// Quote output to avoid terminal escape codes, e.g. bell.
|
||||
t.Logf("%s: stdout: %q", name, s.Text())
|
||||
}
|
||||
}()
|
||||
|
||||
// Write to log and output buffer.
|
||||
copyDone := make(chan struct{})
|
||||
out := newStdbuf()
|
||||
w := io.MultiWriter(logw, out)
|
||||
go func() {
|
||||
defer close(copyDone)
|
||||
_, err := io.Copy(w, ptty.Output())
|
||||
_ = out.closeErr(err)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
_ = out.Close
|
||||
_ = ptty.Close()
|
||||
<-copyDone
|
||||
})
|
||||
|
||||
return &PTY{
|
||||
t: t,
|
||||
PTY: ptty,
|
||||
out: out,
|
||||
|
||||
outputWriter: writer,
|
||||
runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax),
|
||||
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
|
||||
}
|
||||
}
|
||||
|
||||
type PTY struct {
|
||||
t *testing.T
|
||||
pty.PTY
|
||||
out *stdbuf
|
||||
|
||||
outputWriter io.Writer
|
||||
runeReader *bufio.Reader
|
||||
runeReader *bufio.Reader
|
||||
}
|
||||
|
||||
func (p *PTY) ExpectMatch(str string) string {
|
||||
p.t.Helper()
|
||||
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
multiWriter := io.MultiWriter(&buffer, p.outputWriter)
|
||||
runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax)
|
||||
complete, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
match := make(chan error, 1)
|
||||
go func() {
|
||||
timer := time.NewTimer(10 * time.Second)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-complete.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
_ = p.Close()
|
||||
p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
|
||||
defer close(match)
|
||||
match <- func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(buffer.String(), str) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
for {
|
||||
var r rune
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
require.NoError(p.t, err)
|
||||
_, err = runeWriter.WriteRune(r)
|
||||
require.NoError(p.t, err)
|
||||
err = runeWriter.Flush()
|
||||
require.NoError(p.t, err)
|
||||
if strings.Contains(buffer.String(), str) {
|
||||
break
|
||||
|
||||
select {
|
||||
case err := <-match:
|
||||
if err != nil {
|
||||
p.t.Fatalf("%s: read error: %v (wanted %q; got %q)", time.Now(), err, str, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.t.Logf("%s: matched %q = %q", time.Now(), str, buffer.String())
|
||||
return buffer.String()
|
||||
case <-timeout.Done():
|
||||
// Ensure goroutine is cleaned up before test exit.
|
||||
_ = p.out.closeErr(p.Close())
|
||||
<-match
|
||||
|
||||
p.t.Fatalf("%s: match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), ""))
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (p *PTY) Write(r rune) {
|
||||
p.t.Helper()
|
||||
|
||||
_, err := p.Input().Write([]byte{byte(r)})
|
||||
require.NoError(p.t, err)
|
||||
}
|
||||
|
||||
func (p *PTY) WriteLine(str string) {
|
||||
p.t.Helper()
|
||||
|
||||
newline := []byte{'\r'}
|
||||
if runtime.GOOS == "windows" {
|
||||
newline = append(newline, '\n')
|
||||
|
@ -120,3 +145,101 @@ func (p *PTY) WriteLine(str string) {
|
|||
_, err := p.Input().Write(append([]byte(str), newline...))
|
||||
require.NoError(p.t, err)
|
||||
}
|
||||
|
||||
// stdbuf is like a buffered stdout, it buffers writes until read.
|
||||
type stdbuf struct {
|
||||
r io.Reader
|
||||
|
||||
mu sync.Mutex // Protects following.
|
||||
b []byte
|
||||
more chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newStdbuf() *stdbuf {
|
||||
return &stdbuf{more: make(chan struct{}, 1)}
|
||||
}
|
||||
|
||||
func (b *stdbuf) Read(p []byte) (int, error) {
|
||||
if b.r == nil {
|
||||
return b.readOrWaitForMore(p)
|
||||
}
|
||||
|
||||
n, err := b.r.Read(p)
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
b.r = nil
|
||||
err = nil
|
||||
if n == 0 {
|
||||
return b.readOrWaitForMore(p)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *stdbuf) readOrWaitForMore(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Deplete channel so that more check
|
||||
// is for future input into buffer.
|
||||
select {
|
||||
case <-b.more:
|
||||
default:
|
||||
}
|
||||
|
||||
if len(b.b) == 0 {
|
||||
if b.err != nil {
|
||||
return 0, b.err
|
||||
}
|
||||
|
||||
b.mu.Unlock()
|
||||
<-b.more
|
||||
b.mu.Lock()
|
||||
}
|
||||
|
||||
b.r = bytes.NewReader(b.b)
|
||||
b.b = b.b[len(b.b):]
|
||||
|
||||
return b.r.Read(p)
|
||||
}
|
||||
|
||||
func (b *stdbuf) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.err != nil {
|
||||
return 0, b.err
|
||||
}
|
||||
|
||||
b.b = append(b.b, p...)
|
||||
|
||||
select {
|
||||
case b.more <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (b *stdbuf) Close() error {
|
||||
return b.closeErr(nil)
|
||||
}
|
||||
|
||||
func (b *stdbuf) closeErr(err error) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.err != nil {
|
||||
return err
|
||||
}
|
||||
if err == nil {
|
||||
b.err = io.EOF
|
||||
} else {
|
||||
b.err = err
|
||||
}
|
||||
close(b.more)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package ptytest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStdbuf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var got bytes.Buffer
|
||||
|
||||
b := newStdbuf()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, err := io.Copy(&got, b)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
_, err := b.Write([]byte("hello "))
|
||||
require.NoError(t, err)
|
||||
_, err = b.Write([]byte("world\n"))
|
||||
require.NoError(t, err)
|
||||
_, err = b.Write([]byte("bye\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = b.Close()
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
|
||||
assert.Equal(t, "hello world\nbye\n", got.String())
|
||||
}
|
|
@ -2,7 +2,6 @@ package ptytest_test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) {
|
|||
pty.WriteLine("read")
|
||||
})
|
||||
|
||||
// See https://github.com/coder/coder/issues/2122 for the motivation
|
||||
// behind this test.
|
||||
t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info.
|
||||
isPlatformBug bool
|
||||
}{
|
||||
{name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)},
|
||||
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true},
|
||||
{name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1
|
||||
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)},
|
||||
{name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
// nolint:paralleltest // Avoid parallel test to more easily identify the issue.
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
|
||||
t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122")
|
||||
}
|
||||
|
||||
cmd := cobra.Command{
|
||||
Use: "test",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
|
|
Loading…
Reference in New Issue