coder/agent/agentssh/agentssh_internal_test.go

193 lines
5.2 KiB
Go

//go:build !windows
package agentssh
import (
"bufio"
"context"
"io"
"net"
"testing"
gliderssh "github.com/gliderlabs/ssh"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty"
"github.com/coder/coder/testutil"
"cdr.dev/slog/sloggers/slogtest"
)
const longScript = `
echo "started"
sleep 30
echo "done"
`
// Test_sessionStart_orphan tests running a command that takes a long time to
// exit normally, and terminate the SSH session context early to verify that we
// return quickly and don't leave the command running as an "orphan" with no
// active SSH session.
func Test_sessionStart_orphan(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
// Here we're going to call the handler directly with a faked SSH session
// that just uses io.Pipes instead of a network socket. There is a large
// variation in the time between closing the socket from the client side and
// the SSH server canceling the session Context, which would lead to a flaky
// test if we did it that way. So instead, we directly cancel the context
// in this test.
sessionCtx, sessionCancel := context.WithCancel(ctx)
toClient, fromClient, sess := newTestSession(sessionCtx)
ptyInfo := gliderssh.Pty{}
windowSize := make(chan gliderssh.Window)
close(windowSize)
// the command gets the session context so that Go will terminate it when
// the session expires.
cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript)
done := make(chan struct{})
go func() {
defer close(done)
// we don't really care what the error is here. In the larger scenario,
// the client has disconnected, so we can't return any error information
// to them.
_ = s.startPTYSession(sess, "ssh", cmd, ptyInfo, windowSize)
}()
readDone := make(chan struct{})
go func() {
defer close(readDone)
s := bufio.NewScanner(toClient)
assert.True(t, s.Scan())
txt := s.Text()
assert.Equal(t, "started", txt, "output corrupted")
}()
waitForChan(ctx, t, readDone, "read timeout")
// process is started, and should be sleeping for ~30 seconds
sessionCancel()
// now, we wait for the handler to complete. If it does so before the
// main test timeout, we consider this a pass. If not, it indicates
// that the server isn't properly shutting down sessions when they are
// disconnected client side, which could lead to processes hanging around
// indefinitely.
waitForChan(ctx, t, done, "handler timeout")
err = fromClient.Close()
require.NoError(t, err)
}
func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) {
t.Helper()
select {
case <-c:
// OK!
case <-ctx.Done():
t.Fatal(msg)
}
}
type testSession struct {
ctx testSSHContext
// c2p is the client -> pty buffer
toPty *io.PipeReader
// p2c is the pty -> client buffer
fromPty *io.PipeWriter
}
type testSSHContext struct {
context.Context
}
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
toClient, fromPty := io.Pipe()
toPty, fromClient := io.Pipe()
return toClient, fromClient, &testSession{
ctx: testSSHContext{ctx},
toPty: toPty,
fromPty: fromPty,
}
}
func (s *testSession) Context() gliderssh.Context {
return s.ctx
}
func (*testSession) DisablePTYEmulation() {}
// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to
// write the message of the day, which will interfere with our tests. It writes
// the message of the day if it's a shell login (zero length RawCommand()).
func (*testSession) RawCommand() string { return "quiet logon" }
func (s *testSession) Read(p []byte) (n int, err error) {
return s.toPty.Read(p)
}
func (s *testSession) Write(p []byte) (n int, err error) {
return s.fromPty.Write(p)
}
func (testSSHContext) Lock() {
panic("not implemented")
}
func (testSSHContext) Unlock() {
panic("not implemented")
}
// User returns the username used when establishing the SSH connection.
func (testSSHContext) User() string {
panic("not implemented")
}
// SessionID returns the session hash.
func (testSSHContext) SessionID() string {
panic("not implemented")
}
// ClientVersion returns the version reported by the client.
func (testSSHContext) ClientVersion() string {
panic("not implemented")
}
// ServerVersion returns the version reported by the server.
func (testSSHContext) ServerVersion() string {
panic("not implemented")
}
// RemoteAddr returns the remote address for this connection.
func (testSSHContext) RemoteAddr() net.Addr {
panic("not implemented")
}
// LocalAddr returns the local address for this connection.
func (testSSHContext) LocalAddr() net.Addr {
panic("not implemented")
}
// Permissions returns the Permissions object used for this connection.
func (testSSHContext) Permissions() *gliderssh.Permissions {
panic("not implemented")
}
// SetValue allows you to easily write new values into the underlying context.
func (testSSHContext) SetValue(_, _ interface{}) {
panic("not implemented")
}