mirror of https://github.com/coder/coder.git
chore: replace wsconncache with a single tailnet (#8176)
This commit is contained in:
parent
0a37dd20d6
commit
c47b78c44b
|
@ -64,6 +64,7 @@ type Options struct {
|
|||
SSHMaxTimeout time.Duration
|
||||
TailnetListenPort uint16
|
||||
Subsystem codersdk.AgentSubsystem
|
||||
Addresses []netip.Prefix
|
||||
|
||||
PrometheusRegistry *prometheus.Registry
|
||||
}
|
||||
|
@ -132,6 +133,7 @@ func New(options Options) Agent {
|
|||
connStatsChan: make(chan *agentsdk.Stats, 1),
|
||||
sshMaxTimeout: options.SSHMaxTimeout,
|
||||
subsystem: options.Subsystem,
|
||||
addresses: options.Addresses,
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
|
@ -177,6 +179,7 @@ type agent struct {
|
|||
lifecycleStates []agentsdk.PostLifecycleRequest
|
||||
|
||||
network *tailnet.Conn
|
||||
addresses []netip.Prefix
|
||||
connStatsChan chan *agentsdk.Stats
|
||||
latestStat atomic.Pointer[agentsdk.Stats]
|
||||
|
||||
|
@ -545,6 +548,10 @@ func (a *agent) run(ctx context.Context) error {
|
|||
}
|
||||
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest))
|
||||
|
||||
if manifest.AgentID == uuid.Nil {
|
||||
return xerrors.New("nil agentID returned by manifest")
|
||||
}
|
||||
|
||||
// Expand the directory and send it back to coderd so external
|
||||
// applications that rely on the directory can use it.
|
||||
//
|
||||
|
@ -630,7 +637,7 @@ func (a *agent) run(ctx context.Context) error {
|
|||
network := a.network
|
||||
a.closeMutex.Unlock()
|
||||
if network == nil {
|
||||
network, err = a.createTailnet(ctx, manifest.DERPMap, manifest.DisableDirectConnections)
|
||||
network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DisableDirectConnections)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tailnet: %w", err)
|
||||
}
|
||||
|
@ -648,6 +655,11 @@ func (a *agent) run(ctx context.Context) error {
|
|||
|
||||
a.startReportingConnectionStats(ctx)
|
||||
} else {
|
||||
// Update the wireguard IPs if the agent ID changed.
|
||||
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "update tailnet addresses", slog.Error(err))
|
||||
}
|
||||
// Update the DERP map and allow/disallow direct connections.
|
||||
network.SetDERPMap(manifest.DERPMap)
|
||||
network.SetBlockEndpoints(manifest.DisableDirectConnections)
|
||||
|
@ -661,6 +673,20 @@ func (a *agent) run(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
|
||||
if len(a.addresses) == 0 {
|
||||
return []netip.Prefix{
|
||||
// This is the IP that should be used primarily.
|
||||
netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
|
||||
// We also listen on the legacy codersdk.WorkspaceAgentIP. This
|
||||
// allows for a transition away from wsconncache.
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
}
|
||||
}
|
||||
|
||||
return a.addresses
|
||||
}
|
||||
|
||||
func (a *agent) trackConnGoroutine(fn func()) error {
|
||||
a.closeMutex.Lock()
|
||||
defer a.closeMutex.Unlock()
|
||||
|
@ -675,9 +701,9 @@ func (a *agent) trackConnGoroutine(fn func()) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
|
||||
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
|
||||
network, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
||||
Addresses: a.wireguardAddresses(agentID),
|
||||
DERPMap: derpMap,
|
||||
Logger: a.logger.Named("tailnet"),
|
||||
ListenPort: a.tailnetListenPort,
|
||||
|
|
|
@ -35,7 +35,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/net/speedtest"
|
||||
|
@ -45,6 +44,7 @@ import (
|
|||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/agent/agenttest"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
|
@ -67,7 +67,7 @@ func TestAgent_Stats_SSH(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, stats, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
|
@ -100,7 +100,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, stats, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
|
||||
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
|
@ -130,7 +130,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
@ -157,7 +157,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, stats, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
@ -425,20 +425,19 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled // Allow the blank identifiers.
|
||||
conn, client, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
// Set new banner func and wait for the agent to call it to update the
|
||||
// banner.
|
||||
ready := make(chan struct{}, 2)
|
||||
client.mu.Lock()
|
||||
client.getServiceBanner = func() (codersdk.ServiceBannerConfig, error) {
|
||||
client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
|
||||
select {
|
||||
case ready <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return test.banner, nil
|
||||
}
|
||||
client.mu.Unlock()
|
||||
})
|
||||
<-ready
|
||||
<-ready // Wait for two updates to ensure the value has propagated.
|
||||
|
||||
|
@ -542,7 +541,7 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
@ -592,7 +591,7 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
@ -922,7 +921,7 @@ func TestAgent_SFTP(t *testing.T) {
|
|||
home = "/" + strings.ReplaceAll(home, "\\", "/")
|
||||
}
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
@ -954,7 +953,7 @@ func TestAgent_SCP(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
@ -1062,16 +1061,15 @@ func TestAgent_StartupScript(t *testing.T) {
|
|||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
client := &client{
|
||||
t: t,
|
||||
agentID: uuid.New(),
|
||||
manifest: agentsdk.Manifest{
|
||||
client := agenttest.NewClient(t,
|
||||
uuid.New(),
|
||||
agentsdk.Manifest{
|
||||
StartupScript: command,
|
||||
DERPMap: &tailcfg.DERPMap{},
|
||||
},
|
||||
statsChan: make(chan *agentsdk.Stats),
|
||||
coordinator: tailnet.NewCoordinator(logger),
|
||||
}
|
||||
make(chan *agentsdk.Stats),
|
||||
tailnet.NewCoordinator(logger),
|
||||
)
|
||||
closer := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
|
@ -1082,36 +1080,35 @@ func TestAgent_StartupScript(t *testing.T) {
|
|||
_ = closer.Close()
|
||||
})
|
||||
assert.Eventually(t, func() bool {
|
||||
got := client.getLifecycleStates()
|
||||
got := client.GetLifecycleStates()
|
||||
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
require.Len(t, client.getStartupLogs(), 1)
|
||||
require.Equal(t, output, client.getStartupLogs()[0].Output)
|
||||
require.Len(t, client.GetStartupLogs(), 1)
|
||||
require.Equal(t, output, client.GetStartupLogs()[0].Output)
|
||||
})
|
||||
// This ensures that even when coderd sends back that the startup
|
||||
// script has written too many lines it will still succeed!
|
||||
t.Run("OverflowsAndSkips", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
client := &client{
|
||||
t: t,
|
||||
agentID: uuid.New(),
|
||||
manifest: agentsdk.Manifest{
|
||||
client := agenttest.NewClient(t,
|
||||
uuid.New(),
|
||||
agentsdk.Manifest{
|
||||
StartupScript: command,
|
||||
DERPMap: &tailcfg.DERPMap{},
|
||||
},
|
||||
patchWorkspaceLogs: func() error {
|
||||
resp := httptest.NewRecorder()
|
||||
httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Too many lines!",
|
||||
})
|
||||
res := resp.Result()
|
||||
defer res.Body.Close()
|
||||
return codersdk.ReadBodyAsError(res)
|
||||
},
|
||||
statsChan: make(chan *agentsdk.Stats),
|
||||
coordinator: tailnet.NewCoordinator(logger),
|
||||
make(chan *agentsdk.Stats, 50),
|
||||
tailnet.NewCoordinator(logger),
|
||||
)
|
||||
client.PatchWorkspaceLogs = func() error {
|
||||
resp := httptest.NewRecorder()
|
||||
httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Too many lines!",
|
||||
})
|
||||
res := resp.Result()
|
||||
defer res.Body.Close()
|
||||
return codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
closer := agent.New(agent.Options{
|
||||
Client: client,
|
||||
|
@ -1123,10 +1120,10 @@ func TestAgent_StartupScript(t *testing.T) {
|
|||
_ = closer.Close()
|
||||
})
|
||||
assert.Eventually(t, func() bool {
|
||||
got := client.getLifecycleStates()
|
||||
got := client.GetLifecycleStates()
|
||||
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
require.Len(t, client.getStartupLogs(), 0)
|
||||
require.Len(t, client.GetStartupLogs(), 0)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1138,28 +1135,26 @@ func TestAgent_Metadata(t *testing.T) {
|
|||
t.Run("Once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
//nolint:dogsled
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||||
{
|
||||
Key: "greeting",
|
||||
Interval: 0,
|
||||
Script: echoHello,
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||||
{
|
||||
Key: "greeting",
|
||||
Interval: 0,
|
||||
Script: echoHello,
|
||||
},
|
||||
},
|
||||
}, 0)
|
||||
|
||||
var gotMd map[string]agentsdk.PostMetadataRequest
|
||||
require.Eventually(t, func() bool {
|
||||
gotMd = client.getMetadata()
|
||||
gotMd = client.GetMetadata()
|
||||
return len(gotMd) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
collectedAt := gotMd["greeting"].CollectedAt
|
||||
|
||||
require.Never(t, func() bool {
|
||||
gotMd = client.getMetadata()
|
||||
gotMd = client.GetMetadata()
|
||||
if len(gotMd) != 1 {
|
||||
panic("unexpected number of metadata")
|
||||
}
|
||||
|
@ -1170,22 +1165,20 @@ func TestAgent_Metadata(t *testing.T) {
|
|||
t.Run("Many", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
//nolint:dogsled
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||||
{
|
||||
Key: "greeting",
|
||||
Interval: 1,
|
||||
Timeout: 100,
|
||||
Script: echoHello,
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||||
{
|
||||
Key: "greeting",
|
||||
Interval: 1,
|
||||
Timeout: 100,
|
||||
Script: echoHello,
|
||||
},
|
||||
},
|
||||
}, 0)
|
||||
|
||||
var gotMd map[string]agentsdk.PostMetadataRequest
|
||||
require.Eventually(t, func() bool {
|
||||
gotMd = client.getMetadata()
|
||||
gotMd = client.GetMetadata()
|
||||
return len(gotMd) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1195,7 +1188,7 @@ func TestAgent_Metadata(t *testing.T) {
|
|||
}
|
||||
|
||||
if !assert.Eventually(t, func() bool {
|
||||
gotMd = client.getMetadata()
|
||||
gotMd = client.GetMetadata()
|
||||
return gotMd["greeting"].CollectedAt.After(collectedAt1)
|
||||
}, testutil.WaitShort, testutil.IntervalMedium) {
|
||||
t.Fatalf("expected metadata to be collected again")
|
||||
|
@ -1221,29 +1214,27 @@ func TestAgentMetadata_Timing(t *testing.T) {
|
|||
script = "echo hello | tee -a " + greetingPath
|
||||
)
|
||||
//nolint:dogsled
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||||
{
|
||||
Key: "greeting",
|
||||
Interval: reportInterval,
|
||||
Script: script,
|
||||
},
|
||||
{
|
||||
Key: "bad",
|
||||
Interval: reportInterval,
|
||||
Script: "exit 1",
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
|
||||
{
|
||||
Key: "greeting",
|
||||
Interval: reportInterval,
|
||||
Script: script,
|
||||
},
|
||||
{
|
||||
Key: "bad",
|
||||
Interval: reportInterval,
|
||||
Script: "exit 1",
|
||||
},
|
||||
},
|
||||
}, 0)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(client.getMetadata()) == 2
|
||||
return len(client.GetMetadata()) == 2
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) {
|
||||
md := client.getMetadata()
|
||||
md := client.GetMetadata()
|
||||
require.Len(t, md, 2, "got: %+v", md)
|
||||
|
||||
require.Equal(t, "hello\n", md["greeting"].Value)
|
||||
|
@ -1285,11 +1276,9 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Run("StartTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "sleep 3",
|
||||
StartupScriptTimeout: time.Nanosecond,
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "sleep 3",
|
||||
StartupScriptTimeout: time.Nanosecond,
|
||||
}, 0)
|
||||
|
||||
want := []codersdk.WorkspaceAgentLifecycle{
|
||||
|
@ -1299,7 +1288,7 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
|
||||
var got []codersdk.WorkspaceAgentLifecycle
|
||||
assert.Eventually(t, func() bool {
|
||||
got = client.getLifecycleStates()
|
||||
got = client.GetLifecycleStates()
|
||||
return slices.Contains(got, want[len(want)-1])
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1309,11 +1298,9 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Run("StartError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "false",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "false",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
}, 0)
|
||||
|
||||
want := []codersdk.WorkspaceAgentLifecycle{
|
||||
|
@ -1323,7 +1310,7 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
|
||||
var got []codersdk.WorkspaceAgentLifecycle
|
||||
assert.Eventually(t, func() bool {
|
||||
got = client.getLifecycleStates()
|
||||
got = client.GetLifecycleStates()
|
||||
return slices.Contains(got, want[len(want)-1])
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1333,11 +1320,9 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Run("Ready", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
}, 0)
|
||||
|
||||
want := []codersdk.WorkspaceAgentLifecycle{
|
||||
|
@ -1347,7 +1332,7 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
|
||||
var got []codersdk.WorkspaceAgentLifecycle
|
||||
assert.Eventually(t, func() bool {
|
||||
got = client.getLifecycleStates()
|
||||
got = client.GetLifecycleStates()
|
||||
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1357,15 +1342,13 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Run("ShuttingDown", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, closer := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
ShutdownScript: "sleep 3",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
},
|
||||
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
||||
ShutdownScript: "sleep 3",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
}, 0)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
// Start close asynchronously so that we an inspect the state.
|
||||
|
@ -1387,7 +1370,7 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
|
||||
var got []codersdk.WorkspaceAgentLifecycle
|
||||
assert.Eventually(t, func() bool {
|
||||
got = client.getLifecycleStates()
|
||||
got = client.GetLifecycleStates()
|
||||
return slices.Contains(got, want[len(want)-1])
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1397,15 +1380,13 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Run("ShutdownTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, closer := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
ShutdownScript: "sleep 3",
|
||||
ShutdownScriptTimeout: time.Nanosecond,
|
||||
},
|
||||
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
||||
ShutdownScript: "sleep 3",
|
||||
ShutdownScriptTimeout: time.Nanosecond,
|
||||
}, 0)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
// Start close asynchronously so that we an inspect the state.
|
||||
|
@ -1428,7 +1409,7 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
|
||||
var got []codersdk.WorkspaceAgentLifecycle
|
||||
assert.Eventually(t, func() bool {
|
||||
got = client.getLifecycleStates()
|
||||
got = client.GetLifecycleStates()
|
||||
return slices.Contains(got, want[len(want)-1])
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1438,15 +1419,13 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Run("ShutdownError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, closer := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
ShutdownScript: "false",
|
||||
ShutdownScriptTimeout: 30 * time.Second,
|
||||
},
|
||||
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
|
||||
ShutdownScript: "false",
|
||||
ShutdownScriptTimeout: 30 * time.Second,
|
||||
}, 0)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
// Start close asynchronously so that we an inspect the state.
|
||||
|
@ -1469,7 +1448,7 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
|
||||
var got []codersdk.WorkspaceAgentLifecycle
|
||||
assert.Eventually(t, func() bool {
|
||||
got = client.getLifecycleStates()
|
||||
got = client.GetLifecycleStates()
|
||||
return slices.Contains(got, want[len(want)-1])
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
|
@ -1480,17 +1459,18 @@ func TestAgent_Lifecycle(t *testing.T) {
|
|||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
expected := "this-is-shutdown"
|
||||
client := &client{
|
||||
t: t,
|
||||
agentID: uuid.New(),
|
||||
manifest: agentsdk.Manifest{
|
||||
DERPMap: tailnettest.RunDERPAndSTUN(t),
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
client := agenttest.NewClient(t,
|
||||
uuid.New(),
|
||||
agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
StartupScript: "echo 1",
|
||||
ShutdownScript: "echo " + expected,
|
||||
},
|
||||
statsChan: make(chan *agentsdk.Stats),
|
||||
coordinator: tailnet.NewCoordinator(logger),
|
||||
}
|
||||
make(chan *agentsdk.Stats, 50),
|
||||
tailnet.NewCoordinator(logger),
|
||||
)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
agent := agent.New(agent.Options{
|
||||
|
@ -1536,71 +1516,63 @@ func TestAgent_Startup(t *testing.T) {
|
|||
t.Run("EmptyDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "",
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "",
|
||||
}, 0)
|
||||
assert.Eventually(t, func() bool {
|
||||
return client.getStartup().Version != ""
|
||||
return client.GetStartup().Version != ""
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.Equal(t, "", client.getStartup().ExpandedDirectory)
|
||||
require.Equal(t, "", client.GetStartup().ExpandedDirectory)
|
||||
})
|
||||
|
||||
t.Run("HomeDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "~",
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "~",
|
||||
}, 0)
|
||||
assert.Eventually(t, func() bool {
|
||||
return client.getStartup().Version != ""
|
||||
return client.GetStartup().Version != ""
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, homeDir, client.getStartup().ExpandedDirectory)
|
||||
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
|
||||
})
|
||||
|
||||
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "coder/coder",
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "coder/coder",
|
||||
}, 0)
|
||||
assert.Eventually(t, func() bool {
|
||||
return client.getStartup().Version != ""
|
||||
return client.GetStartup().Version != ""
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.getStartup().ExpandedDirectory)
|
||||
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory)
|
||||
})
|
||||
|
||||
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, client, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "$HOME",
|
||||
},
|
||||
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
StartupScript: "true",
|
||||
StartupScriptTimeout: 30 * time.Second,
|
||||
Directory: "$HOME",
|
||||
}, 0)
|
||||
assert.Eventually(t, func() bool {
|
||||
return client.getStartup().Version != ""
|
||||
return client.GetStartup().Version != ""
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, homeDir, client.getStartup().ExpandedDirectory)
|
||||
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1617,7 +1589,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
id := uuid.New()
|
||||
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
|
@ -1719,7 +1691,7 @@ func TestAgent_Dial(t *testing.T) {
|
|||
}()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
require.True(t, conn.AwaitReachable(context.Background()))
|
||||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
@ -1739,12 +1711,10 @@ func TestAgent_Speedtest(t *testing.T) {
|
|||
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{
|
||||
manifest: agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
},
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
defer conn.Close()
|
||||
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
|
||||
|
@ -1761,17 +1731,16 @@ func TestAgent_Reconnect(t *testing.T) {
|
|||
defer coordinator.Close()
|
||||
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *agentsdk.Stats)
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
client := &client{
|
||||
t: t,
|
||||
agentID: agentID,
|
||||
manifest: agentsdk.Manifest{
|
||||
statsCh := make(chan *agentsdk.Stats, 50)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
client := agenttest.NewClient(t,
|
||||
agentID,
|
||||
agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
},
|
||||
statsChan: statsCh,
|
||||
coordinator: coordinator,
|
||||
}
|
||||
statsCh,
|
||||
coordinator,
|
||||
)
|
||||
initialized := atomic.Int32{}
|
||||
closer := agent.New(agent.Options{
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
|
@ -1786,7 +1755,7 @@ func TestAgent_Reconnect(t *testing.T) {
|
|||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(agentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
client.lastWorkspaceAgent()
|
||||
client.LastWorkspaceAgent()
|
||||
require.Eventually(t, func() bool {
|
||||
return initialized.Load() == 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
@ -1798,16 +1767,15 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
|||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer coordinator.Close()
|
||||
|
||||
client := &client{
|
||||
t: t,
|
||||
agentID: uuid.New(),
|
||||
manifest: agentsdk.Manifest{
|
||||
client := agenttest.NewClient(t,
|
||||
uuid.New(),
|
||||
agentsdk.Manifest{
|
||||
GitAuthConfigs: 1,
|
||||
DERPMap: &tailcfg.DERPMap{},
|
||||
},
|
||||
statsChan: make(chan *agentsdk.Stats),
|
||||
coordinator: coordinator,
|
||||
}
|
||||
make(chan *agentsdk.Stats, 50),
|
||||
coordinator,
|
||||
)
|
||||
filesystem := afero.NewMemMapFs()
|
||||
closer := agent.New(agent.Options{
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
|
@ -1830,7 +1798,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
|||
|
||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, &client{}, 0)
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
waitGroup := sync.WaitGroup{}
|
||||
|
@ -1883,12 +1851,11 @@ func setupSSHSession(
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, _, fs, _ := setupAgent(t, &client{
|
||||
manifest: options,
|
||||
getServiceBanner: func() (codersdk.ServiceBannerConfig, error) {
|
||||
conn, _, _, fs, _ := setupAgent(t, options, 0, func(c *agenttest.Client, _ *agent.Options) {
|
||||
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
|
||||
return serviceBanner, nil
|
||||
},
|
||||
}, 0)
|
||||
})
|
||||
})
|
||||
if prepareFS != nil {
|
||||
prepareFS(fs)
|
||||
}
|
||||
|
@ -1905,31 +1872,28 @@ func setupSSHSession(
|
|||
return session
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
|
||||
func (c closeFunc) Close() error {
|
||||
return c()
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) (
|
||||
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
|
||||
*codersdk.WorkspaceAgentConn,
|
||||
*client,
|
||||
*agenttest.Client,
|
||||
<-chan *agentsdk.Stats,
|
||||
afero.Fs,
|
||||
io.Closer,
|
||||
) {
|
||||
c.t = t
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
if c.manifest.DERPMap == nil {
|
||||
c.manifest.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
if metadata.DERPMap == nil {
|
||||
metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
c.coordinator = tailnet.NewCoordinator(logger)
|
||||
if metadata.AgentID == uuid.Nil {
|
||||
metadata.AgentID = uuid.New()
|
||||
}
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
t.Cleanup(func() {
|
||||
_ = c.coordinator.Close()
|
||||
_ = coordinator.Close()
|
||||
})
|
||||
c.agentID = uuid.New()
|
||||
c.statsChan = make(chan *agentsdk.Stats, 50)
|
||||
statsCh := make(chan *agentsdk.Stats, 50)
|
||||
fs := afero.NewMemMapFs()
|
||||
c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator)
|
||||
|
||||
options := agent.Options{
|
||||
Client: c,
|
||||
Filesystem: fs,
|
||||
|
@ -1938,7 +1902,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
|
|||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
options = opt(options)
|
||||
opt(c, &options)
|
||||
}
|
||||
|
||||
closer := agent.New(options)
|
||||
|
@ -1947,7 +1911,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
|
|||
})
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: c.manifest.DERPMap,
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: logger.Named("client"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -1961,15 +1925,15 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
|
|||
})
|
||||
go func() {
|
||||
defer close(serveClientDone)
|
||||
c.coordinator.ServeClient(serverConn, uuid.New(), c.agentID)
|
||||
coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID)
|
||||
}()
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
agentConn := &codersdk.WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
}
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: metadata.AgentID,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentConn.Close()
|
||||
})
|
||||
|
@ -1980,7 +1944,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
|
|||
if !agentConn.AwaitReachable(ctx) {
|
||||
t.Fatal("agent not reachable")
|
||||
}
|
||||
return agentConn, c, c.statsChan, fs, closer
|
||||
return agentConn, c, statsCh, fs, closer
|
||||
}
|
||||
|
||||
var dialTestPayload = []byte("dean-was-here123")
|
||||
|
@ -2043,146 +2007,6 @@ func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected
|
|||
}
|
||||
}
|
||||
|
||||
type client struct {
|
||||
t *testing.T
|
||||
agentID uuid.UUID
|
||||
manifest agentsdk.Manifest
|
||||
metadata map[string]agentsdk.PostMetadataRequest
|
||||
statsChan chan *agentsdk.Stats
|
||||
coordinator tailnet.Coordinator
|
||||
lastWorkspaceAgent func()
|
||||
patchWorkspaceLogs func() error
|
||||
getServiceBanner func() (codersdk.ServiceBannerConfig, error)
|
||||
|
||||
mu sync.Mutex // Protects following.
|
||||
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
||||
startup agentsdk.PostStartupRequest
|
||||
logs []agentsdk.StartupLog
|
||||
}
|
||||
|
||||
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
|
||||
return c.manifest, nil
|
||||
}
|
||||
|
||||
func (c *client) Listen(_ context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
closed := make(chan struct{})
|
||||
c.lastWorkspaceAgent = func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
<-closed
|
||||
}
|
||||
c.t.Cleanup(c.lastWorkspaceAgent)
|
||||
go func() {
|
||||
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
|
||||
doneCh := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
|
||||
setInterval(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case stat := <-statsChan:
|
||||
select {
|
||||
case c.statsChan <- stat:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// We don't want to send old stats.
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return closeFunc(func() error {
|
||||
cancel()
|
||||
<-doneCh
|
||||
close(c.statsChan)
|
||||
return nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (c *client) getLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.lifecycleStates
|
||||
}
|
||||
|
||||
func (c *client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.lifecycleStates = append(c.lifecycleStates, req.State)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) getStartup() agentsdk.PostStartupRequest {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.startup
|
||||
}
|
||||
|
||||
func (c *client) getMetadata() map[string]agentsdk.PostMetadataRequest {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return maps.Clone(c.metadata)
|
||||
}
|
||||
|
||||
func (c *client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.metadata == nil {
|
||||
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
|
||||
}
|
||||
c.metadata[key] = req
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.startup = startup
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) getStartupLogs() []agentsdk.StartupLog {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.logs
|
||||
}
|
||||
|
||||
func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.patchWorkspaceLogs != nil {
|
||||
return c.patchWorkspaceLogs()
|
||||
}
|
||||
c.logs = append(c.logs, logs.Logs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.getServiceBanner != nil {
|
||||
return c.getServiceBanner()
|
||||
}
|
||||
return codersdk.ServiceBannerConfig{}, nil
|
||||
}
|
||||
|
||||
// tempDirUnixSocket returns a temporary directory that can safely hold unix
|
||||
// sockets (probably).
|
||||
//
|
||||
|
@ -2214,9 +2038,8 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
|||
registry := prometheus.NewRegistry()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, &client{}, 0, func(o agent.Options) agent.Options {
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.PrometheusRegistry = registry
|
||||
return o
|
||||
})
|
||||
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
package agenttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
func NewClient(t testing.TB,
|
||||
agentID uuid.UUID,
|
||||
manifest agentsdk.Manifest,
|
||||
statsChan chan *agentsdk.Stats,
|
||||
coordinator tailnet.Coordinator,
|
||||
) *Client {
|
||||
if manifest.AgentID == uuid.Nil {
|
||||
manifest.AgentID = agentID
|
||||
}
|
||||
return &Client{
|
||||
t: t,
|
||||
agentID: agentID,
|
||||
manifest: manifest,
|
||||
statsChan: statsChan,
|
||||
coordinator: coordinator,
|
||||
}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
t testing.TB
|
||||
agentID uuid.UUID
|
||||
manifest agentsdk.Manifest
|
||||
metadata map[string]agentsdk.PostMetadataRequest
|
||||
statsChan chan *agentsdk.Stats
|
||||
coordinator tailnet.Coordinator
|
||||
LastWorkspaceAgent func()
|
||||
PatchWorkspaceLogs func() error
|
||||
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
|
||||
|
||||
mu sync.Mutex // Protects following.
|
||||
lifecycleStates []codersdk.WorkspaceAgentLifecycle
|
||||
startup agentsdk.PostStartupRequest
|
||||
logs []agentsdk.StartupLog
|
||||
}
|
||||
|
||||
func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
|
||||
return c.manifest, nil
|
||||
}
|
||||
|
||||
func (c *Client) Listen(_ context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
closed := make(chan struct{})
|
||||
c.LastWorkspaceAgent = func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
<-closed
|
||||
}
|
||||
c.t.Cleanup(c.LastWorkspaceAgent)
|
||||
go func() {
|
||||
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
|
||||
doneCh := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
|
||||
setInterval(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case stat := <-statsChan:
|
||||
select {
|
||||
case c.statsChan <- stat:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// We don't want to send old stats.
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return closeFunc(func() error {
|
||||
cancel()
|
||||
<-doneCh
|
||||
close(c.statsChan)
|
||||
return nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.lifecycleStates
|
||||
}
|
||||
|
||||
func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.lifecycleStates = append(c.lifecycleStates, req.State)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) GetStartup() agentsdk.PostStartupRequest {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.startup
|
||||
}
|
||||
|
||||
func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return maps.Clone(c.metadata)
|
||||
}
|
||||
|
||||
func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.metadata == nil {
|
||||
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
|
||||
}
|
||||
c.metadata[key] = req
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.startup = startup
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) GetStartupLogs() []agentsdk.StartupLog {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.logs
|
||||
}
|
||||
|
||||
func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.PatchWorkspaceLogs != nil {
|
||||
return c.PatchWorkspaceLogs()
|
||||
}
|
||||
c.logs = append(c.logs, logs.Logs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.GetServiceBannerFunc = f
|
||||
}
|
||||
|
||||
func (c *Client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.GetServiceBannerFunc != nil {
|
||||
return c.GetServiceBannerFunc()
|
||||
}
|
||||
return codersdk.ServiceBannerConfig{}, nil
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
|
||||
func (c closeFunc) Close() error {
|
||||
return c()
|
||||
}
|
|
@ -5961,6 +5961,9 @@ const docTemplate = `{
|
|||
"agentsdk.Manifest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"apps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
@ -7617,6 +7620,7 @@ const docTemplate = `{
|
|||
"workspace_actions",
|
||||
"tailnet_ha_coordinator",
|
||||
"convert-to-oidc",
|
||||
"single_tailnet",
|
||||
"workspace_build_logs_ui"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
|
@ -7624,6 +7628,7 @@ const docTemplate = `{
|
|||
"ExperimentWorkspaceActions",
|
||||
"ExperimentTailnetHACoordinator",
|
||||
"ExperimentConvertToOIDC",
|
||||
"ExperimentSingleTailnet",
|
||||
"ExperimentWorkspaceBuildLogsUI"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -5251,6 +5251,9 @@
|
|||
"agentsdk.Manifest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"apps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
@ -6818,6 +6821,7 @@
|
|||
"workspace_actions",
|
||||
"tailnet_ha_coordinator",
|
||||
"convert-to-oidc",
|
||||
"single_tailnet",
|
||||
"workspace_build_logs_ui"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
|
@ -6825,6 +6829,7 @@
|
|||
"ExperimentWorkspaceActions",
|
||||
"ExperimentTailnetHACoordinator",
|
||||
"ExperimentConvertToOIDC",
|
||||
"ExperimentSingleTailnet",
|
||||
"ExperimentWorkspaceBuildLogsUI"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -364,8 +364,23 @@ func New(options *Options) *API {
|
|||
}
|
||||
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
|
||||
if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) {
|
||||
api.agentProvider, err = NewServerTailnet(api.ctx,
|
||||
options.Logger,
|
||||
options.DERPServer,
|
||||
options.DERPMap,
|
||||
&api.TailnetCoordinator,
|
||||
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
|
||||
)
|
||||
if err != nil {
|
||||
panic("failed to setup server tailnet: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
api.agentProvider = &wsconncache.AgentProvider{
|
||||
Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
|
||||
}
|
||||
}
|
||||
|
||||
api.workspaceAppServer = &workspaceapps.Server{
|
||||
Logger: options.Logger.Named("workspaceapps"),
|
||||
|
@ -377,7 +392,7 @@ func New(options *Options) *API {
|
|||
RealIPConfig: options.RealIPConfig,
|
||||
|
||||
SignedTokenProvider: api.WorkspaceAppsProvider,
|
||||
WorkspaceConnCache: api.workspaceAgentCache,
|
||||
AgentProvider: api.agentProvider,
|
||||
AppSecurityKey: options.AppSecurityKey,
|
||||
|
||||
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
|
||||
|
@ -921,10 +936,10 @@ type API struct {
|
|||
derpCloseFunc func()
|
||||
|
||||
metricsCache *metricscache.Cache
|
||||
workspaceAgentCache *wsconncache.Cache
|
||||
updateChecker *updatecheck.Checker
|
||||
WorkspaceAppsProvider workspaceapps.SignedTokenProvider
|
||||
workspaceAppServer *workspaceapps.Server
|
||||
agentProvider workspaceapps.AgentProvider
|
||||
|
||||
// Experiments contains the list of experiments currently enabled.
|
||||
// This is used to gate features that are not yet ready for production.
|
||||
|
@ -951,7 +966,8 @@ func (api *API) Close() error {
|
|||
if coordinator != nil {
|
||||
_ = (*coordinator).Close()
|
||||
}
|
||||
return api.workspaceAgentCache.Close()
|
||||
_ = api.agentProvider.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func compressHandler(h http.Handler) http.Handler {
|
||||
|
|
|
@ -109,6 +109,7 @@ type Options struct {
|
|||
GitAuthConfigs []*gitauth.Config
|
||||
TrialGenerator func(context.Context, string) error
|
||||
TemplateScheduleStore schedule.TemplateScheduleStore
|
||||
Coordinator tailnet.Coordinator
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report
|
||||
HealthcheckTimeout time.Duration
|
||||
|
|
|
@ -302,7 +302,7 @@ func TestAgents(t *testing.T) {
|
|||
coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug))
|
||||
coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordinatorPtr.Store(&coordinator)
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests
|
||||
registry := prometheus.NewRegistry()
|
||||
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
package coderd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
var tailnetTransport *http.Transport
|
||||
|
||||
func init() {
|
||||
var valid bool
|
||||
tailnetTransport, valid = http.DefaultTransport.(*http.Transport)
|
||||
if !valid {
|
||||
panic("dev error: default transport is the wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
// NewServerTailnet creates a new tailnet intended for use by coderd. It
|
||||
// automatically falls back to wsconncache if a legacy agent is encountered.
|
||||
func NewServerTailnet(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
derpServer *derp.Server,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
coord *atomic.Pointer[tailnet.Coordinator],
|
||||
cache *wsconncache.Cache,
|
||||
) (*ServerTailnet, error) {
|
||||
logger = logger.Named("servertailnet")
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: derpMap,
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
||||
}
|
||||
|
||||
serverCtx, cancel := context.WithCancel(ctx)
|
||||
tn := &ServerTailnet{
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
conn: conn,
|
||||
coord: coord,
|
||||
cache: cache,
|
||||
agentNodes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
tn.transport.MaxIdleConnsPerHost = 10
|
||||
tn.transport.MaxIdleConns = 0
|
||||
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
|
||||
tn.agentConn.Store(&agentConn)
|
||||
|
||||
err = tn.getAgentConn().UpdateSelf(conn.Node())
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
|
||||
}
|
||||
conn.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := tn.getAgentConn().UpdateSelf(node)
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
|
||||
}
|
||||
})
|
||||
|
||||
// This is set to allow local DERP traffic to be proxied through memory
|
||||
// instead of needing to hit the external access URL. Don't use the ctx
|
||||
// given in this callback, it's only valid while connecting.
|
||||
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
|
||||
if !region.EmbeddedRelay {
|
||||
return nil
|
||||
}
|
||||
left, right := net.Pipe()
|
||||
go func() {
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
|
||||
derpServer.Accept(ctx, right, brw, "internal")
|
||||
}()
|
||||
return left
|
||||
})
|
||||
|
||||
go tn.watchAgentUpdates()
|
||||
go tn.expireOldAgents()
|
||||
return tn, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) expireOldAgents() {
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.nodesMu.Lock()
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentNodes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
_ = agentConn
|
||||
// err := agentConn.UnsubscribeAgent(agentID)
|
||||
// if err != nil {
|
||||
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
// }
|
||||
// delete(s.agentNodes, agentID)
|
||||
|
||||
// TODO(coadler): actually remove from the netmap, then reenable
|
||||
// the above
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
for {
|
||||
conn := s.getAgentConn()
|
||||
nodes, ok := conn.NextUpdate(s.ctx)
|
||||
if !ok {
|
||||
if conn.IsClosed() && s.ctx.Err() == nil {
|
||||
s.reinitCoordinator()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.conn.UpdateNodes(nodes, false)
|
||||
if err != nil {
|
||||
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
|
||||
return *s.agentConn.Load()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) reinitCoordinator() {
|
||||
s.nodesMu.Lock()
|
||||
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
|
||||
s.agentConn.Store(&agentConn)
|
||||
|
||||
// Resubscribe to all of the agents we're tracking.
|
||||
for agentID := range s.agentNodes {
|
||||
err := agentConn.SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
|
||||
type ServerTailnet struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
logger slog.Logger
|
||||
conn *tailnet.Conn
|
||||
coord *atomic.Pointer[tailnet.Coordinator]
|
||||
agentConn atomic.Pointer[tailnet.MultiAgentConn]
|
||||
cache *wsconncache.Cache
|
||||
nodesMu sync.Mutex
|
||||
// agentNodes is a map of agent tailnetNodes the server wants to keep a
|
||||
// connection to. It contains the last time the agent was connected to.
|
||||
agentNodes map[uuid.UUID]time.Time
|
||||
// agentTockets holds a map of all open connections to an agent.
|
||||
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: dashboardURL.String(),
|
||||
})
|
||||
}
|
||||
proxy.Director = s.director(agentID, proxy.Director)
|
||||
proxy.Transport = s.transport
|
||||
|
||||
return proxy, func() {}, nil
|
||||
}
|
||||
|
||||
type agentIDKey struct{}
|
||||
|
||||
// director makes sure agentIDKey is set on the context in the reverse proxy.
|
||||
// This allows the transport to correctly identify which agent to dial to.
|
||||
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
|
||||
return func(req *http.Request) {
|
||||
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
|
||||
*req = *req.WithContext(ctx)
|
||||
prev(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("no agent id attached")
|
||||
}
|
||||
|
||||
return s.DialAgentNetConn(ctx, agentID, network, addr)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.nodesMu.Lock()
|
||||
defer s.nodesMu.Unlock()
|
||||
|
||||
_, ok := s.agentNodes[agentID]
|
||||
// If we don't have the node, subscribe.
|
||||
if !ok {
|
||||
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
|
||||
err := s.getAgentConn().SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
|
||||
s.agentNodes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
|
||||
var (
|
||||
conn *codersdk.WorkspaceAgentConn
|
||||
ret = func() {}
|
||||
)
|
||||
|
||||
if s.getAgentConn().AgentIsLegacy(agentID) {
|
||||
s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID))
|
||||
cconn, release, err := s.cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err)
|
||||
}
|
||||
|
||||
conn = cconn.WorkspaceAgentConn
|
||||
ret = release
|
||||
} else {
|
||||
err := s.ensureAgent(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
})
|
||||
}
|
||||
|
||||
// Since we now have an open conn, be careful to close it if we error
|
||||
// without returning it to the user.
|
||||
|
||||
reachable := conn.AwaitReachable(ctx)
|
||||
if !reachable {
|
||||
ret()
|
||||
conn.Close()
|
||||
return nil, nil, xerrors.New("agent is unreachable")
|
||||
}
|
||||
|
||||
return conn, ret, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
|
||||
conn, release, err := s.AgentConn(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("acquire agent conn: %w", err)
|
||||
}
|
||||
|
||||
// Since we now have an open conn, be careful to close it if we error
|
||||
// without returning it to the user.
|
||||
|
||||
nc, err := conn.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
release()
|
||||
conn.Close()
|
||||
return nil, xerrors.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
return &netConnCloser{Conn: nc, close: func() {
|
||||
release()
|
||||
conn.Close()
|
||||
}}, err
|
||||
}
|
||||
|
||||
type netConnCloser struct {
|
||||
net.Conn
|
||||
close func()
|
||||
}
|
||||
|
||||
func (c *netConnCloser) Close() error {
|
||||
c.close()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) Close() error {
|
||||
s.cancel()
|
||||
_ = s.cache.Close()
|
||||
_ = s.conn.Close()
|
||||
s.transport.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agenttest"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestServerTailnet_AgentConn_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
// Connect through the ServerTailnet
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
|
||||
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
assert.True(t, conn.AwaitReachable(ctx))
|
||||
}
|
||||
|
||||
func TestServerTailnet_AgentConn_Legacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
// Force a connection through wsconncache using the legacy hardcoded ip.
|
||||
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
})
|
||||
|
||||
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
assert.True(t, conn.AwaitReachable(ctx))
|
||||
}
|
||||
|
||||
func TestServerTailnet_ReverseProxy_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Force a connection through wsconncache using the legacy hardcoded ip.
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
u.String(),
|
||||
nil,
|
||||
).WithContext(ctx)
|
||||
|
||||
rp.ServeHTTP(rw, req)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Force a connection through wsconncache using the legacy hardcoded ip.
|
||||
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
})
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
u.String(),
|
||||
nil,
|
||||
).WithContext(ctx)
|
||||
|
||||
rp.ServeHTTP(rw, req)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
|
||||
manifest := agentsdk.Manifest{
|
||||
AgentID: uuid.New(),
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
|
||||
var coordPtr atomic.Pointer[tailnet.Coordinator]
|
||||
coord := tailnet.NewCoordinator(logger)
|
||||
coordPtr.Store(&coord)
|
||||
t.Cleanup(func() {
|
||||
_ = coord.Close()
|
||||
})
|
||||
|
||||
c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
|
||||
|
||||
options := agent.Options{
|
||||
Client: c,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
Logger: logger.Named("agent"),
|
||||
Addresses: agentAddresses,
|
||||
}
|
||||
|
||||
ag := agent.New(options)
|
||||
t.Cleanup(func() {
|
||||
_ = ag.Close()
|
||||
})
|
||||
|
||||
// Wait for the agent to connect.
|
||||
require.Eventually(t, func() bool {
|
||||
return coord.Node(manifest.AgentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: manifest.DERPMap,
|
||||
Logger: logger.Named("client"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
serveClientDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
<-serveClientDone
|
||||
})
|
||||
go func() {
|
||||
defer close(serveClientDone)
|
||||
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID)
|
||||
}()
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: manifest.AgentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
}), nil
|
||||
}, 0)
|
||||
|
||||
serverTailnet, err := coderd.NewServerTailnet(
|
||||
context.Background(),
|
||||
logger,
|
||||
derpServer,
|
||||
manifest.DERPMap,
|
||||
&coordPtr,
|
||||
cache,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = serverTailnet.Close()
|
||||
})
|
||||
|
||||
return manifest.AgentID, ag, serverTailnet
|
||||
}
|
|
@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{
|
||||
AgentID: apiAgent.ID,
|
||||
Apps: convertApps(dbApps),
|
||||
DERPMap: api.DERPMap,
|
||||
GitAuthConfigs: len(api.GitAuthConfigs),
|
||||
|
@ -654,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
|||
return
|
||||
}
|
||||
|
||||
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
|
||||
agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error dialing workspace agent.",
|
||||
|
@ -729,7 +730,9 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
|||
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
|
||||
}
|
||||
|
||||
func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
// Deprecated: use api.tailnet.AgentConn instead.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
|
@ -765,14 +768,16 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac
|
|||
return nil
|
||||
})
|
||||
conn.SetNodeCallback(sendNodes)
|
||||
agentConn := &codersdk.WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
CloseFunc: func() {
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
CloseFunc: func() error {
|
||||
cancel()
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
go func() {
|
||||
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
|
||||
if err != nil {
|
||||
|
|
|
@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques
|
|||
return resp, err
|
||||
}
|
||||
|
||||
func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
|
||||
func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
|
||||
t.Helper()
|
||||
var resp *http.Response
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
|
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
)
|
||||
|
@ -61,6 +60,22 @@ var nonCanonicalHeaders = map[string]string{
|
|||
"Sec-Websocket-Version": "Sec-WebSocket-Version",
|
||||
}
|
||||
|
||||
type AgentProvider interface {
|
||||
// ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests
|
||||
// to the specified agent.
|
||||
//
|
||||
// TODO: after wsconncache is deleted this doesn't need to return an error.
|
||||
ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error)
|
||||
|
||||
// AgentConn returns a new connection to the specified agent.
|
||||
//
|
||||
// TODO: after wsconncache is deleted this doesn't need to return a release
|
||||
// func.
|
||||
AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Server serves workspace apps endpoints, including:
|
||||
// - Path-based apps
|
||||
// - Subdomain app middleware
|
||||
|
@ -83,7 +98,6 @@ type Server struct {
|
|||
RealIPConfig *httpmw.RealIPConfig
|
||||
|
||||
SignedTokenProvider SignedTokenProvider
|
||||
WorkspaceConnCache *wsconncache.Cache
|
||||
AppSecurityKey SecurityKey
|
||||
|
||||
// DisablePathApps disables path-based apps. This is a security feature as path
|
||||
|
@ -95,6 +109,8 @@ type Server struct {
|
|||
DisablePathApps bool
|
||||
SecureAuthCookie bool
|
||||
|
||||
AgentProvider AgentProvider
|
||||
|
||||
websocketWaitMutex sync.Mutex
|
||||
websocketWaitGroup sync.WaitGroup
|
||||
}
|
||||
|
@ -106,8 +122,8 @@ func (s *Server) Close() error {
|
|||
s.websocketWaitGroup.Wait()
|
||||
s.websocketWaitMutex.Unlock()
|
||||
|
||||
// The caller must close the SignedTokenProvider (if necessary) and the
|
||||
// wsconncache.
|
||||
// The caller must close the SignedTokenProvider and the AgentProvider (if
|
||||
// necessary).
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -517,18 +533,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
|
|||
r.URL.Path = path
|
||||
appURL.RawQuery = ""
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(appURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: s.DashboardURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
|
||||
proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID)
|
||||
if err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
|
@ -540,7 +545,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
|
|||
return
|
||||
}
|
||||
defer release()
|
||||
proxy.Transport = conn.HTTPTransport()
|
||||
|
||||
proxy.ModifyResponse = func(r *http.Response) error {
|
||||
r.Header.Del(httpmw.AccessControlAllowOriginHeader)
|
||||
|
@ -658,13 +662,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
|||
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
|
||||
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
defer agentConn.Close()
|
||||
log.Debug(ctx, "dialed workspace agent")
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
|
||||
if err != nil {
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
// Package wsconncache caches workspace agent connections by UUID.
|
||||
// Deprecated: Use ServerTailnet instead.
|
||||
package wsconncache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -13,13 +16,57 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
)
|
||||
|
||||
// New creates a new workspace connection cache that closes
|
||||
// connections after the inactive timeout provided.
|
||||
type AgentProvider struct {
|
||||
Cache *Cache
|
||||
}
|
||||
|
||||
func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
|
||||
conn, rel, err := a.Cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
|
||||
}
|
||||
|
||||
return conn.WorkspaceAgentConn, rel, nil
|
||||
}
|
||||
|
||||
func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: dashboardURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
conn, release, err := a.Cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
|
||||
}
|
||||
|
||||
proxy.Transport = conn.HTTPTransport()
|
||||
|
||||
return proxy, release, nil
|
||||
}
|
||||
|
||||
func (a *AgentProvider) Close() error {
|
||||
return a.Cache.Close()
|
||||
}
|
||||
|
||||
// New creates a new workspace connection cache that closes connections after
|
||||
// the inactive timeout provided.
|
||||
//
|
||||
// Agent connections are cached due to WebRTC negotiation
|
||||
// taking a few hundred milliseconds.
|
||||
// Agent connections are cached due to Wireguard negotiation taking a few
|
||||
// hundred milliseconds, depending on latency.
|
||||
//
|
||||
// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased
|
||||
// out because it creates a unique Tailnet for each agent.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
|
||||
if inactiveTimeout == 0 {
|
||||
inactiveTimeout = 5 * time.Minute
|
||||
|
|
|
@ -157,22 +157,23 @@ func TestCache(t *testing.T) {
|
|||
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
manifest.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
t.Cleanup(func() {
|
||||
_ = coordinator.Close()
|
||||
})
|
||||
agentID := uuid.New()
|
||||
manifest.AgentID = uuid.New()
|
||||
closer := agent.New(agent.Options{
|
||||
Client: &client{
|
||||
t: t,
|
||||
agentID: agentID,
|
||||
agentID: manifest.AgentID,
|
||||
manifest: manifest,
|
||||
coordinator: coordinator,
|
||||
},
|
||||
Logger: logger.Named("agent"),
|
||||
ReconnectingPTYTimeout: ptyTimeout,
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = closer.Close()
|
||||
|
@ -189,14 +190,15 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
|
|||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
agentConn := &codersdk.WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
}
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: manifest.AgentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentConn.Close()
|
||||
})
|
||||
|
|
|
@ -84,6 +84,7 @@ func (c *Client) PostMetadata(ctx context.Context, key string, req PostMetadataR
|
|||
}
|
||||
|
||||
type Manifest struct {
|
||||
AgentID uuid.UUID `json:"agent_id"`
|
||||
// GitAuthConfigs stores the number of Git configurations
|
||||
// the Coder deployment has. If this number is >0, we
|
||||
// set up special configuration in the workspace.
|
||||
|
|
|
@ -1764,6 +1764,12 @@ const (
|
|||
// oidc.
|
||||
ExperimentConvertToOIDC Experiment = "convert-to-oidc"
|
||||
|
||||
// ExperimentSingleTailnet replaces workspace connections inside coderd to
|
||||
// all use a single tailnet, instead of the previous behavior of creating a
|
||||
// single tailnet for each agent.
|
||||
// WARNING: This cannot be enabled when using HA.
|
||||
ExperimentSingleTailnet Experiment = "single_tailnet"
|
||||
|
||||
ExperimentWorkspaceBuildLogsUI Experiment = "workspace_build_logs_ui"
|
||||
// Add new experiments here!
|
||||
// ExperimentExample Experiment = "example"
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
|
@ -27,8 +28,14 @@ import (
|
|||
// WorkspaceAgentIP is a static IPv6 address with the Tailscale prefix that is used to route
|
||||
// connections from clients to this node. A dynamic address is not required because a Tailnet
|
||||
// client only dials a single agent at a time.
|
||||
//
|
||||
// Deprecated: use tailnet.IP() instead. This is kept for backwards
|
||||
// compatibility with wsconncache.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
var WorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
|
||||
|
||||
var ErrSkipClose = xerrors.New("skip tailnet close")
|
||||
|
||||
const (
|
||||
WorkspaceAgentSSHPort = tailnet.WorkspaceAgentSSHPort
|
||||
WorkspaceAgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort
|
||||
|
@ -120,11 +127,38 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// NewWorkspaceAgentConn creates a new WorkspaceAgentConn. `conn` may be unique
|
||||
// to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the
|
||||
// conn is shared and closing it is undesirable, you may return ErrNoClose from
|
||||
// opts.CloseFunc. This will ensure the underlying conn is not closed.
|
||||
func NewWorkspaceAgentConn(conn *tailnet.Conn, opts WorkspaceAgentConnOptions) *WorkspaceAgentConn {
|
||||
return &WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// WorkspaceAgentConn represents a connection to a workspace agent.
|
||||
// @typescript-ignore WorkspaceAgentConn
|
||||
type WorkspaceAgentConn struct {
|
||||
*tailnet.Conn
|
||||
CloseFunc func()
|
||||
opts WorkspaceAgentConnOptions
|
||||
}
|
||||
|
||||
// @typescript-ignore WorkspaceAgentConnOptions
|
||||
type WorkspaceAgentConnOptions struct {
|
||||
AgentID uuid.UUID
|
||||
AgentIP netip.Addr
|
||||
CloseFunc func() error
|
||||
}
|
||||
|
||||
func (c *WorkspaceAgentConn) agentAddress() netip.Addr {
|
||||
var emptyIP netip.Addr
|
||||
if cmp := c.opts.AgentIP.Compare(emptyIP); cmp != 0 {
|
||||
return c.opts.AgentIP
|
||||
}
|
||||
|
||||
return tailnet.IPFromUUID(c.opts.AgentID)
|
||||
}
|
||||
|
||||
// AwaitReachable waits for the agent to be reachable.
|
||||
|
@ -132,7 +166,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool {
|
|||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
return c.Conn.AwaitReachable(ctx, WorkspaceAgentIP)
|
||||
return c.Conn.AwaitReachable(ctx, c.agentAddress())
|
||||
}
|
||||
|
||||
// Ping pings the agent and returns the round-trip time.
|
||||
|
@ -141,13 +175,20 @@ func (c *WorkspaceAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ip
|
|||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
return c.Conn.Ping(ctx, WorkspaceAgentIP)
|
||||
return c.Conn.Ping(ctx, c.agentAddress())
|
||||
}
|
||||
|
||||
// Close ends the connection to the workspace agent.
|
||||
func (c *WorkspaceAgentConn) Close() error {
|
||||
if c.CloseFunc != nil {
|
||||
c.CloseFunc()
|
||||
var cerr error
|
||||
if c.opts.CloseFunc != nil {
|
||||
cerr = c.opts.CloseFunc()
|
||||
if xerrors.Is(cerr, ErrSkipClose) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if cerr != nil {
|
||||
return multierror.Append(cerr, c.Conn.Close())
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
@ -176,10 +217,12 @@ type ReconnectingPTYRequest struct {
|
|||
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
|
||||
|
||||
conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentReconnectingPTYPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -209,10 +252,12 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
|
|||
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
|
||||
|
||||
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSSHPort))
|
||||
}
|
||||
|
||||
// SSHClient calls SSH to create a client that uses a weak cipher
|
||||
|
@ -220,10 +265,12 @@ func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
|
|||
func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
netConn, err := c.SSH(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
}
|
||||
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
|
||||
// SSH host validation isn't helpful, because obtaining a peer
|
||||
// connection already signifies user-intent to dial a workspace.
|
||||
|
@ -233,6 +280,7 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh conn: %w", err)
|
||||
}
|
||||
|
||||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
|
@ -240,17 +288,21 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
|
|||
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
|
||||
|
||||
speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSpeedtestPort))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial speedtest: %w", err)
|
||||
}
|
||||
|
||||
results, err := speedtest.RunClientWithConn(direction, duration, speedConn)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("run speedtest: %w", err)
|
||||
}
|
||||
|
||||
return results, err
|
||||
}
|
||||
|
||||
|
@ -259,19 +311,23 @@ func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.
|
|||
func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
if network == "unix" {
|
||||
return nil, xerrors.New("network must be tcp or udp")
|
||||
}
|
||||
_, rawPort, _ := net.SplitHostPort(addr)
|
||||
port, _ := strconv.ParseUint(rawPort, 10, 16)
|
||||
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
|
||||
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
if network == "udp" {
|
||||
|
||||
_, rawPort, _ := net.SplitHostPort(addr)
|
||||
port, _ := strconv.ParseUint(rawPort, 10, 16)
|
||||
ipp := netip.AddrPortFrom(c.agentAddress(), uint16(port))
|
||||
|
||||
switch network {
|
||||
case "tcp":
|
||||
return c.Conn.DialContextTCP(ctx, ipp)
|
||||
case "udp":
|
||||
return c.Conn.DialContextUDP(ctx, ipp)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unknown network %q", network)
|
||||
}
|
||||
return c.Conn.DialContextTCP(ctx, ipp)
|
||||
}
|
||||
|
||||
type WorkspaceAgentListeningPortsResponse struct {
|
||||
|
@ -309,7 +365,8 @@ func (c *WorkspaceAgentConn) ListeningPorts(ctx context.Context) (WorkspaceAgent
|
|||
func (c *WorkspaceAgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
host := net.JoinHostPort(WorkspaceAgentIP.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort))
|
||||
|
||||
host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort))
|
||||
url := fmt.Sprintf("http://%s%s", host, path)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
|
@ -332,13 +389,14 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
|
|||
if network != "tcp" {
|
||||
return nil, xerrors.Errorf("network must be tcp")
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
|
||||
}
|
||||
// Verify that host is TailnetIP and port is
|
||||
// TailnetStatisticsPort.
|
||||
if host != WorkspaceAgentIP.String() || port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) {
|
||||
|
||||
// Verify that the port is TailnetStatisticsPort.
|
||||
if port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) {
|
||||
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
|
||||
}
|
||||
|
||||
|
@ -346,7 +404,12 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
|
|||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
|
||||
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
|
||||
ipAddr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse host addr: %w", err)
|
||||
}
|
||||
|
||||
conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(ipAddr, WorkspaceAgentHTTPAPIServerPort))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial http api: %w", err)
|
||||
}
|
||||
|
|
|
@ -307,8 +307,8 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
|
|||
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(nodes, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
options.Logger.Debug(ctx, "serving coordinator")
|
||||
|
@ -330,13 +330,15 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
|
|||
return nil, err
|
||||
}
|
||||
|
||||
agentConn = &WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
CloseFunc: func() {
|
||||
agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
CloseFunc: func() error {
|
||||
cancel()
|
||||
<-closed
|
||||
return conn.Close()
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
if !agentConn.AwaitReachable(ctx) {
|
||||
_ = agentConn.Close()
|
||||
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())
|
||||
|
|
|
@ -292,6 +292,7 @@ curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/manifest \
|
|||
|
||||
```json
|
||||
{
|
||||
"agent_id": "string",
|
||||
"apps": [
|
||||
{
|
||||
"command": "string",
|
||||
|
|
|
@ -161,6 +161,7 @@
|
|||
|
||||
```json
|
||||
{
|
||||
"agent_id": "string",
|
||||
"apps": [
|
||||
{
|
||||
"command": "string",
|
||||
|
@ -260,6 +261,7 @@
|
|||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
| ---------------------------- | ------------------------------------------------------------------------------------------------- | -------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `agent_id` | string | false | | |
|
||||
| `apps` | array of [codersdk.WorkspaceApp](#codersdkworkspaceapp) | false | | |
|
||||
| `derpmap` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | |
|
||||
| `directory` | string | false | | |
|
||||
|
@ -2543,6 +2545,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
|||
| `workspace_actions` |
|
||||
| `tailnet_ha_coordinator` |
|
||||
| `convert-to-oidc` |
|
||||
| `single_tailnet` |
|
||||
| `workspace_build_logs_ui` |
|
||||
|
||||
## codersdk.Feature
|
||||
|
|
|
@ -6,9 +6,8 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/clibase"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/codersdk"
|
||||
agpl "github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
|
@ -37,9 +38,12 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
|
|||
closeFunc: cancelFunc,
|
||||
close: make(chan struct{}),
|
||||
nodes: map[uuid.UUID]*agpl.Node{},
|
||||
agentSockets: map[uuid.UUID]*agpl.TrackedConn{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{},
|
||||
agentSockets: map[uuid.UUID]agpl.Queue{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Queue{},
|
||||
agentNameCache: nameCache,
|
||||
clients: map[uuid.UUID]agpl.Queue{},
|
||||
clientsToAgents: map[uuid.UUID]map[uuid.UUID]agpl.Queue{},
|
||||
legacyAgents: map[uuid.UUID]struct{}{},
|
||||
}
|
||||
|
||||
if err := coord.runPubsub(ctx); err != nil {
|
||||
|
@ -49,6 +53,56 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
|
|||
return coord, nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
||||
m := (&agpl.MultiAgent{
|
||||
ID: id,
|
||||
Logger: c.log,
|
||||
AgentIsLegacyFunc: c.agentIsLegacy,
|
||||
OnSubscribe: c.clientSubscribeToAgent,
|
||||
OnNodeUpdate: c.clientNodeUpdate,
|
||||
OnRemove: c.clientDisconnected,
|
||||
}).Init()
|
||||
c.addClient(id, m)
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) {
|
||||
c.mutex.Lock()
|
||||
c.clients[id] = q
|
||||
c.clientsToAgents[id] = map[uuid.UUID]agpl.Queue{}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.initOrSetAgentConnectionSocketLocked(agentID, enq)
|
||||
|
||||
node := c.nodes[enq.UniqueID()]
|
||||
if node != nil {
|
||||
err := c.sendNodeToAgentLocked(agentID, node)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("handle client update: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
agentNode, ok := c.nodes[agentID]
|
||||
// If we have the node locally, give it back to the multiagent.
|
||||
if ok {
|
||||
return agentNode, nil
|
||||
}
|
||||
|
||||
// If we don't have the node locally, notify other coordinators.
|
||||
err := c.publishClientHello(agentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish client hello: %w", err)
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type haCoordinator struct {
|
||||
id uuid.UUID
|
||||
log slog.Logger
|
||||
|
@ -60,14 +114,26 @@ type haCoordinator struct {
|
|||
// nodes maps agent and connection IDs their respective node.
|
||||
nodes map[uuid.UUID]*agpl.Node
|
||||
// agentSockets maps agent IDs to their open websocket.
|
||||
agentSockets map[uuid.UUID]*agpl.TrackedConn
|
||||
agentSockets map[uuid.UUID]agpl.Queue
|
||||
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
||||
// are subscribed to updates for that agent.
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Queue
|
||||
|
||||
// clients holds a map of all clients connected to the coordinator. This is
|
||||
// necessary because a client may not be subscribed into any agents.
|
||||
clients map[uuid.UUID]agpl.Queue
|
||||
// clientsToAgents is an index of clients to all of their subscribed agents.
|
||||
clientsToAgents map[uuid.UUID]map[uuid.UUID]agpl.Queue
|
||||
|
||||
// agentNameCache holds a cache of agent names. If one of them disappears,
|
||||
// it's helpful to have a name cached for debugging.
|
||||
agentNameCache *lru.Cache[uuid.UUID, string]
|
||||
|
||||
// legacyAgents holda a mapping of all agents detected as legacy, meaning
|
||||
// they only listen on codersdk.WorkspaceAgentIP. They aren't compatible
|
||||
// with the new ServerTailnet, so they must be connected through
|
||||
// wsconncache.
|
||||
legacyAgents map[uuid.UUID]struct{}
|
||||
}
|
||||
|
||||
// Node returns an in-memory node by ID.
|
||||
|
@ -88,61 +154,35 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {
|
|||
|
||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||
// with the specified ID.
|
||||
func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||
func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
logger := c.clientLogger(id, agent)
|
||||
|
||||
c.mutex.Lock()
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
if !ok {
|
||||
connectionSockets = map[uuid.UUID]*agpl.TrackedConn{}
|
||||
c.agentToConnectionSockets[agent] = connectionSockets
|
||||
}
|
||||
logger := c.clientLogger(id, agentID)
|
||||
|
||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0)
|
||||
// Insert this connection into a map so the agent
|
||||
// can publish node updates.
|
||||
connectionSockets[id] = tc
|
||||
defer tc.Close()
|
||||
|
||||
// When a new connection is requested, we update it with the latest
|
||||
// node of the agent. This allows the connection to establish.
|
||||
node, ok := c.nodes[agent]
|
||||
if ok {
|
||||
err := tc.Enqueue([]*agpl.Node{node})
|
||||
c.mutex.Unlock()
|
||||
c.addClient(id, tc)
|
||||
defer c.clientDisconnected(id)
|
||||
|
||||
agentNode, err := c.clientSubscribeToAgent(tc, agentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
|
||||
if agentNode != nil {
|
||||
err := tc.Enqueue([]*agpl.Node{agentNode})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("enqueue node: %w", err)
|
||||
}
|
||||
} else {
|
||||
c.mutex.Unlock()
|
||||
err := c.publishClientHello(agent)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish client hello: %w", err)
|
||||
logger.Debug(ctx, "enqueue initial node", slog.Error(err))
|
||||
}
|
||||
}
|
||||
go tc.SendUpdates()
|
||||
|
||||
defer func() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Clean all traces of this connection from the map.
|
||||
delete(c.nodes, id)
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(connectionSockets, id)
|
||||
if len(connectionSockets) != 0 {
|
||||
return
|
||||
}
|
||||
delete(c.agentToConnectionSockets, agent)
|
||||
}()
|
||||
go tc.SendUpdates()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
// Indefinitely handle messages from the client websocket.
|
||||
for {
|
||||
err := c.handleNextClientMessage(id, agent, decoder)
|
||||
err := c.handleNextClientMessage(id, decoder)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
|
||||
return nil
|
||||
|
@ -152,35 +192,90 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID
|
|||
}
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
|
||||
func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Queue) {
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agentID]
|
||||
if !ok {
|
||||
connectionSockets = map[uuid.UUID]agpl.Queue{}
|
||||
c.agentToConnectionSockets[agentID] = connectionSockets
|
||||
}
|
||||
connectionSockets[enq.UniqueID()] = enq
|
||||
c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID]
|
||||
}
|
||||
|
||||
func (c *haCoordinator) clientDisconnected(id uuid.UUID) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
for agentID := range c.clientsToAgents[id] {
|
||||
// Clean all traces of this connection from the map.
|
||||
delete(c.nodes, id)
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agentID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(connectionSockets, id)
|
||||
if len(connectionSockets) != 0 {
|
||||
return
|
||||
}
|
||||
delete(c.agentToConnectionSockets, agentID)
|
||||
}
|
||||
|
||||
delete(c.clients, id)
|
||||
delete(c.clientsToAgents, id)
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
|
||||
var node agpl.Node
|
||||
err := decoder.Decode(&node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
|
||||
return c.clientNodeUpdate(id, &node)
|
||||
}
|
||||
|
||||
func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Update the node of this client in our in-memory map. If an agent entirely
|
||||
// shuts down and reconnects, it needs to be aware of all clients attempting
|
||||
// to establish connections.
|
||||
c.nodes[id] = &node
|
||||
// Write the new node from this client to the actively connected agent.
|
||||
agentSocket, ok := c.agentSockets[agent]
|
||||
c.nodes[id] = node
|
||||
|
||||
for agentID, agentSocket := range c.clientsToAgents[id] {
|
||||
if agentSocket == nil {
|
||||
// If we don't own the agent locally, send it over pubsub to a node that
|
||||
// owns the agent.
|
||||
err := c.publishNodesToAgent(agentID, []*agpl.Node{node})
|
||||
if err != nil {
|
||||
c.log.Error(context.Background(), "publish node to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
} else {
|
||||
// Write the new node from this client to the actively connected agent.
|
||||
err := agentSocket.Enqueue([]*agpl.Node{node})
|
||||
if err != nil {
|
||||
c.log.Error(context.Background(), "enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error {
|
||||
agentSocket, ok := c.agentSockets[agentID]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
// If we don't own the agent locally, send it over pubsub to a node that
|
||||
// owns the agent.
|
||||
err := c.publishNodesToAgent(agent, []*agpl.Node{&node})
|
||||
err := c.publishNodesToAgent(agentID, []*agpl.Node{node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish node to agent")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
err = agentSocket.Enqueue([]*agpl.Node{&node})
|
||||
c.mutex.Unlock()
|
||||
err := agentSocket.Enqueue([]*agpl.Node{node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("enqueu nodes: %w", err)
|
||||
return xerrors.Errorf("enqueue node: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -202,7 +297,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
|
|||
// dead.
|
||||
oldAgentSocket, ok := c.agentSockets[id]
|
||||
if ok {
|
||||
overwrites = oldAgentSocket.Overwrites + 1
|
||||
overwrites = oldAgentSocket.Overwrites() + 1
|
||||
_ = oldAgentSocket.Close()
|
||||
}
|
||||
// This uniquely identifies a connection that belongs to this goroutine.
|
||||
|
@ -219,6 +314,9 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
|
|||
}
|
||||
}
|
||||
c.agentSockets[id] = tc
|
||||
for clientID := range c.agentToConnectionSockets[id] {
|
||||
c.clientsToAgents[clientID][id] = tc
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
go tc.SendUpdates()
|
||||
|
||||
|
@ -234,10 +332,13 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
|
|||
|
||||
// Only delete the connection if it's ours. It could have been
|
||||
// overwritten.
|
||||
if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique {
|
||||
if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique {
|
||||
delete(c.agentSockets, id)
|
||||
delete(c.nodes, id)
|
||||
}
|
||||
for clientID := range c.agentToConnectionSockets[id] {
|
||||
c.clientsToAgents[clientID][id] = nil
|
||||
}
|
||||
}()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
|
@ -285,6 +386,13 @@ func (c *haCoordinator) handleClientHello(id uuid.UUID) error {
|
|||
return c.publishAgentToNodes(id, node)
|
||||
}
|
||||
|
||||
func (c *haCoordinator) agentIsLegacy(agentID uuid.UUID) bool {
|
||||
c.mutex.RLock()
|
||||
_, ok := c.legacyAgents[agentID]
|
||||
c.mutex.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) {
|
||||
var node agpl.Node
|
||||
err := decoder.Decode(&node)
|
||||
|
@ -293,6 +401,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
|
|||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
// Keep a cache of all legacy agents.
|
||||
if len(node.Addresses) > 0 && node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP {
|
||||
c.legacyAgents[id] = struct{}{}
|
||||
}
|
||||
|
||||
oldNode := c.nodes[id]
|
||||
if oldNode != nil {
|
||||
if oldNode.AsOf.After(node.AsOf) {
|
||||
|
@ -311,7 +424,9 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
|
|||
for _, connectionSocket := range connectionSockets {
|
||||
_ = connectionSocket.Enqueue([]*agpl.Node{&node})
|
||||
}
|
||||
|
||||
c.mutex.Unlock()
|
||||
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
|
@ -334,20 +449,18 @@ func (c *haCoordinator) Close() error {
|
|||
for _, socket := range c.agentSockets {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
_ = socket.CoordinatorClose()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
for _, connMap := range c.agentToConnectionSockets {
|
||||
wg.Add(len(connMap))
|
||||
for _, socket := range connMap {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Add(len(c.clients))
|
||||
for _, client := range c.clients {
|
||||
client := client
|
||||
go func() {
|
||||
_ = client.CoordinatorClose()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
@ -422,13 +535,12 @@ func (c *haCoordinator) runPubsub(ctx context.Context) error {
|
|||
}
|
||||
go func() {
|
||||
for {
|
||||
var message []byte
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case message = <-messageQueue:
|
||||
case message := <-messageQueue:
|
||||
c.handlePubsubMessage(ctx, message)
|
||||
}
|
||||
c.handlePubsubMessage(ctx, message)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
|
@ -125,6 +125,11 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store
|
|||
return c, nil
|
||||
}
|
||||
|
||||
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
||||
_, _ = c, id
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) {
|
||||
// TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
|
@ -183,9 +183,12 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
|
|||
SecurityKey: secKey,
|
||||
Logger: s.Logger.Named("proxy_token_provider"),
|
||||
},
|
||||
WorkspaceConnCache: wsconncache.New(s.DialWorkspaceAgent, 0),
|
||||
AppSecurityKey: secKey,
|
||||
AppSecurityKey: secKey,
|
||||
|
||||
// TODO: Convert wsproxy to use coderd.ServerTailnet.
|
||||
AgentProvider: &wsconncache.AgentProvider{
|
||||
Cache: wsconncache.New(s.DialWorkspaceAgent, 0),
|
||||
},
|
||||
DisablePathApps: opts.DisablePathApps,
|
||||
SecureAuthCookie: opts.SecureAuthCookie,
|
||||
}
|
||||
|
@ -273,6 +276,7 @@ func (s *Server) Close() error {
|
|||
tmp, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_ = s.SDKClient.WorkspaceProxyGoingAway(tmp)
|
||||
_ = s.AppServer.AgentProvider.Close()
|
||||
|
||||
return s.AppServer.Close()
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -377,7 +376,10 @@ func agentHTTPClient(conn *codersdk.WorkspaceAgentConn) *http.Client {
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse port %q: %w", port, err)
|
||||
}
|
||||
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.WorkspaceAgentIP, uint16(portUint)))
|
||||
|
||||
// Addr doesn't matter here, besides the port. DialContext will
|
||||
// automatically choose the right IP to dial.
|
||||
return conn.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", portUint))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
|
@ -243,7 +244,7 @@ func Test_Runner(t *testing.T) {
|
|||
func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
@ -282,12 +283,16 @@ func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID)
|
|||
agentClient.SetSessionToken(authToken)
|
||||
agentCloser := agent.New(agent.Options{
|
||||
Client: agentClient,
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent"),
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent").Leveled(slog.LevelDebug),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentCloser.Close()
|
||||
})
|
||||
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
require.Eventually(t, func() bool {
|
||||
t.Log("agent id", resources[0].Agents[0].ID)
|
||||
return (*api.TailnetCoordinator.Load()).Node(resources[0].Agents[0].ID) != nil
|
||||
}, testutil.WaitLong, testutil.IntervalMedium, "agent never connected")
|
||||
return client, resources[0].Agents[0].ID
|
||||
}
|
||||
|
|
|
@ -68,6 +68,7 @@ func TestRun(t *testing.T) {
|
|||
agentCloser := agent.New(agent.Options{
|
||||
Client: agentClient,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
t.Cleanup(func() {
|
||||
|
|
|
@ -1431,12 +1431,14 @@ export const Entitlements: Entitlement[] = [
|
|||
export type Experiment =
|
||||
| "convert-to-oidc"
|
||||
| "moons"
|
||||
| "single_tailnet"
|
||||
| "tailnet_ha_coordinator"
|
||||
| "workspace_actions"
|
||||
| "workspace_build_logs_ui"
|
||||
export const Experiments: Experiment[] = [
|
||||
"convert-to-oidc",
|
||||
"moons",
|
||||
"single_tailnet",
|
||||
"tailnet_ha_coordinator",
|
||||
"workspace_actions",
|
||||
"workspace_build_logs_ui",
|
||||
|
|
106
tailnet/conn.go
106
tailnet/conn.go
|
@ -139,6 +139,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
IP()
|
||||
dialer := &tsdial.Dialer{
|
||||
Logf: Logger(options.Logger.Named("tsdial")),
|
||||
}
|
||||
|
@ -182,10 +183,17 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
|
||||
|
||||
netStack, err := netstack.Create(
|
||||
Logger(options.Logger.Named("netstack")), tunDevice, wireguardEngine, magicConn, dialer, dnsManager)
|
||||
Logger(options.Logger.Named("netstack")),
|
||||
tunDevice,
|
||||
wireguardEngine,
|
||||
magicConn,
|
||||
dialer,
|
||||
dnsManager,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create netstack: %w", err)
|
||||
}
|
||||
|
||||
dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
return netStack.DialContextTCP(ctx, dst)
|
||||
}
|
||||
|
@ -203,7 +211,14 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
localIPs, _ := localIPSet.IPSet()
|
||||
logIPSet := netipx.IPSetBuilder{}
|
||||
logIPs, _ := logIPSet.IPSet()
|
||||
wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter"))))
|
||||
wireguardEngine.SetFilter(filter.New(
|
||||
netMap.PacketFilter,
|
||||
localIPs,
|
||||
logIPs,
|
||||
nil,
|
||||
Logger(options.Logger.Named("packet-filter")),
|
||||
))
|
||||
|
||||
dialContext, dialCancel := context.WithCancel(context.Background())
|
||||
server := &Conn{
|
||||
blockEndpoints: options.BlockEndpoints,
|
||||
|
@ -230,6 +245,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
_ = server.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
|
||||
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
|
||||
if err != nil {
|
||||
|
@ -251,6 +267,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
server.lastMutex.Unlock()
|
||||
server.sendNode()
|
||||
})
|
||||
|
||||
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
|
||||
server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni))
|
||||
server.lastMutex.Lock()
|
||||
|
@ -262,6 +279,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
server.lastMutex.Unlock()
|
||||
server.sendNode()
|
||||
})
|
||||
|
||||
magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) {
|
||||
server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason))
|
||||
server.lastMutex.Lock()
|
||||
|
@ -273,6 +291,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
server.lastMutex.Unlock()
|
||||
server.sendNode()
|
||||
})
|
||||
|
||||
netStack.ForwardTCPIn = server.forwardTCP
|
||||
netStack.ForwardTCPSockOpts = server.forwardTCPSockOpts
|
||||
|
||||
|
@ -284,22 +303,30 @@ func NewConn(options *Options) (conn *Conn, err error) {
|
|||
return server, nil
|
||||
}
|
||||
|
||||
// IP generates a new IP with a static service prefix.
|
||||
func IP() netip.Addr {
|
||||
// This is Tailscale's ephemeral service prefix.
|
||||
// This can be changed easily later-on, because
|
||||
// all of our nodes are ephemeral.
|
||||
func maskUUID(uid uuid.UUID) uuid.UUID {
|
||||
// This is Tailscale's ephemeral service prefix. This can be changed easily
|
||||
// later-on, because all of our nodes are ephemeral.
|
||||
// fd7a:115c:a1e0
|
||||
uid := uuid.New()
|
||||
uid[0] = 0xfd
|
||||
uid[1] = 0x7a
|
||||
uid[2] = 0x11
|
||||
uid[3] = 0x5c
|
||||
uid[4] = 0xa1
|
||||
uid[5] = 0xe0
|
||||
return uid
|
||||
}
|
||||
|
||||
// IP generates a random IP with a static service prefix.
|
||||
func IP() netip.Addr {
|
||||
uid := maskUUID(uuid.New())
|
||||
return netip.AddrFrom16(uid)
|
||||
}
|
||||
|
||||
// IP generates a new IP from a UUID.
|
||||
func IPFromUUID(uid uuid.UUID) netip.Addr {
|
||||
return netip.AddrFrom16(maskUUID(uid))
|
||||
}
|
||||
|
||||
// Conn is an actively listening Wireguard connection.
|
||||
type Conn struct {
|
||||
dialContext context.Context
|
||||
|
@ -334,6 +361,29 @@ type Conn struct {
|
|||
trafficStats *connstats.Statistics
|
||||
}
|
||||
|
||||
func (c *Conn) SetAddresses(ips []netip.Prefix) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.netMap.Addresses = ips
|
||||
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
err := c.reconfig()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Addresses() []netip.Prefix {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
return c.netMap.Addresses
|
||||
}
|
||||
|
||||
func (c *Conn) SetNodeCallback(callback func(node *Node)) {
|
||||
c.lastMutex.Lock()
|
||||
c.nodeCallback = callback
|
||||
|
@ -366,32 +416,6 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail
|
|||
c.magicConn.SetDERPRegionDialer(dialer)
|
||||
}
|
||||
|
||||
func (c *Conn) RemoveAllPeers() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.netMap.Peers = []*tailcfg.Node{}
|
||||
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update wireguard config: %w", err)
|
||||
}
|
||||
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
|
||||
if err != nil {
|
||||
if c.isClosed() {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, wgengine.ErrNoChanges) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateNodes connects with a set of peers. This can be constantly updated,
|
||||
// and peers will continually be reconnected as necessary. If replacePeers is
|
||||
// true, all peers will be removed before adding the new ones.
|
||||
|
@ -423,6 +447,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
|
|||
}
|
||||
delete(c.peerMap, peer.ID)
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
// If no preferred DERP is provided, we can't reach the node.
|
||||
if node.PreferredDERP == 0 {
|
||||
|
@ -452,17 +477,29 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
|
|||
}
|
||||
c.peerMap[node.ID] = peerNode
|
||||
}
|
||||
|
||||
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
|
||||
for _, peer := range c.peerMap {
|
||||
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
|
||||
}
|
||||
|
||||
netMapCopy := *c.netMap
|
||||
c.logger.Debug(context.Background(), "updating network map")
|
||||
c.wireguardEngine.SetNetworkMap(&netMapCopy)
|
||||
err := c.reconfig()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) reconfig() error {
|
||||
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update wireguard config: %w", err)
|
||||
}
|
||||
|
||||
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
|
||||
if err != nil {
|
||||
if c.isClosed() {
|
||||
|
@ -473,6 +510,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
|
|||
}
|
||||
return xerrors.Errorf("reconfig: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ func TestMain(m *testing.M) {
|
|||
func TestTailnet(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
t.Run("InstantClose", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
|
@ -172,7 +172,7 @@ func TestConn_PreferredDERP(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
Logger: logger.Named("w1"),
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -11,17 +10,16 @@ import (
|
|||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// Coordinator exchanges nodes with agents to establish connections.
|
||||
|
@ -44,6 +42,8 @@ type Coordinator interface {
|
|||
ServeAgent(conn net.Conn, id uuid.UUID, name string) error
|
||||
// Close closes the coordinator.
|
||||
Close() error
|
||||
|
||||
ServeMultiAgent(id uuid.UUID) MultiAgentConn
|
||||
}
|
||||
|
||||
// Node represents a node in the network.
|
||||
|
@ -54,10 +54,11 @@ type Node struct {
|
|||
AsOf time.Time `json:"as_of"`
|
||||
// Key is the Wireguard public key of the node.
|
||||
Key key.NodePublic `json:"key"`
|
||||
// DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections.
|
||||
// DiscoKey is used for discovery messages over DERP to establish
|
||||
// peer-to-peer connections.
|
||||
DiscoKey key.DiscoPublic `json:"disco"`
|
||||
// PreferredDERP is the DERP server that peered connections
|
||||
// should meet at to establish.
|
||||
// PreferredDERP is the DERP server that peered connections should meet at
|
||||
// to establish.
|
||||
PreferredDERP int `json:"preferred_derp"`
|
||||
// DERPLatency is the latency in seconds to each DERP server.
|
||||
DERPLatency map[string]float64 `json:"derp_latency"`
|
||||
|
@ -68,8 +69,8 @@ type Node struct {
|
|||
DERPForcedWebsocket map[int]string `json:"derp_forced_websockets"`
|
||||
// Addresses are the IP address ranges this connection exposes.
|
||||
Addresses []netip.Prefix `json:"addresses"`
|
||||
// AllowedIPs specify what addresses can dial the connection.
|
||||
// We allow all by default.
|
||||
// AllowedIPs specify what addresses can dial the connection. We allow all
|
||||
// by default.
|
||||
AllowedIPs []netip.Prefix `json:"allowed_ips"`
|
||||
// Endpoints are ip:port combinations that can be used to establish
|
||||
// peer-to-peer connections.
|
||||
|
@ -130,12 +131,33 @@ func NewCoordinator(logger slog.Logger) Coordinator {
|
|||
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
|
||||
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
|
||||
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
|
||||
// This coordinator is incompatible with multiple Coder
|
||||
// replicas as all node data is in-memory.
|
||||
// This coordinator is incompatible with multiple Coder replicas as all node
|
||||
// data is in-memory.
|
||||
type coordinator struct {
|
||||
core *core
|
||||
}
|
||||
|
||||
func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
|
||||
m := (&MultiAgent{
|
||||
ID: id,
|
||||
Logger: c.core.logger,
|
||||
AgentIsLegacyFunc: c.core.agentIsLegacy,
|
||||
OnSubscribe: c.core.clientSubscribeToAgent,
|
||||
OnUnsubscribe: c.core.clientUnsubscribeFromAgent,
|
||||
OnNodeUpdate: c.core.clientNodeUpdate,
|
||||
OnRemove: c.core.clientDisconnected,
|
||||
}).Init()
|
||||
c.core.addClient(id, m)
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *core) addClient(id uuid.UUID, ma Queue) {
|
||||
c.mutex.Lock()
|
||||
c.clients[id] = ma
|
||||
c.clientsToAgents[id] = map[uuid.UUID]Queue{}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines;
|
||||
// it is protected by a mutex to ensure data stay consistent.
|
||||
type core struct {
|
||||
|
@ -146,14 +168,38 @@ type core struct {
|
|||
// nodes maps agent and connection IDs their respective node.
|
||||
nodes map[uuid.UUID]*Node
|
||||
// agentSockets maps agent IDs to their open websocket.
|
||||
agentSockets map[uuid.UUID]*TrackedConn
|
||||
agentSockets map[uuid.UUID]Queue
|
||||
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
|
||||
// are subscribed to updates for that agent.
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*TrackedConn
|
||||
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue
|
||||
|
||||
// clients holds a map of all clients connected to the coordinator. This is
|
||||
// necessary because a client may not be subscribed into any agents.
|
||||
clients map[uuid.UUID]Queue
|
||||
// clientsToAgents is an index of clients to all of their subscribed agents.
|
||||
clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue
|
||||
|
||||
// agentNameCache holds a cache of agent names. If one of them disappears,
|
||||
// it's helpful to have a name cached for debugging.
|
||||
agentNameCache *lru.Cache[uuid.UUID, string]
|
||||
|
||||
// legacyAgents holda a mapping of all agents detected as legacy, meaning
|
||||
// they only listen on codersdk.WorkspaceAgentIP. They aren't compatible
|
||||
// with the new ServerTailnet, so they must be connected through
|
||||
// wsconncache.
|
||||
legacyAgents map[uuid.UUID]struct{}
|
||||
}
|
||||
|
||||
type Queue interface {
|
||||
UniqueID() uuid.UUID
|
||||
Enqueue(n []*Node) error
|
||||
Name() string
|
||||
Stats() (start, lastWrite int64)
|
||||
Overwrites() int64
|
||||
// CoordinatorClose is used by the coordinator when closing a Queue. It
|
||||
// should skip removing itself from the coordinator.
|
||||
CoordinatorClose() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
func newCore(logger slog.Logger) *core {
|
||||
|
@ -165,128 +211,18 @@ func newCore(logger slog.Logger) *core {
|
|||
return &core{
|
||||
logger: logger,
|
||||
closed: false,
|
||||
nodes: make(map[uuid.UUID]*Node),
|
||||
agentSockets: map[uuid.UUID]*TrackedConn{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{},
|
||||
nodes: map[uuid.UUID]*Node{},
|
||||
agentSockets: map[uuid.UUID]Queue{},
|
||||
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{},
|
||||
agentNameCache: nameCache,
|
||||
legacyAgents: map[uuid.UUID]struct{}{},
|
||||
clients: map[uuid.UUID]Queue{},
|
||||
clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{},
|
||||
}
|
||||
}
|
||||
|
||||
var ErrWouldBlock = xerrors.New("would block")
|
||||
|
||||
type TrackedConn struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
conn net.Conn
|
||||
updates chan []*Node
|
||||
logger slog.Logger
|
||||
lastData []byte
|
||||
|
||||
// ID is an ephemeral UUID used to uniquely identify the owner of the
|
||||
// connection.
|
||||
ID uuid.UUID
|
||||
|
||||
Name string
|
||||
Start int64
|
||||
LastWrite int64
|
||||
Overwrites int64
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Enqueue(n []*Node) (err error) {
|
||||
atomic.StoreInt64(&t.LastWrite, time.Now().Unix())
|
||||
select {
|
||||
case t.updates <- n:
|
||||
return nil
|
||||
default:
|
||||
return ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
// Close the connection and cancel the context for reading node updates from the queue
|
||||
func (t *TrackedConn) Close() error {
|
||||
t.cancel()
|
||||
return t.conn.Close()
|
||||
}
|
||||
|
||||
// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung.
|
||||
// It is exported so that tests can use it.
|
||||
const WriteTimeout = time.Second * 5
|
||||
|
||||
// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is
|
||||
// canceled.
|
||||
func (t *TrackedConn) SendUpdates() {
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
t.logger.Debug(t.ctx, "done sending updates")
|
||||
return
|
||||
case nodes := <-t.updates:
|
||||
data, err := json.Marshal(nodes)
|
||||
if err != nil {
|
||||
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
|
||||
return
|
||||
}
|
||||
if bytes.Equal(t.lastData, data) {
|
||||
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data)))
|
||||
continue
|
||||
}
|
||||
|
||||
// Set a deadline so that hung connections don't put back pressure on the system.
|
||||
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
||||
err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout))
|
||||
if err != nil {
|
||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
||||
t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err))
|
||||
_ = t.Close()
|
||||
return
|
||||
}
|
||||
_, err = t.conn.Write(data)
|
||||
if err != nil {
|
||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
||||
t.logger.Debug(t.ctx, "could not write nodes to connection",
|
||||
slog.Error(err), slog.F("nodes", string(data)))
|
||||
_ = t.Close()
|
||||
return
|
||||
}
|
||||
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data)))
|
||||
|
||||
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
|
||||
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
|
||||
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
|
||||
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
|
||||
// our successful write, it is important that we reset the deadline before it fires.
|
||||
err = t.conn.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
||||
t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err))
|
||||
_ = t.Close()
|
||||
return
|
||||
}
|
||||
t.lastData = data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn {
|
||||
// buffer updates so they don't block, since we hold the
|
||||
// coordinator mutex while queuing. Node updates don't
|
||||
// come quickly, so 512 should be plenty for all but
|
||||
// the most pathological cases.
|
||||
updates := make(chan []*Node, 512)
|
||||
now := time.Now().Unix()
|
||||
return &TrackedConn{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
cancel: cancel,
|
||||
updates: updates,
|
||||
logger: logger,
|
||||
ID: id,
|
||||
Start: now,
|
||||
LastWrite: now,
|
||||
Overwrites: overwrites,
|
||||
}
|
||||
}
|
||||
|
||||
// Node returns an in-memory node by ID.
|
||||
// If the node does not exist, nil is returned.
|
||||
func (c *coordinator) Node(id uuid.UUID) *Node {
|
||||
|
@ -321,16 +257,29 @@ func (c *core) agentCount() int {
|
|||
|
||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
||||
// with the specified ID.
|
||||
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
||||
func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
logger := c.core.clientLogger(id, agent)
|
||||
logger := c.core.clientLogger(id, agentID)
|
||||
logger.Debug(ctx, "coordinating client")
|
||||
tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent)
|
||||
|
||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0)
|
||||
defer tc.Close()
|
||||
|
||||
c.core.addClient(id, tc)
|
||||
defer c.core.clientDisconnected(id)
|
||||
|
||||
agentNode, err := c.core.clientSubscribeToAgent(tc, agentID)
|
||||
if err != nil {
|
||||
return err
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
|
||||
if agentNode != nil {
|
||||
err := tc.Enqueue([]*Node{agentNode})
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "enqueue initial node", slog.Error(err))
|
||||
}
|
||||
}
|
||||
defer c.core.clientDisconnected(id, agent)
|
||||
|
||||
// On this goroutine, we read updates from the client and publish them. We start a second goroutine
|
||||
// to write updates back to the client.
|
||||
|
@ -338,7 +287,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
|
|||
|
||||
decoder := json.NewDecoder(conn)
|
||||
for {
|
||||
err := c.handleNextClientMessage(id, agent, decoder)
|
||||
err := c.handleNextClientMessage(id, decoder)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err))
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
|
||||
|
@ -353,99 +302,133 @@ func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger {
|
|||
return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
|
||||
}
|
||||
|
||||
// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is
|
||||
// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some
|
||||
// updates.
|
||||
func (c *core) initAndTrackClient(
|
||||
ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID,
|
||||
) (
|
||||
*TrackedConn, error,
|
||||
) {
|
||||
logger := c.clientLogger(id, agent)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.closed {
|
||||
return nil, xerrors.New("coordinator is closed")
|
||||
}
|
||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0)
|
||||
|
||||
// When a new connection is requested, we update it with the latest
|
||||
// node of the agent. This allows the connection to establish.
|
||||
node, ok := c.nodes[agent]
|
||||
if ok {
|
||||
err := tc.Enqueue([]*Node{node})
|
||||
// this should never error since we're still the only goroutine that
|
||||
// knows about the TrackedConn. If we hit an error something really
|
||||
// wrong is happening
|
||||
if err != nil {
|
||||
logger.Critical(ctx, "unable to queue initial node", slog.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Insert this connection into a map so the agent
|
||||
// can publish node updates.
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) {
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agentID]
|
||||
if !ok {
|
||||
connectionSockets = map[uuid.UUID]*TrackedConn{}
|
||||
c.agentToConnectionSockets[agent] = connectionSockets
|
||||
connectionSockets = map[uuid.UUID]Queue{}
|
||||
c.agentToConnectionSockets[agentID] = connectionSockets
|
||||
}
|
||||
connectionSockets[id] = tc
|
||||
logger.Debug(ctx, "added tracked connection")
|
||||
return tc, nil
|
||||
connectionSockets[enq.UniqueID()] = enq
|
||||
|
||||
c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID]
|
||||
}
|
||||
|
||||
func (c *core) clientDisconnected(id, agent uuid.UUID) {
|
||||
logger := c.clientLogger(id, agent)
|
||||
func (c *core) clientDisconnected(id uuid.UUID) {
|
||||
logger := c.clientLogger(id, uuid.Nil)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Clean all traces of this connection from the map.
|
||||
delete(c.nodes, id)
|
||||
logger.Debug(context.Background(), "deleted client node")
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agent]
|
||||
if !ok {
|
||||
return
|
||||
|
||||
for agentID := range c.clientsToAgents[id] {
|
||||
connectionSockets, ok := c.agentToConnectionSockets[agentID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delete(connectionSockets, id)
|
||||
logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID))
|
||||
|
||||
if len(connectionSockets) == 0 {
|
||||
delete(c.agentToConnectionSockets, agentID)
|
||||
logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
delete(connectionSockets, id)
|
||||
logger.Debug(context.Background(), "deleted client connectionSocket from map")
|
||||
if len(connectionSockets) != 0 {
|
||||
return
|
||||
}
|
||||
delete(c.agentToConnectionSockets, agent)
|
||||
logger.Debug(context.Background(), "deleted last client connectionSocket from map")
|
||||
|
||||
delete(c.clients, id)
|
||||
delete(c.clientsToAgents, id)
|
||||
logger.Debug(context.Background(), "deleted client agents")
|
||||
}
|
||||
|
||||
func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
|
||||
logger := c.core.clientLogger(id, agent)
|
||||
func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
|
||||
logger := c.core.clientLogger(id, uuid.Nil)
|
||||
|
||||
var node Node
|
||||
err := decoder.Decode(&node)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read json: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug(context.Background(), "got client node update", slog.F("node", node))
|
||||
return c.core.clientNodeUpdate(id, agent, &node)
|
||||
return c.core.clientNodeUpdate(id, &node)
|
||||
}
|
||||
|
||||
func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error {
|
||||
logger := c.clientLogger(id, agent)
|
||||
func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Update the node of this client in our in-memory map. If an agent entirely
|
||||
// shuts down and reconnects, it needs to be aware of all clients attempting
|
||||
// to establish connections.
|
||||
c.nodes[id] = node
|
||||
|
||||
agentSocket, ok := c.agentSockets[agent]
|
||||
if !ok {
|
||||
logger.Debug(context.Background(), "no agent socket, unable to send node")
|
||||
return nil
|
||||
return c.clientNodeUpdateLocked(id, node)
|
||||
}
|
||||
|
||||
func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error {
|
||||
logger := c.clientLogger(id, uuid.Nil)
|
||||
|
||||
agents := []uuid.UUID{}
|
||||
for agentID, agentSocket := range c.clientsToAgents[id] {
|
||||
if agentSocket == nil {
|
||||
logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID))
|
||||
continue
|
||||
}
|
||||
|
||||
err := agentSocket.Enqueue([]*Node{node})
|
||||
if err != nil {
|
||||
logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
continue
|
||||
}
|
||||
agents = append(agents, agentID)
|
||||
}
|
||||
|
||||
err := agentSocket.Enqueue([]*Node{node})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("Enqueue node: %w", err)
|
||||
logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
logger := c.clientLogger(enq.UniqueID(), agentID)
|
||||
|
||||
c.initOrSetAgentConnectionSocketLocked(agentID, enq)
|
||||
|
||||
node, ok := c.nodes[enq.UniqueID()]
|
||||
if ok {
|
||||
// If we have the client node, send it to the agent. If not, it will be
|
||||
// sent async.
|
||||
agentSocket, ok := c.agentSockets[agentID]
|
||||
if !ok {
|
||||
logger.Debug(context.Background(), "subscribe to agent; socket is nil")
|
||||
} else {
|
||||
err := agentSocket.Enqueue([]*Node{node})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("enqueue client to agent: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.Debug(context.Background(), "multiagent node doesn't exist")
|
||||
}
|
||||
logger.Debug(context.Background(), "enqueued node to agent")
|
||||
|
||||
agentNode, ok := c.nodes[agentID]
|
||||
if !ok {
|
||||
// This is ok, once the agent connects the node will be sent over.
|
||||
logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID))
|
||||
}
|
||||
|
||||
// Send the subscribed agent back to the multi agent.
|
||||
return agentNode, nil
|
||||
}
|
||||
|
||||
func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
delete(c.clientsToAgents[enq.UniqueID()], agentID)
|
||||
delete(c.agentToConnectionSockets[agentID], enq.UniqueID())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -493,11 +476,14 @@ func (c *core) agentDisconnected(id, unique uuid.UUID) {
|
|||
|
||||
// Only delete the connection if it's ours. It could have been
|
||||
// overwritten.
|
||||
if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique {
|
||||
if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique {
|
||||
delete(c.agentSockets, id)
|
||||
delete(c.nodes, id)
|
||||
logger.Debug(context.Background(), "deleted agent socket and node")
|
||||
}
|
||||
for clientID := range c.agentToConnectionSockets[id] {
|
||||
c.clientsToAgents[clientID][id] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is
|
||||
|
@ -519,7 +505,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
|
|||
// dead.
|
||||
oldAgentSocket, ok := c.agentSockets[id]
|
||||
if ok {
|
||||
overwrites = oldAgentSocket.Overwrites + 1
|
||||
overwrites = oldAgentSocket.Overwrites() + 1
|
||||
_ = oldAgentSocket.Close()
|
||||
}
|
||||
tc := NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites)
|
||||
|
@ -549,6 +535,10 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
|
|||
}
|
||||
|
||||
c.agentSockets[id] = tc
|
||||
for clientID := range c.agentToConnectionSockets[id] {
|
||||
c.clientsToAgents[clientID][id] = tc
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "added agent socket")
|
||||
return tc, nil
|
||||
}
|
||||
|
@ -564,11 +554,31 @@ func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder
|
|||
return c.core.agentNodeUpdate(id, &node)
|
||||
}
|
||||
|
||||
// This is copied from codersdk because importing it here would cause an import
|
||||
// cycle. This is just temporary until wsconncache is phased out.
|
||||
var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
|
||||
|
||||
// This is temporary until we no longer need to detect for agent backwards
|
||||
// compatibility.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
func (c *core) agentIsLegacy(agentID uuid.UUID) bool {
|
||||
c.mutex.RLock()
|
||||
_, ok := c.legacyAgents[agentID]
|
||||
c.mutex.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error {
|
||||
logger := c.agentLogger(id)
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.nodes[id] = node
|
||||
|
||||
// Keep a cache of all legacy agents.
|
||||
if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP {
|
||||
c.legacyAgents[id] = struct{}{}
|
||||
}
|
||||
|
||||
connectionSockets, ok := c.agentToConnectionSockets[id]
|
||||
if !ok {
|
||||
logger.Debug(context.Background(), "no client sockets; unable to send node")
|
||||
|
@ -588,6 +598,7 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error {
|
|||
slog.F("client_id", clientID), slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -611,20 +622,18 @@ func (c *core) close() error {
|
|||
for _, socket := range c.agentSockets {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
_ = socket.CoordinatorClose()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
for _, connMap := range c.agentToConnectionSockets {
|
||||
wg.Add(len(connMap))
|
||||
for _, socket := range connMap {
|
||||
socket := socket
|
||||
go func() {
|
||||
_ = socket.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Add(len(c.clients))
|
||||
for _, client := range c.clients {
|
||||
client := client
|
||||
go func() {
|
||||
_ = client.CoordinatorClose()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
c.mutex.Unlock()
|
||||
|
@ -649,8 +658,8 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func CoordinatorHTTPDebug(
|
||||
agentSocketsMap map[uuid.UUID]*TrackedConn,
|
||||
agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]*TrackedConn,
|
||||
agentSocketsMap map[uuid.UUID]Queue,
|
||||
agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue,
|
||||
agentNameCache *lru.Cache[uuid.UUID, string],
|
||||
) func(w http.ResponseWriter, _ *http.Request) {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
|
@ -658,7 +667,7 @@ func CoordinatorHTTPDebug(
|
|||
|
||||
type idConn struct {
|
||||
id uuid.UUID
|
||||
conn *TrackedConn
|
||||
conn Queue
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -671,16 +680,17 @@ func CoordinatorHTTPDebug(
|
|||
}
|
||||
|
||||
slices.SortFunc(agentSockets, func(a, b idConn) bool {
|
||||
return a.conn.Name < b.conn.Name
|
||||
return a.conn.Name() < b.conn.Name()
|
||||
})
|
||||
|
||||
for _, agent := range agentSockets {
|
||||
start, lastWrite := agent.conn.Stats()
|
||||
_, _ = fmt.Fprintf(w, "<li style=\"margin-top:4px\"><b>%s</b> (<code>%s</code>): created %v ago, write %v ago, overwrites %d </li>\n",
|
||||
agent.conn.Name,
|
||||
agent.conn.Name(),
|
||||
agent.id.String(),
|
||||
now.Sub(time.Unix(agent.conn.Start, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(agent.conn.LastWrite, 0)).Round(time.Second),
|
||||
agent.conn.Overwrites,
|
||||
now.Sub(time.Unix(start, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
|
||||
agent.conn.Overwrites(),
|
||||
)
|
||||
|
||||
if conns := agentToConnectionSocketsMap[agent.id]; len(conns) > 0 {
|
||||
|
@ -696,11 +706,12 @@ func CoordinatorHTTPDebug(
|
|||
|
||||
_, _ = fmt.Fprintln(w, "<ul>")
|
||||
for _, connSocket := range connSockets {
|
||||
start, lastWrite := connSocket.conn.Stats()
|
||||
_, _ = fmt.Fprintf(w, "<li><b>%s</b> (<code>%s</code>): created %v ago, write %v ago </li>\n",
|
||||
connSocket.conn.Name,
|
||||
connSocket.conn.Name(),
|
||||
connSocket.id.String(),
|
||||
now.Sub(time.Unix(connSocket.conn.Start, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(connSocket.conn.LastWrite, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(start, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
|
||||
)
|
||||
}
|
||||
_, _ = fmt.Fprintln(w, "</ul>")
|
||||
|
@ -755,11 +766,12 @@ func CoordinatorHTTPDebug(
|
|||
_, _ = fmt.Fprintf(w, "<h3 style=\"margin:0px;font-size:16px;font-weight:400\">connections: total %d</h3>\n", len(agentConns.conns))
|
||||
_, _ = fmt.Fprintln(w, "<ul>")
|
||||
for _, agentConn := range agentConns.conns {
|
||||
start, lastWrite := agentConn.conn.Stats()
|
||||
_, _ = fmt.Fprintf(w, "<li><b>%s</b> (<code>%s</code>): created %v ago, write %v ago </li>\n",
|
||||
agentConn.conn.Name,
|
||||
agentConn.conn.Name(),
|
||||
agentConn.id.String(),
|
||||
now.Sub(time.Unix(agentConn.conn.Start, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(agentConn.conn.LastWrite, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(start, 0)).Round(time.Second),
|
||||
now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
|
||||
)
|
||||
}
|
||||
_, _ = fmt.Fprintln(w, "</ul>")
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
type MultiAgentConn interface {
|
||||
UpdateSelf(node *Node) error
|
||||
SubscribeAgent(agentID uuid.UUID) error
|
||||
UnsubscribeAgent(agentID uuid.UUID) error
|
||||
NextUpdate(ctx context.Context) ([]*Node, bool)
|
||||
AgentIsLegacy(agentID uuid.UUID) bool
|
||||
Close() error
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
type MultiAgent struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
closed bool
|
||||
|
||||
ID uuid.UUID
|
||||
Logger slog.Logger
|
||||
|
||||
AgentIsLegacyFunc func(agentID uuid.UUID) bool
|
||||
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
|
||||
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
|
||||
OnNodeUpdate func(id uuid.UUID, node *Node) error
|
||||
OnRemove func(id uuid.UUID)
|
||||
|
||||
updates chan []*Node
|
||||
closeOnce sync.Once
|
||||
start int64
|
||||
lastWrite int64
|
||||
// Client nodes normally generate a unique id for each connection so
|
||||
// overwrites are really not an issue, but is provided for compatibility.
|
||||
overwrites int64
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Init() *MultiAgent {
|
||||
m.updates = make(chan []*Node, 128)
|
||||
m.start = time.Now().Unix()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MultiAgent) UniqueID() uuid.UUID {
|
||||
return m.ID
|
||||
}
|
||||
|
||||
func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool {
|
||||
return m.AgentIsLegacyFunc(agentID)
|
||||
}
|
||||
|
||||
var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
|
||||
|
||||
func (m *MultiAgent) UpdateSelf(node *Node) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
return m.OnNodeUpdate(m.ID, node)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
node, err := m.OnSubscribe(m, agentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if node != nil {
|
||||
return m.enqueueLocked([]*Node{node})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.closed {
|
||||
return ErrMultiAgentClosed
|
||||
}
|
||||
|
||||
return m.OnUnsubscribe(m, agentID)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false
|
||||
|
||||
case nodes, ok := <-m.updates:
|
||||
return nodes, ok
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Enqueue(nodes []*Node) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.enqueueLocked(nodes)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) enqueueLocked(nodes []*Node) error {
|
||||
atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
|
||||
|
||||
select {
|
||||
case m.updates <- nodes:
|
||||
return nil
|
||||
default:
|
||||
return ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Name() string {
|
||||
return m.ID.String()
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Stats() (start int64, lastWrite int64) {
|
||||
return m.start, atomic.LoadInt64(&m.lastWrite)
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Overwrites() int64 {
|
||||
return m.overwrites
|
||||
}
|
||||
|
||||
func (m *MultiAgent) IsClosed() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.closed
|
||||
}
|
||||
|
||||
func (m *MultiAgent) CoordinatorClose() error {
|
||||
m.mu.Lock()
|
||||
if !m.closed {
|
||||
m.closed = true
|
||||
close(m.updates)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Close() error {
|
||||
_ = m.CoordinatorClose()
|
||||
m.closeOnce.Do(func() { m.OnRemove(m.ID) })
|
||||
return nil
|
||||
}
|
|
@ -22,7 +22,7 @@ import (
|
|||
)
|
||||
|
||||
// RunDERPAndSTUN creates a DERP mapping for tests.
|
||||
func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap {
|
||||
func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {
|
||||
logf := tailnet.Logger(slogtest.Make(t, nil))
|
||||
d := derp.NewServer(key.NewNode(), logf)
|
||||
server := httptest.NewUnstartedServer(derphttp.Handler(d))
|
||||
|
@ -61,7 +61,7 @@ func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap {
|
|||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}, d
|
||||
}
|
||||
|
||||
// RunDERPOnlyWebSockets creates a DERP mapping for tests that
|
||||
|
|
|
@ -14,7 +14,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
func TestRunDERPAndSTUN(t *testing.T) {
|
||||
t.Parallel()
|
||||
_ = tailnettest.RunDERPAndSTUN(t)
|
||||
_, _ = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
|
||||
func TestRunDERPOnlyWebSockets(t *testing.T) {
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung.
|
||||
// It is exported so that tests can use it.
|
||||
const WriteTimeout = time.Second * 5
|
||||
|
||||
type TrackedConn struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
conn net.Conn
|
||||
updates chan []*Node
|
||||
logger slog.Logger
|
||||
lastData []byte
|
||||
|
||||
// ID is an ephemeral UUID used to uniquely identify the owner of the
|
||||
// connection.
|
||||
id uuid.UUID
|
||||
|
||||
name string
|
||||
start int64
|
||||
lastWrite int64
|
||||
overwrites int64
|
||||
}
|
||||
|
||||
func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn {
|
||||
// buffer updates so they don't block, since we hold the
|
||||
// coordinator mutex while queuing. Node updates don't
|
||||
// come quickly, so 512 should be plenty for all but
|
||||
// the most pathological cases.
|
||||
updates := make(chan []*Node, 512)
|
||||
now := time.Now().Unix()
|
||||
return &TrackedConn{
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
cancel: cancel,
|
||||
updates: updates,
|
||||
logger: logger,
|
||||
id: id,
|
||||
start: now,
|
||||
lastWrite: now,
|
||||
overwrites: overwrites,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Enqueue(n []*Node) (err error) {
|
||||
atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
|
||||
select {
|
||||
case t.updates <- n:
|
||||
return nil
|
||||
default:
|
||||
return ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TrackedConn) UniqueID() uuid.UUID {
|
||||
return t.id
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Stats() (start, lastWrite int64) {
|
||||
return t.start, atomic.LoadInt64(&t.lastWrite)
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Overwrites() int64 {
|
||||
return t.overwrites
|
||||
}
|
||||
|
||||
func (t *TrackedConn) CoordinatorClose() error {
|
||||
return t.Close()
|
||||
}
|
||||
|
||||
// Close the connection and cancel the context for reading node updates from the queue
|
||||
func (t *TrackedConn) Close() error {
|
||||
t.cancel()
|
||||
return t.conn.Close()
|
||||
}
|
||||
|
||||
// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is
|
||||
// canceled.
|
||||
func (t *TrackedConn) SendUpdates() {
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
t.logger.Debug(t.ctx, "done sending updates")
|
||||
return
|
||||
case nodes := <-t.updates:
|
||||
data, err := json.Marshal(nodes)
|
||||
if err != nil {
|
||||
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
|
||||
return
|
||||
}
|
||||
if bytes.Equal(t.lastData, data) {
|
||||
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data)))
|
||||
continue
|
||||
}
|
||||
|
||||
// Set a deadline so that hung connections don't put back pressure on the system.
|
||||
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
||||
err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout))
|
||||
if err != nil {
|
||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
||||
t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err))
|
||||
_ = t.Close()
|
||||
return
|
||||
}
|
||||
_, err = t.conn.Write(data)
|
||||
if err != nil {
|
||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
||||
t.logger.Debug(t.ctx, "could not write nodes to connection",
|
||||
slog.Error(err), slog.F("nodes", string(data)))
|
||||
_ = t.Close()
|
||||
return
|
||||
}
|
||||
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data)))
|
||||
|
||||
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
|
||||
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
|
||||
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
|
||||
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
|
||||
// our successful write, it is important that we reset the deadline before it fires.
|
||||
err = t.conn.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
||||
t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err))
|
||||
_ = t.Close()
|
||||
return
|
||||
}
|
||||
t.lastData = data
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue