test: Refactor ptytest to use contexts and less duplication (#5740)

This commit is contained in:
Mathias Fredriksson 2023-01-17 16:02:38 +02:00 committed by GitHub
parent 77e71f3ca4
commit 145d101512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 136 deletions

View File

@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
}) })
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) { t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, _ := testutil.Context(t)
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url") root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t) pty := ptytest.New(t)
root.SetOutput(pty.Output()) root.SetOutput(pty.Output())
err := root.Execute() err := root.ExecuteContext(ctx)
require.NoError(t, err) require.NoError(t, err)
got := pty.ReadLine() got := pty.ReadLine(ctx)
if !strings.HasPrefix(got, "postgres://") { if !strings.HasPrefix(got, "postgres://") {
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got) t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
} }
@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
// We can't use waitAccessURL as it will only return the HTTP URL. // We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at " const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix) pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine() httpLine := pty.ReadLine(ctx)
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr) require.NotEmpty(t, httpAddr)
const tlsLinePrefix = "Started TLS/HTTPS listener at " const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix) pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine() tlsLine := pty.ReadLine(ctx)
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr) require.NotEmpty(t, tlsAddr)
@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
if c.httpListener { if c.httpListener {
const httpLinePrefix = "Started HTTP listener at " const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix) pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine() httpLine := pty.ReadLine(ctx)
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr) require.NotEmpty(t, httpAddr)
} }
if c.tlsListener { if c.tlsListener {
const tlsLinePrefix = "Started TLS/HTTPS listener at " const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix) pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine() tlsLine := pty.ReadLine(ctx)
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr) require.NotEmpty(t, tlsAddr)
} }
@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {
t.Run("Stackdriver", func(t *testing.T) { t.Run("Stackdriver", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc() defer cancelFunc()
fi := testutil.TempFile(t, "", "coder-logging-test-*") fi := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
<-serverErr <-serverErr
}() }()
require.Eventually(t, func() bool { // Wait for server to listen on HTTP, this is a good
line := pty.ReadLine() // starting point for expecting logs.
return strings.HasPrefix(line, "Started HTTP listener at ") _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
stat, err := os.Stat(fi) stat, err := os.Stat(fi)
@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {
t.Run("Multiple", func(t *testing.T) { t.Run("Multiple", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc() defer cancelFunc()
fi1 := testutil.TempFile(t, "", "coder-logging-test-*") fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
<-serverErr <-serverErr
}() }()
require.Eventually(t, func() bool { // Wait for server to listen on HTTP, this is a good
line := pty.ReadLine() // starting point for expecting logs.
return strings.HasPrefix(line, "Started HTTP listener at ") _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
stat, err := os.Stat(fi1) stat, err := os.Stat(fi1)

View File

@ -477,7 +477,7 @@ Expire-Date: 0
// Wait for the prompt or any output really to indicate the command has // Wait for the prompt or any output really to indicate the command has
// started and accepting input on stdin. // started and accepting input on stdin.
_ = pty.ReadRune(ctx) _ = pty.Peek(ctx, 1)
pty.WriteLine("echo hello 'world'") pty.WriteLine("echo hello 'world'")
pty.ExpectMatch("hello world") pty.ExpectMatch("hello world")

View File

@ -15,6 +15,7 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/pty" "github.com/coder/coder/pty"
@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel() defer cancel()
return p.ExpectMatchContext(timeout, str)
}
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
p.t.Helper()
var buffer bytes.Buffer var buffer bytes.Buffer
match := make(chan error, 1) err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
go func() { for {
defer close(match) r, _, err := p.runeReader.ReadRune()
match <- func() error { if err != nil {
for { return err
r, _, err := p.runeReader.ReadRune() }
if err != nil { _, err = buffer.WriteRune(r)
return err if err != nil {
} return err
_, err = buffer.WriteRune(r) }
if err != nil { if strings.Contains(buffer.String(), str) {
return err return nil
}
if strings.Contains(buffer.String(), str) {
return nil
}
} }
}()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
} }
p.logf("matched %q = %q", str, buffer.String()) })
return buffer.String() if err != nil {
case <-timeout.Done(): p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
return "" return ""
} }
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
}
func (p *PTY) Peek(ctx context.Context, n int) []byte {
p.t.Helper()
var out []byte
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
var err error
out, err = p.runeReader.Peek(n)
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
return nil
}
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
return slices.Clone(out)
} }
func (p *PTY) ReadRune(ctx context.Context) rune { func (p *PTY) ReadRune(ctx context.Context) rune {
p.t.Helper() p.t.Helper()
var r rune
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
var err error
r, _, err = p.runeReader.ReadRune()
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
}
func (p *PTY) ReadLine(ctx context.Context) string {
p.t.Helper()
var buffer bytes.Buffer
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
if r == '\n' {
return nil
}
if r == '\r' {
// Peek the next rune to see if it's an LF and then consume
// it.
// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
if len(b) == 0 {
return nil
}
r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
if err != nil {
return err
}
}
return nil
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
}
})
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
}
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
p.t.Helper()
// A timeout is mandatory, caller can decide by passing a context // A timeout is mandatory, caller can decide by passing a context
// that times out. // that times out.
if _, ok := ctx.Deadline(); !ok { if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium timeout := testutil.WaitMedium
p.logf("ReadRune ctx has no deadline, using %s", timeout) p.logf("%s ctx has no deadline, using %s", name, timeout)
var cancel context.CancelFunc var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout) ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
} }
var r rune
match := make(chan error, 1) match := make(chan error, 1)
go func() { go func() {
defer close(match) defer close(match)
var err error match <- fn()
r, _, err = p.runeReader.ReadRune()
match <- err
}() }()
select { select {
case err := <-match: case err := <-match:
if err != nil { return err
p.fatalf("read error", "%v (wanted newline; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
case <-ctx.Done(): case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit. // Ensure goroutine is cleaned up before test exit.
_ = p.close("read rune context done: " + ctx.Err().Error()) _ = p.close("match deadline exceeded")
<-match <-match
p.fatalf("read rune context done", "wanted rune; got nothing") return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
return 0
}
}
func (p *PTY) ReadLine() string {
p.t.Helper()
// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
timeout, cancel := context.WithCancel(context.Background())
defer cancel()
var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
if r == '\n' {
return nil
}
if r == '\r' {
// Peek the next rune to see if it's an LF and then consume
// it.
// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
if len(b) == 0 {
return nil
}
r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
if err != nil {
return err
}
}
return nil
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
}
}()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match
p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String())
return ""
} }
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
) )
func TestPtytest(t *testing.T) { func TestPtytest(t *testing.T) {
@ -28,14 +29,15 @@ func TestPtytest(t *testing.T) {
t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems") t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems")
} }
ctx, _ := testutil.Context(t)
pty := ptytest.New(t) pty := ptytest.New(t)
// The PTY expands these to \r\n (even on linux). // The PTY expands these to \r\n (even on linux).
pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5")) pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5"))
require.Equal(t, "line 1", pty.ReadLine()) require.Equal(t, "line 1", pty.ReadLine(ctx))
require.Equal(t, "line 2", pty.ReadLine()) require.Equal(t, "line 2", pty.ReadLine(ctx))
require.Equal(t, "line 3", pty.ReadLine()) require.Equal(t, "line 3", pty.ReadLine(ctx))
require.Equal(t, "line 4", pty.ReadLine()) require.Equal(t, "line 4", pty.ReadLine(ctx))
require.Equal(t, "line 5", pty.ExpectMatch("5")) require.Equal(t, "line 5", pty.ExpectMatch("5"))
}) })