test: Refactor ptytest to use contexts and less duplication (#5740)

This commit is contained in:
Mathias Fredriksson 2023-01-17 16:02:38 +02:00 committed by GitHub
parent 77e71f3ca4
commit 145d101512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 136 deletions

View File

@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
})
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel()
ctx, _ := testutil.Context(t)
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t)
root.SetOutput(pty.Output())
err := root.Execute()
err := root.ExecuteContext(ctx)
require.NoError(t, err)
got := pty.ReadLine()
got := pty.ReadLine(ctx)
if !strings.HasPrefix(got, "postgres://") {
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
}
@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
// We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine()
httpLine := pty.ReadLine(ctx)
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine()
tlsLine := pty.ReadLine(ctx)
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
if c.httpListener {
const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine()
httpLine := pty.ReadLine(ctx)
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
}
if c.tlsListener {
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine()
tlsLine := pty.ReadLine(ctx)
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
}
@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {
t.Run("Stackdriver", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
fi := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
<-serverErr
}()
require.Eventually(t, func() bool {
line := pty.ReadLine()
return strings.HasPrefix(line, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
require.Eventually(t, func() bool {
stat, err := os.Stat(fi)
@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {
t.Run("Multiple", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
<-serverErr
}()
require.Eventually(t, func() bool {
line := pty.ReadLine()
return strings.HasPrefix(line, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
require.Eventually(t, func() bool {
stat, err := os.Stat(fi1)

View File

@ -477,7 +477,7 @@ Expire-Date: 0
// Wait for the prompt or any output really to indicate the command has
// started and accepting input on stdin.
_ = pty.ReadRune(ctx)
_ = pty.Peek(ctx, 1)
pty.WriteLine("echo hello 'world'")
pty.ExpectMatch("hello world")

View File

@ -15,6 +15,7 @@ import (
"unicode/utf8"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"github.com/coder/coder/pty"
@ -143,11 +144,15 @@ func (p *PTY) ExpectMatch(str string) string {
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
return p.ExpectMatchContext(timeout, str)
}
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
p.t.Helper()
var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
@ -161,80 +166,54 @@ func (p *PTY) ExpectMatch(str string) string {
return nil
}
}
}()
}()
select {
case err := <-match:
})
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match
}
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
return ""
func (p *PTY) Peek(ctx context.Context, n int) []byte {
p.t.Helper()
var out []byte
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
var err error
out, err = p.runeReader.Peek(n)
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
return nil
}
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
return slices.Clone(out)
}
func (p *PTY) ReadRune(ctx context.Context) rune {
p.t.Helper()
// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("ReadRune ctx has no deadline, using %s", timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
var r rune
match := make(chan error, 1)
go func() {
defer close(match)
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
var err error
r, _, err = p.runeReader.ReadRune()
match <- err
}()
select {
case err := <-match:
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, r)
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("read rune context done: " + ctx.Err().Error())
<-match
p.fatalf("read rune context done", "wanted rune; got nothing")
return 0
}
}
func (p *PTY) ReadLine() string {
func (p *PTY) ReadLine(ctx context.Context) string {
p.t.Helper()
// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
timeout, cancel := context.WithCancel(context.Background())
defer cancel()
var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
@ -270,24 +249,43 @@ func (p *PTY) ReadLine() string {
return err
}
}
}()
}()
select {
case err := <-match:
})
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
case <-timeout.Done():
}
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
p.t.Helper()
// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("%s ctx has no deadline, using %s", name, timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
match := make(chan error, 1)
go func() {
defer close(match)
match <- fn()
}()
select {
case err := <-match:
return err
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
_ = p.close("match deadline exceeded")
<-match
p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String())
return ""
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
)
func TestPtytest(t *testing.T) {
@ -28,14 +29,15 @@ func TestPtytest(t *testing.T) {
t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems")
}
ctx, _ := testutil.Context(t)
pty := ptytest.New(t)
// The PTY expands these to \r\n (even on linux).
pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5"))
require.Equal(t, "line 1", pty.ReadLine())
require.Equal(t, "line 2", pty.ReadLine())
require.Equal(t, "line 3", pty.ReadLine())
require.Equal(t, "line 4", pty.ReadLine())
require.Equal(t, "line 1", pty.ReadLine(ctx))
require.Equal(t, "line 2", pty.ReadLine(ctx))
require.Equal(t, "line 3", pty.ReadLine(ctx))
require.Equal(t, "line 4", pty.ReadLine(ctx))
require.Equal(t, "line 5", pty.ExpectMatch("5"))
})