test: Refactor ptytest to use contexts and less duplication (#5740)

This commit is contained in:
Mathias Fredriksson 2023-01-17 16:02:38 +02:00 committed by GitHub
parent 77e71f3ca4
commit 145d101512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 136 deletions

View File

@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
})
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel()
ctx, _ := testutil.Context(t)
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t)
root.SetOutput(pty.Output())
err := root.Execute()
err := root.ExecuteContext(ctx)
require.NoError(t, err)
got := pty.ReadLine()
got := pty.ReadLine(ctx)
if !strings.HasPrefix(got, "postgres://") {
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
}
@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
// We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine()
httpLine := pty.ReadLine(ctx)
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine()
tlsLine := pty.ReadLine(ctx)
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
if c.httpListener {
const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine()
httpLine := pty.ReadLine(ctx)
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
}
if c.tlsListener {
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine()
tlsLine := pty.ReadLine(ctx)
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
}
@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {
t.Run("Stackdriver", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
fi := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
<-serverErr
}()
require.Eventually(t, func() bool {
line := pty.ReadLine()
return strings.HasPrefix(line, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
require.Eventually(t, func() bool {
stat, err := os.Stat(fi)
@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {
t.Run("Multiple", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
<-serverErr
}()
require.Eventually(t, func() bool {
line := pty.ReadLine()
return strings.HasPrefix(line, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
require.Eventually(t, func() bool {
stat, err := os.Stat(fi1)

View File

@ -477,7 +477,7 @@ Expire-Date: 0
// Wait for the prompt or any output really to indicate the command has
// started and accepting input on stdin.
_ = pty.ReadRune(ctx)
_ = pty.Peek(ctx, 1)
pty.WriteLine("echo hello 'world'")
pty.ExpectMatch("hello world")

View File

@ -15,6 +15,7 @@ import (
"unicode/utf8"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"github.com/coder/coder/pty"
@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
return p.ExpectMatchContext(timeout, str)
}
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
p.t.Helper()
var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
if strings.Contains(buffer.String(), str) {
return nil
}
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
if strings.Contains(buffer.String(), str) {
return nil
}
}()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
})
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
}
func (p *PTY) Peek(ctx context.Context, n int) []byte {
p.t.Helper()
var out []byte
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
var err error
out, err = p.runeReader.Peek(n)
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
return nil
}
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
return slices.Clone(out)
}
func (p *PTY) ReadRune(ctx context.Context) rune {
p.t.Helper()
var r rune
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
var err error
r, _, err = p.runeReader.ReadRune()
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
}
func (p *PTY) ReadLine(ctx context.Context) string {
p.t.Helper()
var buffer bytes.Buffer
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
if r == '\n' {
return nil
}
if r == '\r' {
// Peek the next rune to see if it's an LF and then consume
// it.
// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
if len(b) == 0 {
return nil
}
r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
if err != nil {
return err
}
}
return nil
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
}
})
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
}
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
p.t.Helper()
// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("ReadRune ctx has no deadline, using %s", timeout)
p.logf("%s ctx has no deadline, using %s", name, timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
var r rune
match := make(chan error, 1)
go func() {
defer close(match)
var err error
r, _, err = p.runeReader.ReadRune()
match <- err
match <- fn()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
return err
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("read rune context done: " + ctx.Err().Error())
_ = p.close("match deadline exceeded")
<-match
p.fatalf("read rune context done", "wanted rune; got nothing")
return 0
}
}
func (p *PTY) ReadLine() string {
p.t.Helper()
// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
timeout, cancel := context.WithCancel(context.Background())
defer cancel()
var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
if r == '\n' {
return nil
}
if r == '\r' {
// Peek the next rune to see if it's an LF and then consume
// it.
// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
if len(b) == 0 {
return nil
}
r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
if err != nil {
return err
}
}
return nil
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
}
}()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match
p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String())
return ""
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
}
}

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
)
func TestPtytest(t *testing.T) {
@ -28,14 +29,15 @@ func TestPtytest(t *testing.T) {
t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems")
}
ctx, _ := testutil.Context(t)
pty := ptytest.New(t)
// The PTY expands these to \r\n (even on linux).
pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5"))
require.Equal(t, "line 1", pty.ReadLine())
require.Equal(t, "line 2", pty.ReadLine())
require.Equal(t, "line 3", pty.ReadLine())
require.Equal(t, "line 4", pty.ReadLine())
require.Equal(t, "line 1", pty.ReadLine(ctx))
require.Equal(t, "line 2", pty.ReadLine(ctx))
require.Equal(t, "line 3", pty.ReadLine(ctx))
require.Equal(t, "line 4", pty.ReadLine(ctx))
require.Equal(t, "line 5", pty.ExpectMatch("5"))
})