diff --git a/agent/agent_test.go b/agent/agent_test.go index fbb1432b72..49f6dba293 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -76,6 +76,11 @@ func TestAgent(t *testing.T) { session.Stdin = ptty.Input() err = session.Start(command) require.NoError(t, err) + caret := "$" + if runtime.GOOS == "windows" { + caret = ">" + } + ptty.ExpectMatch(caret) ptty.WriteLine("echo test") ptty.ExpectMatch("test") ptty.WriteLine("exit") diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 730f2a244e..e465a35d06 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -3,6 +3,7 @@ package ptytest import ( "bufio" "bytes" + "context" "io" "os" "os/exec" @@ -10,6 +11,7 @@ import ( "runtime" "strings" "testing" + "time" "unicode/utf8" "github.com/stretchr/testify/require" @@ -76,6 +78,19 @@ func (p *PTY) ExpectMatch(str string) string { var buffer bytes.Buffer multiWriter := io.MultiWriter(&buffer, p.outputWriter) runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) + complete, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + go func() { + timer := time.NewTimer(10 * time.Second) + defer timer.Stop() + select { + case <-complete.Done(): + return + case <-timer.C: + } + _ = p.Close() + p.t.Errorf("match exceeded deadline: wanted %q; got %q", str, buffer.String()) + }() for { var r rune r, _, err := p.runeReader.ReadRune()