diff --git a/cli/ssh_test.go b/cli/ssh_test.go index ceb34cc7f6..b5e3f725f6 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -475,6 +475,10 @@ Expire-Date: 0 // real error from being printed. t.Cleanup(cancel) + // Wait for the prompt or any output really to indicate the command has + // started and accepting input on stdin. + _ = pty.ReadRune(ctx) + pty.WriteLine("echo hello 'world'") pty.ExpectMatch("hello world") diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 40d7018118..6f0bae64fc 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -182,6 +182,47 @@ func (p *PTY) ExpectMatch(str string) string { } } +func (p *PTY) ReadRune(ctx context.Context) rune { + p.t.Helper() + + // A timeout is mandatory, caller can decide by passing a context + // that times out. + if _, ok := ctx.Deadline(); !ok { + timeout := testutil.WaitMedium + p.logf("ReadRune ctx has no deadline, using %s", timeout) + var cancel context.CancelFunc + //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*. + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + var r rune + match := make(chan error, 1) + go func() { + defer close(match) + var err error + r, _, err = p.runeReader.ReadRune() + match <- err + }() + + select { + case err := <-match: + if err != nil { + p.fatalf("read error", "%v (wanted newline; got %q)", err, r) + return 0 + } + p.logf("matched rune = %q", r) + return r + case <-ctx.Done(): + // Ensure goroutine is cleaned up before test exit. + _ = p.close("read rune context done: " + ctx.Err().Error()) + <-match + + p.fatalf("read rune context done", "wanted rune; got nothing") + return 0 + } +} + func (p *PTY) ReadLine() string { p.t.Helper()