mirror of https://github.com/coder/coder.git
fix(cli/ssh): Avoid connection hang when workspace is stopped (#7201)
* fix(cli/ssh): Avoid connection hang when workspace is stopped Two issues are addressed here: 1. We were not detecting disconnects due to waiting for Stdin to close (disconnect would only propagate after entering input and failing to write to the connection). 2. In other scenarios, where the connection drop is not detected, we now also watch workspace status and drop the connection when a workspace reaches the stopped state. Fixes: https://github.com/coder/jetbrains-coder/issues/199 Refs: #6180, #6175
This commit is contained in:
parent
fff2b1dc90
commit
c2871e12aa
80
cli/ssh.go
80
cli/ssh.go
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/coder/coder/coderd/util/ptr"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -100,17 +101,82 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
|
||||
defer stopPolling()
|
||||
|
||||
// Enure connection is closed if the context is canceled or
|
||||
// the workspace reaches the stopped state.
|
||||
//
|
||||
// Watching the stopped state is a work-around for cases
|
||||
// where the agent is not gracefully shut down and the
|
||||
// connection is left open. If, for instance, the networking
|
||||
// is stopped before the agent is shut down, the disconnect
|
||||
// will usually not propagate.
|
||||
//
|
||||
// See: https://github.com/coder/coder/issues/6180
|
||||
watchAndClose := func(closer func() error) {
|
||||
// Ensure session is ended on both context cancellation
|
||||
// and workspace stop.
|
||||
defer func() {
|
||||
_ = closer()
|
||||
}()
|
||||
|
||||
startWatchLoop:
|
||||
for {
|
||||
// (Re)connect to the coder server and watch workspace events.
|
||||
var wsWatch <-chan codersdk.Workspace
|
||||
var err error
|
||||
for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); {
|
||||
wsWatch, err = client.WatchWorkspace(ctx, workspace.ID)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case w, ok := <-wsWatch:
|
||||
if !ok {
|
||||
continue startWatchLoop
|
||||
}
|
||||
|
||||
// Transitioning to stop or delete could mean that
|
||||
// the agent will still gracefully stop. If a new
|
||||
// build is starting, there's no reason to wait for
|
||||
// the agent, it should be long gone.
|
||||
if workspace.LatestBuild.ID != w.LatestBuild.ID && w.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
|
||||
return
|
||||
}
|
||||
// Note, we only react to the stopped state here because we
|
||||
// want to give the agent a chance to gracefully shut down
|
||||
// during "stopping".
|
||||
if w.LatestBuild.Status == codersdk.WorkspaceStatusStopped {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stdio {
|
||||
rawSSH, err := conn.SSH(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rawSSH.Close()
|
||||
go watchAndClose(rawSSH.Close)
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(inv.Stdout, rawSSH)
|
||||
// Ensure stdout copy closes incase stdin is closed
|
||||
// unexpectedly. Typically we wouldn't worry about
|
||||
// this since OpenSSH should kill the proxy command.
|
||||
defer rawSSH.Close()
|
||||
|
||||
_, _ = io.Copy(rawSSH, inv.Stdin)
|
||||
}()
|
||||
_, _ = io.Copy(rawSSH, inv.Stdin)
|
||||
_, _ = io.Copy(inv.Stdout, rawSSH)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -125,13 +191,11 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
return err
|
||||
}
|
||||
defer sshSession.Close()
|
||||
|
||||
// Ensure context cancellation is propagated to the
|
||||
// SSH session, e.g. to cancel `Wait()` at the end.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
go watchAndClose(func() error {
|
||||
_ = sshSession.Close()
|
||||
}()
|
||||
_ = sshClient.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
if identityAgent == "" {
|
||||
identityAgent = os.Getenv("SSH_AUTH_SOCK")
|
||||
|
|
115
cli/ssh_test.go
115
cli/ssh_test.go
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
|
@ -143,6 +144,50 @@ func TestSSH(t *testing.T) {
|
|||
cancel()
|
||||
<-cmdDone
|
||||
})
|
||||
|
||||
t.Run("ExitOnStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it")
|
||||
}
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||
inv, root := clitest.New(t, "ssh", workspace.Name)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
pty.ExpectMatch("Waiting")
|
||||
|
||||
agentClient := agentsdk.New(client.URL)
|
||||
agentClient.SetSessionToken(agentToken)
|
||||
agentCloser := agent.New(agent.Options{
|
||||
Client: agentClient,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
defer func() {
|
||||
_ = agentCloser.Close()
|
||||
}()
|
||||
|
||||
// Ensure the agent is connected.
|
||||
pty.WriteLine("echo hell'o'")
|
||||
pty.ExpectMatchContext(ctx, "hello")
|
||||
|
||||
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop)
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "command did not exit in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Stdio", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||
|
@ -207,6 +252,76 @@ func TestSSH(t *testing.T) {
|
|||
|
||||
<-cmdDone
|
||||
})
|
||||
|
||||
t.Run("StdioExitOnStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Windows doesn't seem to clean up the process, maybe #7100 will fix it")
|
||||
}
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||
_, _ = tGoContext(t, func(ctx context.Context) {
|
||||
// Run this async so the SSH command has to wait for
|
||||
// the build and agent to connect!
|
||||
agentClient := agentsdk.New(client.URL)
|
||||
agentClient.SetSessionToken(agentToken)
|
||||
agentCloser := agent.New(agent.Options{
|
||||
Client: agentClient,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
<-ctx.Done()
|
||||
_ = agentCloser.Close()
|
||||
})
|
||||
|
||||
clientOutput, clientInput := io.Pipe()
|
||||
serverOutput, serverInput := io.Pipe()
|
||||
defer func() {
|
||||
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
inv.Stdin = clientOutput
|
||||
inv.Stdout = serverInput
|
||||
inv.Stderr = io.Discard
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
|
||||
Reader: serverOutput,
|
||||
Writer: clientInput,
|
||||
}, "", &ssh.ClientConfig{
|
||||
// #nosec
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
sshClient := ssh.NewClient(conn, channels, requests)
|
||||
defer sshClient.Close()
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
err = session.Shell()
|
||||
require.NoError(t, err)
|
||||
|
||||
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStart, database.WorkspaceTransitionStop)
|
||||
|
||||
select {
|
||||
case <-cmdDone:
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "command did not exit in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ForwardAgent", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Test not supported on windows")
|
||||
|
|
Loading…
Reference in New Issue