coder/pty/start_test.go

177 lines
4.0 KiB
Go

package pty_test
import (
"bytes"
"context"
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/hinshun/vt10x"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty"
"github.com/coder/coder/testutil"
)
// Test_Start_copy tests that we can use io.Copy() on command output
// without deadlocking.
func Test_Start_copy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdEcho, argEcho...))
require.NoError(t, err)
b := &bytes.Buffer{}
readDone := make(chan error, 1)
go func() {
_, err := io.Copy(b, pc.OutputReader())
readDone <- err
}()
select {
case err := <-readDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("read timed out")
}
assert.Contains(t, b.String(), "test")
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Wait()
}()
select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}
}
// Test_Start_truncation tests that we can read command output without truncation
// even after the command has exited.
func Test_Start_truncation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
pc, cmd, err := pty.Start(pty.CommandContext(ctx, cmdCount, argCount...))
require.NoError(t, err)
readDone := make(chan struct{})
go func() {
defer close(readDone)
// avoid buffered IO so that we can precisely control how many bytes to read.
n := 1
for n <= countEnd {
want := fmt.Sprintf("%d", n)
err := readUntil(ctx, t, want, pc.OutputReader())
assert.NoError(t, err, "want: %s", want)
if err != nil {
return
}
n++
if (countEnd - n) < 100 {
// If the OS buffers the output, the process can exit even if
// we're not done reading. We want to slow our reads so that
// if there is a race between reading the data and it being
// truncated, we will lose and fail the test.
time.Sleep(testutil.IntervalFast)
}
}
// ensure we still get to EOF
endB := &bytes.Buffer{}
_, err := io.Copy(endB, pc.OutputReader())
assert.NoError(t, err)
}()
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Wait()
}()
select {
case err := <-cmdDone:
require.NoError(t, err)
case <-ctx.Done():
t.Fatal("cmd.Wait() timed out")
}
select {
case <-readDone:
// OK!
case <-ctx.Done():
t.Fatal("read timed out")
}
}
// Test_Start_cancel_context tests that we can cancel the command context and kill the process.
func Test_Start_cancel_context(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
cmdCtx, cmdCancel := context.WithCancel(ctx)
pc, cmd, err := pty.Start(pty.CommandContext(cmdCtx, cmdSleep, argSleep...))
require.NoError(t, err)
defer func() {
_ = pc.Close()
}()
cmdCancel()
cmdDone := make(chan struct{})
go func() {
defer close(cmdDone)
_ = cmd.Wait()
}()
select {
case <-cmdDone:
// OK!
case <-ctx.Done():
t.Error("cmd.Wait() timed out")
}
}
// readUntil reads one byte at a time until we either see the string we want, or the context expires
func readUntil(ctx context.Context, t *testing.T, want string, r io.Reader) error {
// output can contain virtual terminal sequences, so we need to parse these
// to correctly interpret getting what we want.
term := vt10x.New(vt10x.WithSize(80, 80))
readErrs := make(chan error, 1)
for {
b := make([]byte, 1)
go func() {
_, err := r.Read(b)
readErrs <- err
}()
select {
case err := <-readErrs:
if err != nil {
t.Logf("err: %v\ngot: %v", err, term)
return err
}
term.Write(b)
case <-ctx.Done():
return ctx.Err()
}
got := term.String()
lines := strings.Split(got, "\n")
for _, line := range lines {
if strings.TrimSpace(line) == want {
t.Logf("want: %v\n got:%v", want, line)
return nil
}
}
}
}