From 145d101512142822e3c66740676cae94ae77c95d Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 17 Jan 2023 16:02:38 +0200 Subject: [PATCH] test: Refactor ptytest to use contexts and less duplication (#5740) --- cli/server_test.go | 32 ++--- cli/ssh_test.go | 2 +- pty/ptytest/ptytest.go | 228 ++++++++++++++++++------------------ pty/ptytest/ptytest_test.go | 10 +- 4 files changed, 136 insertions(+), 136 deletions(-) diff --git a/cli/server_test.go b/cli/server_test.go index 1683ea6647..5e77ea113c 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -120,13 +120,15 @@ func TestServer(t *testing.T) { }) t.Run("BuiltinPostgresURLRaw", func(t *testing.T) { t.Parallel() + ctx, _ := testutil.Context(t) + root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url") pty := ptytest.New(t) root.SetOutput(pty.Output()) - err := root.Execute() + err := root.ExecuteContext(ctx) require.NoError(t, err) - got := pty.ReadLine() + got := pty.ReadLine(ctx) if !strings.HasPrefix(got, "postgres://") { t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got) } @@ -491,12 +493,12 @@ func TestServer(t *testing.T) { // We can't use waitAccessURL as it will only return the HTTP URL. const httpLinePrefix = "Started HTTP listener at " pty.ExpectMatch(httpLinePrefix) - httpLine := pty.ReadLine() + httpLine := pty.ReadLine(ctx) httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) require.NotEmpty(t, httpAddr) const tlsLinePrefix = "Started TLS/HTTPS listener at " pty.ExpectMatch(tlsLinePrefix) - tlsLine := pty.ReadLine() + tlsLine := pty.ReadLine(ctx) tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) require.NotEmpty(t, tlsAddr) @@ -617,14 +619,14 @@ func TestServer(t *testing.T) { if c.httpListener { const httpLinePrefix = "Started HTTP listener at " pty.ExpectMatch(httpLinePrefix) - httpLine := pty.ReadLine() + httpLine := pty.ReadLine(ctx) httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) require.NotEmpty(t, httpAddr) } if c.tlsListener { const tlsLinePrefix = "Started TLS/HTTPS listener at " pty.ExpectMatch(tlsLinePrefix) - tlsLine := pty.ReadLine() + tlsLine := pty.ReadLine(ctx) tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) require.NotEmpty(t, tlsAddr) } @@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) { t.Run("Stackdriver", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancelFunc() fi := testutil.TempFile(t, "", "coder-logging-test-*") @@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) { <-serverErr }() - require.Eventually(t, func() bool { - line := pty.ReadLine() - return strings.HasPrefix(line, "Started HTTP listener at ") - }, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http") + // Wait for server to listen on HTTP, this is a good + // starting point for expecting logs. + _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ") require.Eventually(t, func() bool { stat, err := os.Stat(fi) @@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) { t.Run("Multiple", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) + ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong) defer cancelFunc() fi1 := testutil.TempFile(t, "", "coder-logging-test-*") @@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) { <-serverErr }() - require.Eventually(t, func() bool { - line := pty.ReadLine() - return strings.HasPrefix(line, "Started HTTP listener at ") - }, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http") + // Wait for server to listen on HTTP, this is a good + // starting point for expecting logs. + _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ") require.Eventually(t, func() bool { stat, err := os.Stat(fi1) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index b5e3f725f6..dc09e0d84f 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -477,7 +477,7 @@ Expire-Date: 0 // Wait for the prompt or any output really to indicate the command has // started and accepting input on stdin. - _ = pty.ReadRune(ctx) + _ = pty.Peek(ctx, 1) pty.WriteLine("echo hello 'world'") pty.ExpectMatch("hello world") diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 6f0bae64fc..35525038f7 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -15,6 +15,7 @@ import ( "unicode/utf8" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "github.com/coder/coder/pty" @@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string { timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() + return p.ExpectMatchContext(timeout, str) +} + +// TODO(mafredri): Rename this to ExpectMatch when refactoring. +func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string { + p.t.Helper() + var buffer bytes.Buffer - match := make(chan error, 1) - go func() { - defer close(match) - match <- func() error { - for { - r, _, err := p.runeReader.ReadRune() - if err != nil { - return err - } - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - if strings.Contains(buffer.String(), str) { - return nil - } + err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error { + for { + r, _, err := p.runeReader.ReadRune() + if err != nil { + return err + } + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + if strings.Contains(buffer.String(), str) { + return nil } - }() - }() - - select { - case err := <-match: - if err != nil { - p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) - return "" } - p.logf("matched %q = %q", str, buffer.String()) - return buffer.String() - case <-timeout.Done(): - // Ensure goroutine is cleaned up before test exit. - _ = p.close("expect match timeout") - <-match - - p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String()) + }) + if err != nil { + p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String()) return "" } + p.logf("matched %q = %q", str, buffer.String()) + return buffer.String() +} + +func (p *PTY) Peek(ctx context.Context, n int) []byte { + p.t.Helper() + + var out []byte + err := p.doMatchWithDeadline(ctx, "Peek", func() error { + var err error + out, err = p.runeReader.Peek(n) + return err + }) + if err != nil { + p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out) + return nil + } + p.logf("peeked %d/%d bytes = %q", len(out), n, out) + return slices.Clone(out) } func (p *PTY) ReadRune(ctx context.Context) rune { p.t.Helper() + var r rune + err := p.doMatchWithDeadline(ctx, "ReadRune", func() error { + var err error + r, _, err = p.runeReader.ReadRune() + return err + }) + if err != nil { + p.fatalf("read error", "%v (wanted rune; got %q)", err, r) + return 0 + } + p.logf("matched rune = %q", r) + return r +} + +func (p *PTY) ReadLine(ctx context.Context) string { + p.t.Helper() + + var buffer bytes.Buffer + err := p.doMatchWithDeadline(ctx, "ReadLine", func() error { + for { + r, _, err := p.runeReader.ReadRune() + if err != nil { + return err + } + if r == '\n' { + return nil + } + if r == '\r' { + // Peek the next rune to see if it's an LF and then consume + // it. + + // Unicode code points can be up to 4 bytes, but the + // ones we're looking for are only 1 byte. + b, _ := p.runeReader.Peek(1) + if len(b) == 0 { + return nil + } + + r, _ = utf8.DecodeRune(b) + if r == '\n' { + _, _, err = p.runeReader.ReadRune() + if err != nil { + return err + } + } + + return nil + } + + _, err = buffer.WriteRune(r) + if err != nil { + return err + } + } + }) + if err != nil { + p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) + return "" + } + p.logf("matched newline = %q", buffer.String()) + return buffer.String() +} + +func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error { + p.t.Helper() + // A timeout is mandatory, caller can decide by passing a context // that times out. if _, ok := ctx.Deadline(); !ok { timeout := testutil.WaitMedium - p.logf("ReadRune ctx has no deadline, using %s", timeout) + p.logf("%s ctx has no deadline, using %s", name, timeout) var cancel context.CancelFunc //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. ctx, cancel = context.WithTimeout(ctx, timeout) defer cancel() } - var r rune match := make(chan error, 1) go func() { defer close(match) - var err error - r, _, err = p.runeReader.ReadRune() - match <- err + match <- fn() }() - select { case err := <-match: - if err != nil { - p.fatalf("read error", "%v (wanted newline; got %q)", err, r) - return 0 - } - p.logf("matched rune = %q", r) - return r + return err case <-ctx.Done(): // Ensure goroutine is cleaned up before test exit. - _ = p.close("read rune context done: " + ctx.Err().Error()) + _ = p.close("match deadline exceeded") <-match - p.fatalf("read rune context done", "wanted rune; got nothing") - return 0 - } -} - -func (p *PTY) ReadLine() string { - p.t.Helper() - - // timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - timeout, cancel := context.WithCancel(context.Background()) - defer cancel() - - var buffer bytes.Buffer - match := make(chan error, 1) - go func() { - defer close(match) - match <- func() error { - for { - r, _, err := p.runeReader.ReadRune() - if err != nil { - return err - } - if r == '\n' { - return nil - } - if r == '\r' { - // Peek the next rune to see if it's an LF and then consume - // it. - - // Unicode code points can be up to 4 bytes, but the - // ones we're looking for are only 1 byte. - b, _ := p.runeReader.Peek(1) - if len(b) == 0 { - return nil - } - - r, _ = utf8.DecodeRune(b) - if r == '\n' { - _, _, err = p.runeReader.ReadRune() - if err != nil { - return err - } - } - - return nil - } - - _, err = buffer.WriteRune(r) - if err != nil { - return err - } - } - }() - }() - - select { - case err := <-match: - if err != nil { - p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String()) - return "" - } - p.logf("matched newline = %q", buffer.String()) - return buffer.String() - case <-timeout.Done(): - // Ensure goroutine is cleaned up before test exit. - _ = p.close("expect match timeout") - <-match - - p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String()) - return "" + return xerrors.Errorf("match deadline exceeded: %w", ctx.Err()) } } diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go index 42699c77f0..68ee9f9815 100644 --- a/pty/ptytest/ptytest_test.go +++ b/pty/ptytest/ptytest_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/pty/ptytest" + "github.com/coder/coder/testutil" ) func TestPtytest(t *testing.T) { @@ -28,14 +29,15 @@ func TestPtytest(t *testing.T) { t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems") } + ctx, _ := testutil.Context(t) pty := ptytest.New(t) // The PTY expands these to \r\n (even on linux). pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5")) - require.Equal(t, "line 1", pty.ReadLine()) - require.Equal(t, "line 2", pty.ReadLine()) - require.Equal(t, "line 3", pty.ReadLine()) - require.Equal(t, "line 4", pty.ReadLine()) + require.Equal(t, "line 1", pty.ReadLine(ctx)) + require.Equal(t, "line 2", pty.ReadLine(ctx)) + require.Equal(t, "line 3", pty.ReadLine(ctx)) + require.Equal(t, "line 4", pty.ReadLine(ctx)) require.Equal(t, "line 5", pty.ExpectMatch("5")) })