test: Refactor ptytest to use contexts and less duplication (#5740)

This commit is contained in:
Mathias Fredriksson 2023-01-17 16:02:38 +02:00 committed by GitHub
parent 77e71f3ca4
commit 145d101512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 136 deletions

View File

@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
}) })
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) { t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, _ := testutil.Context(t)
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url") root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.SetOutput(pty.Output())
err := root.Execute() err := root.ExecuteContext(ctx)
require.NoError(t, err) require.NoError(t, err)
got := pty.ReadLine() got := pty.ReadLine(ctx)
if !strings.HasPrefix(got, "postgres://") { if !strings.HasPrefix(got, "postgres://") {
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got) t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
} }
@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
// We can't use waitAccessURL as it will only return the HTTP URL. // We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at " const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix) pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine() httpLine := pty.ReadLine(ctx)
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr) require.NotEmpty(t, httpAddr)
const tlsLinePrefix = "Started TLS/HTTPS listener at " const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix) pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine() tlsLine := pty.ReadLine(ctx)
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr) require.NotEmpty(t, tlsAddr)
@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
if c.httpListener { if c.httpListener {
const httpLinePrefix = "Started HTTP listener at " const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix) pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine() httpLine := pty.ReadLine(ctx)
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr) require.NotEmpty(t, httpAddr)
} }
if c.tlsListener { if c.tlsListener {
const tlsLinePrefix = "Started TLS/HTTPS listener at " const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix) pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine() tlsLine := pty.ReadLine(ctx)
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr) require.NotEmpty(t, tlsAddr)
} }
@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {
t.Run("Stackdriver", func(t *testing.T) { t.Run("Stackdriver", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc() defer cancelFunc()
fi := testutil.TempFile(t, "", "coder-logging-test-*") fi := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
<-serverErr <-serverErr
}() }()
require.Eventually(t, func() bool { // Wait for server to listen on HTTP, this is a good
line := pty.ReadLine() // starting point for expecting logs.
return strings.HasPrefix(line, "Started HTTP listener at ") _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
stat, err := os.Stat(fi) stat, err := os.Stat(fi)
@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {
t.Run("Multiple", func(t *testing.T) { t.Run("Multiple", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc() defer cancelFunc()
fi1 := testutil.TempFile(t, "", "coder-logging-test-*") fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
<-serverErr <-serverErr
}() }()
require.Eventually(t, func() bool { // Wait for server to listen on HTTP, this is a good
line := pty.ReadLine() // starting point for expecting logs.
return strings.HasPrefix(line, "Started HTTP listener at ") _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
stat, err := os.Stat(fi1) stat, err := os.Stat(fi1)

View File

@ -477,7 +477,7 @@ Expire-Date: 0
// Wait for the prompt or any output really to indicate the command has // Wait for the prompt or any output really to indicate the command has
// started and accepting input on stdin. // started and accepting input on stdin.
_ = pty.ReadRune(ctx) _ = pty.Peek(ctx, 1)
pty.WriteLine("echo hello 'world'") pty.WriteLine("echo hello 'world'")
pty.ExpectMatch("hello world") pty.ExpectMatch("hello world")

View File

@ -15,6 +15,7 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/pty" "github.com/coder/coder/pty"
@ -143,11 +144,15 @@ func (p *PTY) ExpectMatch(str string) string {
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel() defer cancel()
return p.ExpectMatchContext(timeout, str)
}
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
p.t.Helper()
var buffer bytes.Buffer var buffer bytes.Buffer
match := make(chan error, 1) err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
go func() {
defer close(match)
match <- func() error {
for { for {
r, _, err := p.runeReader.ReadRune() r, _, err := p.runeReader.ReadRune()
if err != nil { if err != nil {
@ -161,80 +166,54 @@ func (p *PTY) ExpectMatch(str string) string {
return nil return nil
} }
} }
}() })
}()
select {
case err := <-match:
if err != nil { if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return "" return ""
} }
p.logf("matched %q = %q", str, buffer.String()) p.logf("matched %q = %q", str, buffer.String())
return buffer.String() return buffer.String()
case <-timeout.Done(): }
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String()) func (p *PTY) Peek(ctx context.Context, n int) []byte {
return "" p.t.Helper()
var out []byte
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
var err error
out, err = p.runeReader.Peek(n)
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
return nil
} }
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
return slices.Clone(out)
} }
func (p *PTY) ReadRune(ctx context.Context) rune { func (p *PTY) ReadRune(ctx context.Context) rune {
p.t.Helper() p.t.Helper()
// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("ReadRune ctx has no deadline, using %s", timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
var r rune var r rune
match := make(chan error, 1) err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
go func() {
defer close(match)
var err error var err error
r, _, err = p.runeReader.ReadRune() r, _, err = p.runeReader.ReadRune()
match <- err return err
}() })
select {
case err := <-match:
if err != nil { if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, r) p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
return 0 return 0
} }
p.logf("matched rune = %q", r) p.logf("matched rune = %q", r)
return r return r
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("read rune context done: " + ctx.Err().Error())
<-match
p.fatalf("read rune context done", "wanted rune; got nothing")
return 0
}
} }
func (p *PTY) ReadLine() string { func (p *PTY) ReadLine(ctx context.Context) string {
p.t.Helper() p.t.Helper()
// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
timeout, cancel := context.WithCancel(context.Background())
defer cancel()
var buffer bytes.Buffer var buffer bytes.Buffer
match := make(chan error, 1) err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
go func() {
defer close(match)
match <- func() error {
for { for {
r, _, err := p.runeReader.ReadRune() r, _, err := p.runeReader.ReadRune()
if err != nil { if err != nil {
@ -270,24 +249,43 @@ func (p *PTY) ReadLine() string {
return err return err
} }
} }
}() })
}()
select {
case err := <-match:
if err != nil { if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return "" return ""
} }
p.logf("matched newline = %q", buffer.String()) p.logf("matched newline = %q", buffer.String())
return buffer.String() return buffer.String()
case <-timeout.Done(): }
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
p.t.Helper()
// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("%s ctx has no deadline, using %s", name, timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
match := make(chan error, 1)
go func() {
defer close(match)
match <- fn()
}()
select {
case err := <-match:
return err
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit. // Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout") _ = p.close("match deadline exceeded")
<-match <-match
p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String()) return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
return ""
} }
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
) )
func TestPtytest(t *testing.T) { func TestPtytest(t *testing.T) {
@ -28,14 +29,15 @@ func TestPtytest(t *testing.T) {
t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems") t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems")
} }
ctx, _ := testutil.Context(t)
pty := ptytest.New(t) pty := ptytest.New(t)
// The PTY expands these to \r\n (even on linux). // The PTY expands these to \r\n (even on linux).
pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5")) pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5"))
require.Equal(t, "line 1", pty.ReadLine()) require.Equal(t, "line 1", pty.ReadLine(ctx))
require.Equal(t, "line 2", pty.ReadLine()) require.Equal(t, "line 2", pty.ReadLine(ctx))
require.Equal(t, "line 3", pty.ReadLine()) require.Equal(t, "line 3", pty.ReadLine(ctx))
require.Equal(t, "line 4", pty.ReadLine()) require.Equal(t, "line 4", pty.ReadLine(ctx))
require.Equal(t, "line 5", pty.ExpectMatch("5")) require.Equal(t, "line 5", pty.ExpectMatch("5"))
}) })