mirror of https://github.com/coder/coder.git
test: Refactor ptytest to use contexts and less duplication (#5740)
This commit is contained in:
parent
77e71f3ca4
commit
145d101512
|
@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
|
|||
})
|
||||
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, _ := testutil.Context(t)
|
||||
|
||||
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
|
||||
pty := ptytest.New(t)
|
||||
root.SetOutput(pty.Output())
|
||||
err := root.Execute()
|
||||
err := root.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := pty.ReadLine()
|
||||
got := pty.ReadLine(ctx)
|
||||
if !strings.HasPrefix(got, "postgres://") {
|
||||
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
|
||||
}
|
||||
|
@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
|
|||
// We can't use waitAccessURL as it will only return the HTTP URL.
|
||||
const httpLinePrefix = "Started HTTP listener at "
|
||||
pty.ExpectMatch(httpLinePrefix)
|
||||
httpLine := pty.ReadLine()
|
||||
httpLine := pty.ReadLine(ctx)
|
||||
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
|
||||
require.NotEmpty(t, httpAddr)
|
||||
const tlsLinePrefix = "Started TLS/HTTPS listener at "
|
||||
pty.ExpectMatch(tlsLinePrefix)
|
||||
tlsLine := pty.ReadLine()
|
||||
tlsLine := pty.ReadLine(ctx)
|
||||
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
|
||||
require.NotEmpty(t, tlsAddr)
|
||||
|
||||
|
@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
|
|||
if c.httpListener {
|
||||
const httpLinePrefix = "Started HTTP listener at "
|
||||
pty.ExpectMatch(httpLinePrefix)
|
||||
httpLine := pty.ReadLine()
|
||||
httpLine := pty.ReadLine(ctx)
|
||||
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
|
||||
require.NotEmpty(t, httpAddr)
|
||||
}
|
||||
if c.tlsListener {
|
||||
const tlsLinePrefix = "Started TLS/HTTPS listener at "
|
||||
pty.ExpectMatch(tlsLinePrefix)
|
||||
tlsLine := pty.ReadLine()
|
||||
tlsLine := pty.ReadLine(ctx)
|
||||
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
|
||||
require.NotEmpty(t, tlsAddr)
|
||||
}
|
||||
|
@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {
|
|||
|
||||
t.Run("Stackdriver", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancelFunc()
|
||||
|
||||
fi := testutil.TempFile(t, "", "coder-logging-test-*")
|
||||
|
@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
|
|||
<-serverErr
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
line := pty.ReadLine()
|
||||
return strings.HasPrefix(line, "Started HTTP listener at ")
|
||||
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
|
||||
// Wait for server to listen on HTTP, this is a good
|
||||
// starting point for expecting logs.
|
||||
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
stat, err := os.Stat(fi)
|
||||
|
@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {
|
|||
|
||||
t.Run("Multiple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancelFunc()
|
||||
|
||||
fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
|
||||
|
@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
|
|||
<-serverErr
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
line := pty.ReadLine()
|
||||
return strings.HasPrefix(line, "Started HTTP listener at ")
|
||||
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
|
||||
// Wait for server to listen on HTTP, this is a good
|
||||
// starting point for expecting logs.
|
||||
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
stat, err := os.Stat(fi1)
|
||||
|
|
|
@ -477,7 +477,7 @@ Expire-Date: 0
|
|||
|
||||
// Wait for the prompt or any output really to indicate the command has
|
||||
// started and accepting input on stdin.
|
||||
_ = pty.ReadRune(ctx)
|
||||
_ = pty.Peek(ctx, 1)
|
||||
|
||||
pty.WriteLine("echo hello 'world'")
|
||||
pty.ExpectMatch("hello world")
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"unicode/utf8"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/pty"
|
||||
|
@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
|
|||
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
return p.ExpectMatchContext(timeout, str)
|
||||
}
|
||||
|
||||
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
|
||||
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
|
||||
p.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
match := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(match)
|
||||
match <- func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(buffer.String(), str) {
|
||||
return nil
|
||||
}
|
||||
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(buffer.String(), str) {
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-match:
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.logf("matched %q = %q", str, buffer.String())
|
||||
return buffer.String()
|
||||
case <-timeout.Done():
|
||||
// Ensure goroutine is cleaned up before test exit.
|
||||
_ = p.close("expect match timeout")
|
||||
<-match
|
||||
|
||||
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.logf("matched %q = %q", str, buffer.String())
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (p *PTY) Peek(ctx context.Context, n int) []byte {
|
||||
p.t.Helper()
|
||||
|
||||
var out []byte
|
||||
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
|
||||
var err error
|
||||
out, err = p.runeReader.Peek(n)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
|
||||
return nil
|
||||
}
|
||||
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
|
||||
return slices.Clone(out)
|
||||
}
|
||||
|
||||
func (p *PTY) ReadRune(ctx context.Context) rune {
|
||||
p.t.Helper()
|
||||
|
||||
var r rune
|
||||
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
|
||||
var err error
|
||||
r, _, err = p.runeReader.ReadRune()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
|
||||
return 0
|
||||
}
|
||||
p.logf("matched rune = %q", r)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *PTY) ReadLine(ctx context.Context) string {
|
||||
p.t.Helper()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r == '\n' {
|
||||
return nil
|
||||
}
|
||||
if r == '\r' {
|
||||
// Peek the next rune to see if it's an LF and then consume
|
||||
// it.
|
||||
|
||||
// Unicode code points can be up to 4 bytes, but the
|
||||
// ones we're looking for are only 1 byte.
|
||||
b, _ := p.runeReader.Peek(1)
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, _ = utf8.DecodeRune(b)
|
||||
if r == '\n' {
|
||||
_, _, err = p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.logf("matched newline = %q", buffer.String())
|
||||
return buffer.String()
|
||||
}
|
||||
|
||||
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
|
||||
p.t.Helper()
|
||||
|
||||
// A timeout is mandatory, caller can decide by passing a context
|
||||
// that times out.
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
timeout := testutil.WaitMedium
|
||||
p.logf("ReadRune ctx has no deadline, using %s", timeout)
|
||||
p.logf("%s ctx has no deadline, using %s", name, timeout)
|
||||
var cancel context.CancelFunc
|
||||
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var r rune
|
||||
match := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(match)
|
||||
var err error
|
||||
r, _, err = p.runeReader.ReadRune()
|
||||
match <- err
|
||||
match <- fn()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-match:
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted newline; got %q)", err, r)
|
||||
return 0
|
||||
}
|
||||
p.logf("matched rune = %q", r)
|
||||
return r
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
// Ensure goroutine is cleaned up before test exit.
|
||||
_ = p.close("read rune context done: " + ctx.Err().Error())
|
||||
_ = p.close("match deadline exceeded")
|
||||
<-match
|
||||
|
||||
p.fatalf("read rune context done", "wanted rune; got nothing")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PTY) ReadLine() string {
|
||||
p.t.Helper()
|
||||
|
||||
// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
timeout, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var buffer bytes.Buffer
|
||||
match := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(match)
|
||||
match <- func() error {
|
||||
for {
|
||||
r, _, err := p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r == '\n' {
|
||||
return nil
|
||||
}
|
||||
if r == '\r' {
|
||||
// Peek the next rune to see if it's an LF and then consume
|
||||
// it.
|
||||
|
||||
// Unicode code points can be up to 4 bytes, but the
|
||||
// ones we're looking for are only 1 byte.
|
||||
b, _ := p.runeReader.Peek(1)
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, _ = utf8.DecodeRune(b)
|
||||
if r == '\n' {
|
||||
_, _, err = p.runeReader.ReadRune()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = buffer.WriteRune(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-match:
|
||||
if err != nil {
|
||||
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
|
||||
return ""
|
||||
}
|
||||
p.logf("matched newline = %q", buffer.String())
|
||||
return buffer.String()
|
||||
case <-timeout.Done():
|
||||
// Ensure goroutine is cleaned up before test exit.
|
||||
_ = p.close("expect match timeout")
|
||||
<-match
|
||||
|
||||
p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String())
|
||||
return ""
|
||||
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestPtytest(t *testing.T) {
|
||||
|
@ -28,14 +29,15 @@ func TestPtytest(t *testing.T) {
|
|||
t.Skip("ReadLine is glitchy on windows when it comes to the final line of output it seems")
|
||||
}
|
||||
|
||||
ctx, _ := testutil.Context(t)
|
||||
pty := ptytest.New(t)
|
||||
|
||||
// The PTY expands these to \r\n (even on linux).
|
||||
pty.Output().Write([]byte("line 1\nline 2\nline 3\nline 4\nline 5"))
|
||||
require.Equal(t, "line 1", pty.ReadLine())
|
||||
require.Equal(t, "line 2", pty.ReadLine())
|
||||
require.Equal(t, "line 3", pty.ReadLine())
|
||||
require.Equal(t, "line 4", pty.ReadLine())
|
||||
require.Equal(t, "line 1", pty.ReadLine(ctx))
|
||||
require.Equal(t, "line 2", pty.ReadLine(ctx))
|
||||
require.Equal(t, "line 3", pty.ReadLine(ctx))
|
||||
require.Equal(t, "line 4", pty.ReadLine(ctx))
|
||||
require.Equal(t, "line 5", pty.ExpectMatch("5"))
|
||||
})
|
||||
|
||||
|
|
Loading…
Reference in New Issue