mirror of https://github.com/coder/coder.git
fix(cli): ensure `cliui.Agent` doesn't fetch infinitely (#8446)
This commit is contained in:
parent
14caa9b7c1
commit
1c3bfacca3
|
@ -15,13 +15,16 @@ var errAgentShuttingDown = xerrors.New("agent is shutting down")
|
|||
|
||||
type AgentOptions struct {
|
||||
FetchInterval time.Duration
|
||||
Fetch func(context.Context) (codersdk.WorkspaceAgent, error)
|
||||
Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error)
|
||||
FetchLogs func(ctx context.Context, agentID uuid.UUID, after int64, follow bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error)
|
||||
Wait bool // If true, wait for the agent to be ready (startup script).
|
||||
}
|
||||
|
||||
// Agent displays a spinning indicator that waits for a workspace agent to connect.
|
||||
func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
|
||||
func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentOptions) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
if opts.FetchInterval == 0 {
|
||||
opts.FetchInterval = 500 * time.Millisecond
|
||||
}
|
||||
|
@ -47,7 +50,7 @@ func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
|
|||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
agent, err := opts.Fetch(ctx)
|
||||
agent, err := opts.Fetch(ctx, agentID)
|
||||
select {
|
||||
case <-fetchedAgent:
|
||||
default:
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/coder/coder/cli/clibase"
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/util/ptr"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
@ -23,10 +25,6 @@ import (
|
|||
func TestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ptrTime := func(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
iter []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error
|
||||
|
@ -47,7 +45,7 @@ func TestAgent(t *testing.T) {
|
|||
},
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
agent.FirstConnectedAt = ptrTime(time.Now())
|
||||
agent.FirstConnectedAt = ptr.Ref(time.Now())
|
||||
close(logs)
|
||||
return nil
|
||||
},
|
||||
|
@ -69,7 +67,7 @@ func TestAgent(t *testing.T) {
|
|||
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnecting
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
|
||||
agent.StartedAt = ptrTime(time.Now())
|
||||
agent.StartedAt = ptr.Ref(time.Now())
|
||||
return nil
|
||||
},
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
|
@ -78,9 +76,9 @@ func TestAgent(t *testing.T) {
|
|||
},
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
agent.FirstConnectedAt = ptrTime(time.Now())
|
||||
agent.FirstConnectedAt = ptr.Ref(time.Now())
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
|
||||
agent.ReadyAt = ptrTime(time.Now())
|
||||
agent.ReadyAt = ptr.Ref(time.Now())
|
||||
close(logs)
|
||||
return nil
|
||||
},
|
||||
|
@ -102,17 +100,17 @@ func TestAgent(t *testing.T) {
|
|||
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentDisconnected
|
||||
agent.FirstConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
|
||||
agent.LastConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute))
|
||||
agent.DisconnectedAt = ptrTime(time.Now())
|
||||
agent.FirstConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
|
||||
agent.LastConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
|
||||
agent.DisconnectedAt = ptr.Ref(time.Now())
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
|
||||
agent.StartedAt = ptrTime(time.Now().Add(-1 * time.Minute))
|
||||
agent.ReadyAt = ptrTime(time.Now())
|
||||
agent.StartedAt = ptr.Ref(time.Now().Add(-1 * time.Minute))
|
||||
agent.ReadyAt = ptr.Ref(time.Now())
|
||||
return nil
|
||||
},
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
agent.LastConnectedAt = ptrTime(time.Now())
|
||||
agent.LastConnectedAt = ptr.Ref(time.Now())
|
||||
return nil
|
||||
},
|
||||
func(_ context.Context, _ *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
|
@ -136,9 +134,9 @@ func TestAgent(t *testing.T) {
|
|||
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
agent.FirstConnectedAt = ptrTime(time.Now())
|
||||
agent.FirstConnectedAt = ptr.Ref(time.Now())
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
|
||||
agent.StartedAt = ptrTime(time.Now())
|
||||
agent.StartedAt = ptr.Ref(time.Now())
|
||||
logs <- []codersdk.WorkspaceAgentStartupLog{
|
||||
{
|
||||
CreatedAt: time.Now(),
|
||||
|
@ -149,7 +147,7 @@ func TestAgent(t *testing.T) {
|
|||
},
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady
|
||||
agent.ReadyAt = ptrTime(time.Now())
|
||||
agent.ReadyAt = ptr.Ref(time.Now())
|
||||
logs <- []codersdk.WorkspaceAgentStartupLog{
|
||||
{
|
||||
CreatedAt: time.Now(),
|
||||
|
@ -176,10 +174,10 @@ func TestAgent(t *testing.T) {
|
|||
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
agent.FirstConnectedAt = ptrTime(time.Now())
|
||||
agent.StartedAt = ptrTime(time.Now())
|
||||
agent.FirstConnectedAt = ptr.Ref(time.Now())
|
||||
agent.StartedAt = ptr.Ref(time.Now())
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStartError
|
||||
agent.ReadyAt = ptrTime(time.Now())
|
||||
agent.ReadyAt = ptr.Ref(time.Now())
|
||||
logs <- []codersdk.WorkspaceAgentStartupLog{
|
||||
{
|
||||
CreatedAt: time.Now(),
|
||||
|
@ -222,9 +220,9 @@ func TestAgent(t *testing.T) {
|
|||
iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
agent.FirstConnectedAt = ptrTime(time.Now())
|
||||
agent.FirstConnectedAt = ptr.Ref(time.Now())
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
|
||||
agent.StartedAt = ptrTime(time.Now())
|
||||
agent.StartedAt = ptr.Ref(time.Now())
|
||||
logs <- []codersdk.WorkspaceAgentStartupLog{
|
||||
{
|
||||
CreatedAt: time.Now(),
|
||||
|
@ -234,7 +232,7 @@ func TestAgent(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error {
|
||||
agent.ReadyAt = ptrTime(time.Now())
|
||||
agent.ReadyAt = ptr.Ref(time.Now())
|
||||
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleShuttingDown
|
||||
close(logs)
|
||||
return nil
|
||||
|
@ -310,7 +308,7 @@ func TestAgent(t *testing.T) {
|
|||
|
||||
cmd := &clibase.Cmd{
|
||||
Handler: func(inv *clibase.Invocation) error {
|
||||
tc.opts.Fetch = func(_ context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
tc.opts.Fetch = func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
|
||||
var err error
|
||||
if len(tc.iter) > 0 {
|
||||
err = tc.iter[0](ctx, &agent, logs)
|
||||
|
@ -321,7 +319,7 @@ func TestAgent(t *testing.T) {
|
|||
tc.opts.FetchLogs = func(_ context.Context, _ uuid.UUID, _ int64, _ bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error) {
|
||||
return logs, closeFunc(func() error { return nil }), nil
|
||||
}
|
||||
err := cliui.Agent(inv.Context(), &buf, tc.opts)
|
||||
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, tc.opts)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
@ -350,4 +348,37 @@ func TestAgent(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("NotInfinite", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var fetchCalled uint64
|
||||
|
||||
cmd := &clibase.Cmd{
|
||||
Handler: func(inv *clibase.Invocation) error {
|
||||
buf := bytes.Buffer{}
|
||||
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, cliui.AgentOptions{
|
||||
FetchInterval: 10 * time.Millisecond,
|
||||
Fetch: func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error) {
|
||||
atomic.AddUint64(&fetchCalled, 1)
|
||||
|
||||
return codersdk.WorkspaceAgent{
|
||||
Status: codersdk.WorkspaceAgentConnected,
|
||||
LifecycleState: codersdk.WorkspaceAgentLifecycleReady,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
require.Never(t, func() bool {
|
||||
called := atomic.LoadUint64(&fetchCalled)
|
||||
return called > 5 || called == 0
|
||||
}, time.Second, 100*time.Millisecond)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, cmd.Invoke().Run())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -90,11 +90,9 @@ func (r *RootCmd) portForward() *clibase.Cmd {
|
|||
}
|
||||
}
|
||||
|
||||
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
|
||||
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
|
||||
},
|
||||
Wait: false,
|
||||
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
|
||||
Fetch: client.WorkspaceAgent,
|
||||
Wait: false,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
|
|
|
@ -40,11 +40,9 @@ func (r *RootCmd) speedtest() *clibase.Cmd {
|
|||
return err
|
||||
}
|
||||
|
||||
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
|
||||
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
|
||||
},
|
||||
Wait: false,
|
||||
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
|
||||
Fetch: client.WorkspaceAgent,
|
||||
Wait: false,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
|
|
|
@ -175,10 +175,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
|
||||
// OpenSSH passes stderr directly to the calling TTY.
|
||||
// This is required in "stdio" mode so a connecting indicator can be displayed.
|
||||
err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{
|
||||
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
|
||||
},
|
||||
err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{
|
||||
Fetch: client.WorkspaceAgent,
|
||||
FetchLogs: client.WorkspaceAgentStartupLogsAfter,
|
||||
Wait: wait,
|
||||
})
|
||||
|
|
|
@ -214,10 +214,10 @@ func main() {
|
|||
agent.LastConnectedAt = &lastConnectedAt
|
||||
},
|
||||
}
|
||||
err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{
|
||||
err := cliui.Agent(inv.Context(), inv.Stdout, uuid.Nil, cliui.AgentOptions{
|
||||
FetchInterval: 100 * time.Millisecond,
|
||||
Wait: true,
|
||||
Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) {
|
||||
Fetch: func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
|
||||
if len(fetchSteps) == 0 {
|
||||
return agent, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue