fix(cli): ensure `cliui.Agent` doesn't fetch infinitely (#8446)

This commit is contained in:
Colin Adler 2023-07-12 10:21:54 -05:00 committed by GitHub
parent 14caa9b7c1
commit 1c3bfacca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 44 deletions

View File

@ -15,13 +15,16 @@ var errAgentShuttingDown = xerrors.New("agent is shutting down")
type AgentOptions struct {
FetchInterval time.Duration
Fetch func(context.Context) (codersdk.WorkspaceAgent, error)
Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error)
FetchLogs func(ctx context.Context, agentID uuid.UUID, after int64, follow bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error)
Wait bool // If true, wait for the agent to be ready (startup script).
}
// Agent displays a spinning indicator that waits for a workspace agent to connect.
func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentOptions) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if opts.FetchInterval == 0 {
opts.FetchInterval = 500 * time.Millisecond
}
@ -47,7 +50,7 @@ func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
case <-ctx.Done():
return
case <-t.C:
agent, err := opts.Fetch(ctx)
agent, err := opts.Fetch(ctx, agentID)
select {
case <-fetchedAgent:
default:

View File

@ -6,6 +6,7 @@ import (
"context"
"io"
"strings"
"sync/atomic"
"testing"
"time"
@ -16,6 +17,7 @@ import (
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
@ -23,10 +25,6 @@ import (
func TestAgent(t *testing.T) {
t.Parallel()
ptrTime := func(t time.Time) *time.Time {
return &t
}
for _, tc := range []struct {
name string
iter []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error
@ -47,7 +45,7 @@ func TestAgent(t *testing.T) {
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
close(logs)
return nil
},
@ -69,7 +67,7 @@ func TestAgent(t *testing.T) {
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnecting
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
@ -78,9 +76,9 @@ func TestAgent(t *testing.T) {
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
close(logs)
return nil
},
@ -102,17 +100,17 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentDisconnected
agent.FirstConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
agent.LastConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
agent.DisconnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
agent.LastConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
agent.DisconnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
agent.StartedAt = ptrTime(time.Now().Add(-1 * time.Minute))
agent.ReadyAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
agent.ReadyAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.LastConnectedAt = ptrTime(time.Now())
agent.LastConnectedAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, _ *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
@ -136,9 +134,9 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
@ -149,7 +147,7 @@ func TestAgent(t *testing.T) {
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
@ -176,10 +174,10 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.StartedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStartError
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
@ -222,9 +220,9 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.Status = codersdk.WorkspaceAgentConnected
agent.FirstConnectedAt = ptrTime(time.Now())
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptrTime(time.Now())
agent.StartedAt = ptr.Ref(time.Now())
logs <- []codersdk.WorkspaceAgentStartupLog{
{
CreatedAt: time.Now(),
@ -234,7 +232,7 @@ func TestAgent(t *testing.T) {
return nil
},
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
agent.ReadyAt = ptrTime(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleShuttingDown
close(logs)
return nil
@ -310,7 +308,7 @@ func TestAgent(t *testing.T) {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
tc.opts.Fetch = func(_ context.Context) (codersdk.WorkspaceAgent, error) {
tc.opts.Fetch = func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
var err error
if len(tc.iter) > 0 {
err = tc.iter[0](ctx, &agent, logs)
@ -321,7 +319,7 @@ func TestAgent(t *testing.T) {
tc.opts.FetchLogs = func(_ context.Context, _ uuid.UUID, _ int64, _ bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error) {
return logs, closeFunc(func() error { return nil }), nil
}
err := cliui.Agent(inv.Context(), &buf, tc.opts)
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, tc.opts)
return err
},
}
@ -350,4 +348,37 @@ func TestAgent(t *testing.T) {
}
})
}
t.Run("NotInfinite", func(t *testing.T) {
t.Parallel()
var fetchCalled uint64
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
buf := bytes.Buffer{}
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, cliui.AgentOptions{
FetchInterval: 10 * time.Millisecond,
Fetch: func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error) {
atomic.AddUint64(&fetchCalled, 1)
return codersdk.WorkspaceAgent{
Status: codersdk.WorkspaceAgentConnected,
LifecycleState: codersdk.WorkspaceAgentLifecycleReady,
}, nil
},
})
if err != nil {
return err
}
require.Never(t, func() bool {
called := atomic.LoadUint64(&fetchCalled)
return called > 5 || called == 0
}, time.Second, 100*time.Millisecond)
return nil
},
}
require.NoError(t, cmd.Invoke().Run())
})
}

View File

@ -90,11 +90,9 @@ func (r *RootCmd) portForward() *clibase.Cmd {
}
}
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
Wait: false,
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
Fetch: client.WorkspaceAgent,
Wait: false,
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)

View File

@ -40,11 +40,9 @@ func (r *RootCmd) speedtest() *clibase.Cmd {
return err
}
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
Wait: false,
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
Fetch: client.WorkspaceAgent,
Wait: false,
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)

View File

@ -175,10 +175,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
// OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
Fetch: client.WorkspaceAgent,
FetchLogs: client.WorkspaceAgentStartupLogsAfter,
Wait: wait,
})

View File

@ -214,10 +214,10 @@ func main() {
agent.LastConnectedAt = &lastConnectedAt
},
}
err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
err := cliui.Agent(inv.Context(), inv.Stdout, uuid.Nil, cliui.AgentOptions{
FetchInterval: 100 * time.Millisecond,
Wait: true,
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
Fetch: func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
if len(fetchSteps) == 0 {
return agent, nil
}