mirror of https://github.com/coder/coder.git
193 lines
5.2 KiB
Go
193 lines
5.2 KiB
Go
//go:build !windows
|
|
|
|
package agentssh
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
|
|
gliderssh "github.com/gliderlabs/ssh"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/pty"
|
|
"github.com/coder/coder/testutil"
|
|
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
)
|
|
|
|
const longScript = `
|
|
echo "started"
|
|
sleep 30
|
|
echo "done"
|
|
`
|
|
|
|
// Test_sessionStart_orphan tests running a command that takes a long time to
|
|
// exit normally, and terminate the SSH session context early to verify that we
|
|
// return quickly and don't leave the command running as an "orphan" with no
|
|
// active SSH session.
|
|
func Test_sessionStart_orphan(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
|
defer cancel()
|
|
logger := slogtest.Make(t, nil)
|
|
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
|
|
require.NoError(t, err)
|
|
|
|
// Here we're going to call the handler directly with a faked SSH session
|
|
// that just uses io.Pipes instead of a network socket. There is a large
|
|
// variation in the time between closing the socket from the client side and
|
|
// the SSH server canceling the session Context, which would lead to a flaky
|
|
// test if we did it that way. So instead, we directly cancel the context
|
|
// in this test.
|
|
sessionCtx, sessionCancel := context.WithCancel(ctx)
|
|
toClient, fromClient, sess := newTestSession(sessionCtx)
|
|
ptyInfo := gliderssh.Pty{}
|
|
windowSize := make(chan gliderssh.Window)
|
|
close(windowSize)
|
|
// the command gets the session context so that Go will terminate it when
|
|
// the session expires.
|
|
cmd := pty.CommandContext(sessionCtx, "sh", "-c", longScript)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
|
|
// we don't really care what the error is here. In the larger scenario,
|
|
// the client has disconnected, so we can't return any error information
|
|
// to them.
|
|
_ = s.startPTYSession(sess, "ssh", cmd, ptyInfo, windowSize)
|
|
}()
|
|
|
|
readDone := make(chan struct{})
|
|
go func() {
|
|
defer close(readDone)
|
|
s := bufio.NewScanner(toClient)
|
|
assert.True(t, s.Scan())
|
|
txt := s.Text()
|
|
assert.Equal(t, "started", txt, "output corrupted")
|
|
}()
|
|
|
|
waitForChan(ctx, t, readDone, "read timeout")
|
|
// process is started, and should be sleeping for ~30 seconds
|
|
|
|
sessionCancel()
|
|
|
|
// now, we wait for the handler to complete. If it does so before the
|
|
// main test timeout, we consider this a pass. If not, it indicates
|
|
// that the server isn't properly shutting down sessions when they are
|
|
// disconnected client side, which could lead to processes hanging around
|
|
// indefinitely.
|
|
waitForChan(ctx, t, done, "handler timeout")
|
|
|
|
err = fromClient.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func waitForChan(ctx context.Context, t *testing.T, c <-chan struct{}, msg string) {
|
|
t.Helper()
|
|
select {
|
|
case <-c:
|
|
// OK!
|
|
case <-ctx.Done():
|
|
t.Fatal(msg)
|
|
}
|
|
}
|
|
|
|
type testSession struct {
|
|
ctx testSSHContext
|
|
|
|
// c2p is the client -> pty buffer
|
|
toPty *io.PipeReader
|
|
// p2c is the pty -> client buffer
|
|
fromPty *io.PipeWriter
|
|
}
|
|
|
|
type testSSHContext struct {
|
|
context.Context
|
|
}
|
|
|
|
func newTestSession(ctx context.Context) (toClient *io.PipeReader, fromClient *io.PipeWriter, s ptySession) {
|
|
toClient, fromPty := io.Pipe()
|
|
toPty, fromClient := io.Pipe()
|
|
|
|
return toClient, fromClient, &testSession{
|
|
ctx: testSSHContext{ctx},
|
|
toPty: toPty,
|
|
fromPty: fromPty,
|
|
}
|
|
}
|
|
|
|
func (s *testSession) Context() gliderssh.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
func (*testSession) DisablePTYEmulation() {}
|
|
|
|
// RawCommand returns "quiet logon" so that the PTY handler doesn't attempt to
|
|
// write the message of the day, which will interfere with our tests. It writes
|
|
// the message of the day if it's a shell login (zero length RawCommand()).
|
|
func (*testSession) RawCommand() string { return "quiet logon" }
|
|
|
|
func (s *testSession) Read(p []byte) (n int, err error) {
|
|
return s.toPty.Read(p)
|
|
}
|
|
|
|
func (s *testSession) Write(p []byte) (n int, err error) {
|
|
return s.fromPty.Write(p)
|
|
}
|
|
|
|
func (testSSHContext) Lock() {
|
|
panic("not implemented")
|
|
}
|
|
|
|
func (testSSHContext) Unlock() {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// User returns the username used when establishing the SSH connection.
|
|
func (testSSHContext) User() string {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// SessionID returns the session hash.
|
|
func (testSSHContext) SessionID() string {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// ClientVersion returns the version reported by the client.
|
|
func (testSSHContext) ClientVersion() string {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// ServerVersion returns the version reported by the server.
|
|
func (testSSHContext) ServerVersion() string {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// RemoteAddr returns the remote address for this connection.
|
|
func (testSSHContext) RemoteAddr() net.Addr {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// LocalAddr returns the local address for this connection.
|
|
func (testSSHContext) LocalAddr() net.Addr {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// Permissions returns the Permissions object used for this connection.
|
|
func (testSSHContext) Permissions() *gliderssh.Permissions {
|
|
panic("not implemented")
|
|
}
|
|
|
|
// SetValue allows you to easily write new values into the underlying context.
|
|
func (testSSHContext) SetValue(_, _ interface{}) {
|
|
panic("not implemented")
|
|
}
|