chore: replace wsconncache with a single tailnet (#8176)

This commit is contained in:
Colin Adler 2023-07-12 17:37:31 -05:00 committed by GitHub
parent 0a37dd20d6
commit c47b78c44b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 2004 additions and 763 deletions

View File

@ -64,6 +64,7 @@ type Options struct {
SSHMaxTimeout time.Duration
TailnetListenPort uint16
Subsystem codersdk.AgentSubsystem
Addresses []netip.Prefix
PrometheusRegistry *prometheus.Registry
}
@ -132,6 +133,7 @@ func New(options Options) Agent {
connStatsChan: make(chan *agentsdk.Stats, 1),
sshMaxTimeout: options.SSHMaxTimeout,
subsystem: options.Subsystem,
addresses: options.Addresses,
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
@ -177,6 +179,7 @@ type agent struct {
lifecycleStates []agentsdk.PostLifecycleRequest
network *tailnet.Conn
addresses []netip.Prefix
connStatsChan chan *agentsdk.Stats
latestStat atomic.Pointer[agentsdk.Stats]
@ -545,6 +548,10 @@ func (a *agent) run(ctx context.Context) error {
}
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest))
if manifest.AgentID == uuid.Nil {
return xerrors.New("nil agentID returned by manifest")
}
// Expand the directory and send it back to coderd so external
// applications that rely on the directory can use it.
//
@ -630,7 +637,7 @@ func (a *agent) run(ctx context.Context) error {
network := a.network
a.closeMutex.Unlock()
if network == nil {
network, err = a.createTailnet(ctx, manifest.DERPMap, manifest.DisableDirectConnections)
network, err = a.createTailnet(ctx, manifest.AgentID, manifest.DERPMap, manifest.DisableDirectConnections)
if err != nil {
return xerrors.Errorf("create tailnet: %w", err)
}
@ -648,6 +655,11 @@ func (a *agent) run(ctx context.Context) error {
a.startReportingConnectionStats(ctx)
} else {
// Update the wireguard IPs if the agent ID changed.
err := network.SetAddresses(a.wireguardAddresses(manifest.AgentID))
if err != nil {
a.logger.Error(ctx, "update tailnet addresses", slog.Error(err))
}
// Update the DERP map and allow/disallow direct connections.
network.SetDERPMap(manifest.DERPMap)
network.SetBlockEndpoints(manifest.DisableDirectConnections)
@ -661,6 +673,20 @@ func (a *agent) run(ctx context.Context) error {
return nil
}
func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
if len(a.addresses) == 0 {
return []netip.Prefix{
// This is the IP that should be used primarily.
netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
// We also listen on the legacy codersdk.WorkspaceAgentIP. This
// allows for a transition away from wsconncache.
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
}
}
return a.addresses
}
func (a *agent) trackConnGoroutine(fn func()) error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
@ -675,9 +701,9 @@ func (a *agent) trackConnGoroutine(fn func()) error {
return nil
}
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
network, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
Addresses: a.wireguardAddresses(agentID),
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
ListenPort: a.tailnetListenPort,

View File

@ -35,7 +35,6 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
@ -45,6 +44,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/agent/agentssh"
"github.com/coder/coder/agent/agenttest"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
@ -67,7 +67,7 @@ func TestAgent_Stats_SSH(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, &client{}, 0)
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
@ -100,7 +100,7 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, &client{}, 0)
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash")
require.NoError(t, err)
@ -130,7 +130,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@ -157,7 +157,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, &client{}, 0)
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@ -425,20 +425,19 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled // Allow the blank identifiers.
conn, client, _, _, _ := setupAgent(t, &client{}, 0)
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
for _, test := range tests {
test := test
// Set new banner func and wait for the agent to call it to update the
// banner.
ready := make(chan struct{}, 2)
client.mu.Lock()
client.getServiceBanner = func() (codersdk.ServiceBannerConfig, error) {
client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return test.banner, nil
}
client.mu.Unlock()
})
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
@ -542,7 +541,7 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@ -592,7 +591,7 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@ -922,7 +921,7 @@ func TestAgent_SFTP(t *testing.T) {
home = "/" + strings.ReplaceAll(home, "\\", "/")
}
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@ -954,7 +953,7 @@ func TestAgent_SCP(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
@ -1062,16 +1061,15 @@ func TestAgent_StartupScript(t *testing.T) {
t.Run("Success", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
client := &client{
t: t,
agentID: uuid.New(),
manifest: agentsdk.Manifest{
client := agenttest.NewClient(t,
uuid.New(),
agentsdk.Manifest{
StartupScript: command,
DERPMap: &tailcfg.DERPMap{},
},
statsChan: make(chan *agentsdk.Stats),
coordinator: tailnet.NewCoordinator(logger),
}
make(chan *agentsdk.Stats),
tailnet.NewCoordinator(logger),
)
closer := agent.New(agent.Options{
Client: client,
Filesystem: afero.NewMemMapFs(),
@ -1082,36 +1080,35 @@ func TestAgent_StartupScript(t *testing.T) {
_ = closer.Close()
})
assert.Eventually(t, func() bool {
got := client.getLifecycleStates()
got := client.GetLifecycleStates()
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
}, testutil.WaitShort, testutil.IntervalMedium)
require.Len(t, client.getStartupLogs(), 1)
require.Equal(t, output, client.getStartupLogs()[0].Output)
require.Len(t, client.GetStartupLogs(), 1)
require.Equal(t, output, client.GetStartupLogs()[0].Output)
})
// This ensures that even when coderd sends back that the startup
// script has written too many lines it will still succeed!
t.Run("OverflowsAndSkips", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
client := &client{
t: t,
agentID: uuid.New(),
manifest: agentsdk.Manifest{
client := agenttest.NewClient(t,
uuid.New(),
agentsdk.Manifest{
StartupScript: command,
DERPMap: &tailcfg.DERPMap{},
},
patchWorkspaceLogs: func() error {
resp := httptest.NewRecorder()
httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "Too many lines!",
})
res := resp.Result()
defer res.Body.Close()
return codersdk.ReadBodyAsError(res)
},
statsChan: make(chan *agentsdk.Stats),
coordinator: tailnet.NewCoordinator(logger),
make(chan *agentsdk.Stats, 50),
tailnet.NewCoordinator(logger),
)
client.PatchWorkspaceLogs = func() error {
resp := httptest.NewRecorder()
httpapi.Write(context.Background(), resp, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "Too many lines!",
})
res := resp.Result()
defer res.Body.Close()
return codersdk.ReadBodyAsError(res)
}
closer := agent.New(agent.Options{
Client: client,
@ -1123,10 +1120,10 @@ func TestAgent_StartupScript(t *testing.T) {
_ = closer.Close()
})
assert.Eventually(t, func() bool {
got := client.getLifecycleStates()
got := client.GetLifecycleStates()
return len(got) > 0 && got[len(got)-1] == codersdk.WorkspaceAgentLifecycleReady
}, testutil.WaitShort, testutil.IntervalMedium)
require.Len(t, client.getStartupLogs(), 0)
require.Len(t, client.GetStartupLogs(), 0)
})
}
@ -1138,28 +1135,26 @@ func TestAgent_Metadata(t *testing.T) {
t.Run("Once", func(t *testing.T) {
t.Parallel()
//nolint:dogsled
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: 0,
Script: echoHello,
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: 0,
Script: echoHello,
},
},
}, 0)
var gotMd map[string]agentsdk.PostMetadataRequest
require.Eventually(t, func() bool {
gotMd = client.getMetadata()
gotMd = client.GetMetadata()
return len(gotMd) == 1
}, testutil.WaitShort, testutil.IntervalMedium)
collectedAt := gotMd["greeting"].CollectedAt
require.Never(t, func() bool {
gotMd = client.getMetadata()
gotMd = client.GetMetadata()
if len(gotMd) != 1 {
panic("unexpected number of metadata")
}
@ -1170,22 +1165,20 @@ func TestAgent_Metadata(t *testing.T) {
t.Run("Many", func(t *testing.T) {
t.Parallel()
//nolint:dogsled
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: 1,
Timeout: 100,
Script: echoHello,
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: 1,
Timeout: 100,
Script: echoHello,
},
},
}, 0)
var gotMd map[string]agentsdk.PostMetadataRequest
require.Eventually(t, func() bool {
gotMd = client.getMetadata()
gotMd = client.GetMetadata()
return len(gotMd) == 1
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1195,7 +1188,7 @@ func TestAgent_Metadata(t *testing.T) {
}
if !assert.Eventually(t, func() bool {
gotMd = client.getMetadata()
gotMd = client.GetMetadata()
return gotMd["greeting"].CollectedAt.After(collectedAt1)
}, testutil.WaitShort, testutil.IntervalMedium) {
t.Fatalf("expected metadata to be collected again")
@ -1221,29 +1214,27 @@ func TestAgentMetadata_Timing(t *testing.T) {
script = "echo hello | tee -a " + greetingPath
)
//nolint:dogsled
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: reportInterval,
Script: script,
},
{
Key: "bad",
Interval: reportInterval,
Script: "exit 1",
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Metadata: []codersdk.WorkspaceAgentMetadataDescription{
{
Key: "greeting",
Interval: reportInterval,
Script: script,
},
{
Key: "bad",
Interval: reportInterval,
Script: "exit 1",
},
},
}, 0)
require.Eventually(t, func() bool {
return len(client.getMetadata()) == 2
return len(client.GetMetadata()) == 2
}, testutil.WaitShort, testutil.IntervalMedium)
for start := time.Now(); time.Since(start) < testutil.WaitMedium; time.Sleep(testutil.IntervalMedium) {
md := client.getMetadata()
md := client.GetMetadata()
require.Len(t, md, 2, "got: %+v", md)
require.Equal(t, "hello\n", md["greeting"].Value)
@ -1285,11 +1276,9 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Run("StartTimeout", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "sleep 3",
StartupScriptTimeout: time.Nanosecond,
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "sleep 3",
StartupScriptTimeout: time.Nanosecond,
}, 0)
want := []codersdk.WorkspaceAgentLifecycle{
@ -1299,7 +1288,7 @@ func TestAgent_Lifecycle(t *testing.T) {
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.getLifecycleStates()
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1309,11 +1298,9 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Run("StartError", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "false",
StartupScriptTimeout: 30 * time.Second,
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "false",
StartupScriptTimeout: 30 * time.Second,
}, 0)
want := []codersdk.WorkspaceAgentLifecycle{
@ -1323,7 +1310,7 @@ func TestAgent_Lifecycle(t *testing.T) {
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.getLifecycleStates()
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1333,11 +1320,9 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Run("Ready", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
}, 0)
want := []codersdk.WorkspaceAgentLifecycle{
@ -1347,7 +1332,7 @@ func TestAgent_Lifecycle(t *testing.T) {
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.getLifecycleStates()
got = client.GetLifecycleStates()
return len(got) > 0 && got[len(got)-1] == want[len(want)-1]
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1357,15 +1342,13 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Run("ShuttingDown", func(t *testing.T) {
t.Parallel()
_, client, _, _, closer := setupAgent(t, &client{
manifest: agentsdk.Manifest{
ShutdownScript: "sleep 3",
StartupScriptTimeout: 30 * time.Second,
},
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
ShutdownScript: "sleep 3",
StartupScriptTimeout: 30 * time.Second,
}, 0)
assert.Eventually(t, func() bool {
return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalMedium)
// Start close asynchronously so that we an inspect the state.
@ -1387,7 +1370,7 @@ func TestAgent_Lifecycle(t *testing.T) {
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.getLifecycleStates()
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1397,15 +1380,13 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Run("ShutdownTimeout", func(t *testing.T) {
t.Parallel()
_, client, _, _, closer := setupAgent(t, &client{
manifest: agentsdk.Manifest{
ShutdownScript: "sleep 3",
ShutdownScriptTimeout: time.Nanosecond,
},
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
ShutdownScript: "sleep 3",
ShutdownScriptTimeout: time.Nanosecond,
}, 0)
assert.Eventually(t, func() bool {
return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalMedium)
// Start close asynchronously so that we an inspect the state.
@ -1428,7 +1409,7 @@ func TestAgent_Lifecycle(t *testing.T) {
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.getLifecycleStates()
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1438,15 +1419,13 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Run("ShutdownError", func(t *testing.T) {
t.Parallel()
_, client, _, _, closer := setupAgent(t, &client{
manifest: agentsdk.Manifest{
ShutdownScript: "false",
ShutdownScriptTimeout: 30 * time.Second,
},
_, client, _, _, closer := setupAgent(t, agentsdk.Manifest{
ShutdownScript: "false",
ShutdownScriptTimeout: 30 * time.Second,
}, 0)
assert.Eventually(t, func() bool {
return slices.Contains(client.getLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalMedium)
// Start close asynchronously so that we an inspect the state.
@ -1469,7 +1448,7 @@ func TestAgent_Lifecycle(t *testing.T) {
var got []codersdk.WorkspaceAgentLifecycle
assert.Eventually(t, func() bool {
got = client.getLifecycleStates()
got = client.GetLifecycleStates()
return slices.Contains(got, want[len(want)-1])
}, testutil.WaitShort, testutil.IntervalMedium)
@ -1480,17 +1459,18 @@ func TestAgent_Lifecycle(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
expected := "this-is-shutdown"
client := &client{
t: t,
agentID: uuid.New(),
manifest: agentsdk.Manifest{
DERPMap: tailnettest.RunDERPAndSTUN(t),
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
uuid.New(),
agentsdk.Manifest{
DERPMap: derpMap,
StartupScript: "echo 1",
ShutdownScript: "echo " + expected,
},
statsChan: make(chan *agentsdk.Stats),
coordinator: tailnet.NewCoordinator(logger),
}
make(chan *agentsdk.Stats, 50),
tailnet.NewCoordinator(logger),
)
fs := afero.NewMemMapFs()
agent := agent.New(agent.Options{
@ -1536,71 +1516,63 @@ func TestAgent_Startup(t *testing.T) {
t.Run("EmptyDirectory", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "",
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "",
}, 0)
assert.Eventually(t, func() bool {
return client.getStartup().Version != ""
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, "", client.getStartup().ExpandedDirectory)
require.Equal(t, "", client.GetStartup().ExpandedDirectory)
})
t.Run("HomeDirectory", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "~",
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "~",
}, 0)
assert.Eventually(t, func() bool {
return client.getStartup().Version != ""
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, client.getStartup().ExpandedDirectory)
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
})
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "coder/coder",
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "coder/coder",
}, 0)
assert.Eventually(t, func() bool {
return client.getStartup().Version != ""
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.getStartup().ExpandedDirectory)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory)
})
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
t.Parallel()
_, client, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "$HOME",
},
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
StartupScript: "true",
StartupScriptTimeout: 30 * time.Second,
Directory: "$HOME",
}, 0)
assert.Eventually(t, func() bool {
return client.getStartup().Version != ""
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, client.getStartup().ExpandedDirectory)
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
})
}
@ -1617,7 +1589,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
id := uuid.New()
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err)
@ -1719,7 +1691,7 @@ func TestAgent_Dial(t *testing.T) {
}()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0)
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
require.True(t, conn.AwaitReachable(context.Background()))
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
@ -1739,12 +1711,10 @@ func TestAgent_Speedtest(t *testing.T) {
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
derpMap := tailnettest.RunDERPAndSTUN(t)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{
manifest: agentsdk.Manifest{
DERPMap: derpMap,
},
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
DERPMap: derpMap,
}, 0)
defer conn.Close()
res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
@ -1761,17 +1731,16 @@ func TestAgent_Reconnect(t *testing.T) {
defer coordinator.Close()
agentID := uuid.New()
statsCh := make(chan *agentsdk.Stats)
derpMap := tailnettest.RunDERPAndSTUN(t)
client := &client{
t: t,
agentID: agentID,
manifest: agentsdk.Manifest{
statsCh := make(chan *agentsdk.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
},
statsChan: statsCh,
coordinator: coordinator,
}
statsCh,
coordinator,
)
initialized := atomic.Int32{}
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
@ -1786,7 +1755,7 @@ func TestAgent_Reconnect(t *testing.T) {
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
client.lastWorkspaceAgent()
client.LastWorkspaceAgent()
require.Eventually(t, func() bool {
return initialized.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
@ -1798,16 +1767,15 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
coordinator := tailnet.NewCoordinator(logger)
defer coordinator.Close()
client := &client{
t: t,
agentID: uuid.New(),
manifest: agentsdk.Manifest{
client := agenttest.NewClient(t,
uuid.New(),
agentsdk.Manifest{
GitAuthConfigs: 1,
DERPMap: &tailcfg.DERPMap{},
},
statsChan: make(chan *agentsdk.Stats),
coordinator: coordinator,
}
make(chan *agentsdk.Stats, 50),
coordinator,
)
filesystem := afero.NewMemMapFs()
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
@ -1830,7 +1798,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, &client{}, 0)
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
waitGroup := sync.WaitGroup{}
@ -1883,12 +1851,11 @@ func setupSSHSession(
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, &client{
manifest: options,
getServiceBanner: func() (codersdk.ServiceBannerConfig, error) {
conn, _, _, fs, _ := setupAgent(t, options, 0, func(c *agenttest.Client, _ *agent.Options) {
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
return serviceBanner, nil
},
}, 0)
})
})
if prepareFS != nil {
prepareFS(fs)
}
@ -1905,31 +1872,28 @@ func setupSSHSession(
return session
}
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(agent.Options) agent.Options) (
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
*codersdk.WorkspaceAgentConn,
*client,
*agenttest.Client,
<-chan *agentsdk.Stats,
afero.Fs,
io.Closer,
) {
c.t = t
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
if c.manifest.DERPMap == nil {
c.manifest.DERPMap = tailnettest.RunDERPAndSTUN(t)
if metadata.DERPMap == nil {
metadata.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
}
c.coordinator = tailnet.NewCoordinator(logger)
if metadata.AgentID == uuid.Nil {
metadata.AgentID = uuid.New()
}
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = c.coordinator.Close()
_ = coordinator.Close()
})
c.agentID = uuid.New()
c.statsChan = make(chan *agentsdk.Stats, 50)
statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, metadata.AgentID, metadata, statsCh, coordinator)
options := agent.Options{
Client: c,
Filesystem: fs,
@ -1938,7 +1902,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
}
for _, opt := range opts {
options = opt(options)
opt(c, &options)
}
closer := agent.New(options)
@ -1947,7 +1911,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: c.manifest.DERPMap,
DERPMap: metadata.DERPMap,
Logger: logger.Named("client"),
})
require.NoError(t, err)
@ -1961,15 +1925,15 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
})
go func() {
defer close(serveClientDone)
c.coordinator.ServeClient(serverConn, uuid.New(), c.agentID)
coordinator.ServeClient(serverConn, uuid.New(), metadata.AgentID)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
}
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: metadata.AgentID,
})
t.Cleanup(func() {
_ = agentConn.Close()
})
@ -1980,7 +1944,7 @@ func setupAgent(t *testing.T, c *client, ptyTimeout time.Duration, opts ...func(
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn, c, c.statsChan, fs, closer
return agentConn, c, statsCh, fs, closer
}
var dialTestPayload = []byte("dean-was-here123")
@ -2043,146 +2007,6 @@ func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected
}
}
type client struct {
t *testing.T
agentID uuid.UUID
manifest agentsdk.Manifest
metadata map[string]agentsdk.PostMetadataRequest
statsChan chan *agentsdk.Stats
coordinator tailnet.Coordinator
lastWorkspaceAgent func()
patchWorkspaceLogs func() error
getServiceBanner func() (codersdk.ServiceBannerConfig, error)
mu sync.Mutex // Protects following.
lifecycleStates []codersdk.WorkspaceAgentLifecycle
startup agentsdk.PostStartupRequest
logs []agentsdk.StartupLog
}
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}
func (c *client) Listen(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
c.lastWorkspaceAgent = func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
}
c.t.Cleanup(c.lastWorkspaceAgent)
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
close(closed)
}()
return clientConn, nil
}
func (c *client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(c.statsChan)
return nil
}), nil
}
func (c *client) getLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
c.mu.Lock()
defer c.mu.Unlock()
return c.lifecycleStates
}
func (c *client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.lifecycleStates = append(c.lifecycleStates, req.State)
return nil
}
func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
return nil
}
func (c *client) getStartup() agentsdk.PostStartupRequest {
c.mu.Lock()
defer c.mu.Unlock()
return c.startup
}
func (c *client) getMetadata() map[string]agentsdk.PostMetadataRequest {
c.mu.Lock()
defer c.mu.Unlock()
return maps.Clone(c.metadata)
}
func (c *client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.metadata == nil {
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
}
c.metadata[key] = req
return nil
}
func (c *client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.startup = startup
return nil
}
func (c *client) getStartupLogs() []agentsdk.StartupLog {
c.mu.Lock()
defer c.mu.Unlock()
return c.logs
}
func (c *client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.patchWorkspaceLogs != nil {
return c.patchWorkspaceLogs()
}
c.logs = append(c.logs, logs.Logs...)
return nil
}
func (c *client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.getServiceBanner != nil {
return c.getServiceBanner()
}
return codersdk.ServiceBannerConfig{}, nil
}
// tempDirUnixSocket returns a temporary directory that can safely hold unix
// sockets (probably).
//
@ -2214,9 +2038,8 @@ func TestAgent_Metrics_SSH(t *testing.T) {
registry := prometheus.NewRegistry()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, &client{}, 0, func(o agent.Options) agent.Options {
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.PrometheusRegistry = registry
return o
})
sshClient, err := conn.SSHClient(ctx)

189
agent/agenttest/client.go Normal file
View File

@ -0,0 +1,189 @@
package agenttest
import (
"context"
"io"
"net"
"sync"
"testing"
"time"
"github.com/google/uuid"
"golang.org/x/exp/maps"
"cdr.dev/slog"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/tailnet"
)
func NewClient(t testing.TB,
agentID uuid.UUID,
manifest agentsdk.Manifest,
statsChan chan *agentsdk.Stats,
coordinator tailnet.Coordinator,
) *Client {
if manifest.AgentID == uuid.Nil {
manifest.AgentID = agentID
}
return &Client{
t: t,
agentID: agentID,
manifest: manifest,
statsChan: statsChan,
coordinator: coordinator,
}
}
type Client struct {
t testing.TB
agentID uuid.UUID
manifest agentsdk.Manifest
metadata map[string]agentsdk.PostMetadataRequest
statsChan chan *agentsdk.Stats
coordinator tailnet.Coordinator
LastWorkspaceAgent func()
PatchWorkspaceLogs func() error
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
mu sync.Mutex // Protects following.
lifecycleStates []codersdk.WorkspaceAgentLifecycle
startup agentsdk.PostStartupRequest
logs []agentsdk.StartupLog
}
func (c *Client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}
func (c *Client) Listen(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
c.LastWorkspaceAgent = func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
}
c.t.Cleanup(c.LastWorkspaceAgent)
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
close(closed)
}()
return clientConn, nil
}
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(c.statsChan)
return nil
}), nil
}
func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
c.mu.Lock()
defer c.mu.Unlock()
return c.lifecycleStates
}
func (c *Client) PostLifecycle(_ context.Context, req agentsdk.PostLifecycleRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.lifecycleStates = append(c.lifecycleStates, req.State)
return nil
}
func (*Client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
return nil
}
func (c *Client) GetStartup() agentsdk.PostStartupRequest {
c.mu.Lock()
defer c.mu.Unlock()
return c.startup
}
func (c *Client) GetMetadata() map[string]agentsdk.PostMetadataRequest {
c.mu.Lock()
defer c.mu.Unlock()
return maps.Clone(c.metadata)
}
func (c *Client) PostMetadata(_ context.Context, key string, req agentsdk.PostMetadataRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.metadata == nil {
c.metadata = make(map[string]agentsdk.PostMetadataRequest)
}
c.metadata[key] = req
return nil
}
func (c *Client) PostStartup(_ context.Context, startup agentsdk.PostStartupRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.startup = startup
return nil
}
func (c *Client) GetStartupLogs() []agentsdk.StartupLog {
c.mu.Lock()
defer c.mu.Unlock()
return c.logs
}
func (c *Client) PatchStartupLogs(_ context.Context, logs agentsdk.PatchStartupLogs) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.PatchWorkspaceLogs != nil {
return c.PatchWorkspaceLogs()
}
c.logs = append(c.logs, logs.Logs...)
return nil
}
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
c.mu.Lock()
defer c.mu.Unlock()
c.GetServiceBannerFunc = f
}
func (c *Client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.GetServiceBannerFunc != nil {
return c.GetServiceBannerFunc()
}
return codersdk.ServiceBannerConfig{}, nil
}
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}

5
coderd/apidoc/docs.go generated
View File

@ -5961,6 +5961,9 @@ const docTemplate = `{
"agentsdk.Manifest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"apps": {
"type": "array",
"items": {
@ -7617,6 +7620,7 @@ const docTemplate = `{
"workspace_actions",
"tailnet_ha_coordinator",
"convert-to-oidc",
"single_tailnet",
"workspace_build_logs_ui"
],
"x-enum-varnames": [
@ -7624,6 +7628,7 @@ const docTemplate = `{
"ExperimentWorkspaceActions",
"ExperimentTailnetHACoordinator",
"ExperimentConvertToOIDC",
"ExperimentSingleTailnet",
"ExperimentWorkspaceBuildLogsUI"
]
},

View File

@ -5251,6 +5251,9 @@
"agentsdk.Manifest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"apps": {
"type": "array",
"items": {
@ -6818,6 +6821,7 @@
"workspace_actions",
"tailnet_ha_coordinator",
"convert-to-oidc",
"single_tailnet",
"workspace_build_logs_ui"
],
"x-enum-varnames": [
@ -6825,6 +6829,7 @@
"ExperimentWorkspaceActions",
"ExperimentTailnetHACoordinator",
"ExperimentConvertToOIDC",
"ExperimentSingleTailnet",
"ExperimentWorkspaceBuildLogsUI"
]
},

View File

@ -364,8 +364,23 @@ func New(options *Options) *API {
}
api.Auditor.Store(&options.Auditor)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) {
api.agentProvider, err = NewServerTailnet(api.ctx,
options.Logger,
options.DERPServer,
options.DERPMap,
&api.TailnetCoordinator,
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
)
if err != nil {
panic("failed to setup server tailnet: " + err.Error())
}
} else {
api.agentProvider = &wsconncache.AgentProvider{
Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
}
}
api.workspaceAppServer = &workspaceapps.Server{
Logger: options.Logger.Named("workspaceapps"),
@ -377,7 +392,7 @@ func New(options *Options) *API {
RealIPConfig: options.RealIPConfig,
SignedTokenProvider: api.WorkspaceAppsProvider,
WorkspaceConnCache: api.workspaceAgentCache,
AgentProvider: api.agentProvider,
AppSecurityKey: options.AppSecurityKey,
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
@ -921,10 +936,10 @@ type API struct {
derpCloseFunc func()
metricsCache *metricscache.Cache
workspaceAgentCache *wsconncache.Cache
updateChecker *updatecheck.Checker
WorkspaceAppsProvider workspaceapps.SignedTokenProvider
workspaceAppServer *workspaceapps.Server
agentProvider workspaceapps.AgentProvider
// Experiments contains the list of experiments currently enabled.
// This is used to gate features that are not yet ready for production.
@ -951,7 +966,8 @@ func (api *API) Close() error {
if coordinator != nil {
_ = (*coordinator).Close()
}
return api.workspaceAgentCache.Close()
_ = api.agentProvider.Close()
return nil
}
func compressHandler(h http.Handler) http.Handler {

View File

@ -109,6 +109,7 @@ type Options struct {
GitAuthConfigs []*gitauth.Config
TrialGenerator func(context.Context, string) error
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report
HealthcheckTimeout time.Duration

View File

@ -302,7 +302,7 @@ func TestAgents(t *testing.T) {
coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug))
coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{}
coordinatorPtr.Store(&coordinator)
derpMap := tailnettest.RunDERPAndSTUN(t)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests
registry := prometheus.NewRegistry()

339
coderd/tailnet.go Normal file
View File

@ -0,0 +1,339 @@
package coderd
import (
"bufio"
"context"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
)
var tailnetTransport *http.Transport
func init() {
var valid bool
tailnetTransport, valid = http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
}
// NewServerTailnet creates a new tailnet intended for use by coderd. It
// automatically falls back to wsconncache if a legacy agent is encountered.
func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
derpServer *derp.Server,
derpMap *tailcfg.DERPMap,
coord *atomic.Pointer[tailnet.Coordinator],
cache *wsconncache.Cache,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: derpMap,
Logger: logger,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
serverCtx, cancel := context.WithCancel(ctx)
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
coord: coord,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
tn.transport.MaxIdleConnsPerHost = 10
tn.transport.MaxIdleConns = 0
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
tn.agentConn.Store(&agentConn)
err = tn.getAgentConn().UpdateSelf(conn.Node())
if err != nil {
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
}
conn.SetNodeCallback(func(node *tailnet.Node) {
err := tn.getAgentConn().UpdateSelf(node)
if err != nil {
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
})
// This is set to allow local DERP traffic to be proxied through memory
// instead of needing to hit the external access URL. Don't use the ctx
// given in this callback, it's only valid while connecting.
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
go tn.watchAgentUpdates()
go tn.expireOldAgents()
return tn, nil
}
func (s *ServerTailnet) expireOldAgents() {
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}
s.nodesMu.Lock()
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentNodes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
_ = agentConn
// err := agentConn.UnsubscribeAgent(agentID)
// if err != nil {
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
// }
// delete(s.agentNodes, agentID)
// TODO(coadler): actually remove from the netmap, then reenable
// the above
}
}
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) watchAgentUpdates() {
for {
conn := s.getAgentConn()
nodes, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.reinitCoordinator()
continue
}
return
}
err := s.conn.UpdateNodes(nodes, false)
if err != nil {
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
return
}
}
}
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
return *s.agentConn.Load()
}
func (s *ServerTailnet) reinitCoordinator() {
s.nodesMu.Lock()
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
s.agentConn.Store(&agentConn)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.nodesMu.Unlock()
}
type ServerTailnet struct {
ctx context.Context
cancel func()
logger slog.Logger
conn *tailnet.Conn
coord *atomic.Pointer[tailnet.Coordinator]
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentNodes is a map of agent tailnetNodes the server wants to keep a
// connection to. It contains the last time the agent was connected to.
agentNodes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
transport *http.Transport
}
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
})
}
proxy.Director = s.director(agentID, proxy.Director)
proxy.Transport = s.transport
return proxy, func() {}, nil
}
type agentIDKey struct{}
// director makes sure agentIDKey is set on the context in the reverse proxy.
// This allows the transport to correctly identify which agent to dial to.
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
return func(req *http.Request) {
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
*req = *req.WithContext(ctx)
prev(req)
}
}
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
if !ok {
return nil, xerrors.Errorf("no agent id attached")
}
return s.DialAgentNetConn(ctx, agentID, network, addr)
}
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentNodes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentNodes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
var (
conn *codersdk.WorkspaceAgentConn
ret = func() {}
)
if s.getAgentConn().AgentIsLegacy(agentID) {
s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID))
cconn, release, err := s.cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err)
}
conn = cconn.WorkspaceAgentConn
ret = release
} else {
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
}
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
reachable := conn.AwaitReachable(ctx)
if !reachable {
ret()
conn.Close()
return nil, nil, xerrors.New("agent is unreachable")
}
return conn, ret, nil
}
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
conn, release, err := s.AgentConn(ctx, agentID)
if err != nil {
return nil, xerrors.Errorf("acquire agent conn: %w", err)
}
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
nc, err := conn.DialContext(ctx, network, addr)
if err != nil {
release()
conn.Close()
return nil, xerrors.Errorf("dial context: %w", err)
}
return &netConnCloser{Conn: nc, close: func() {
release()
conn.Close()
}}, err
}
type netConnCloser struct {
net.Conn
close func()
}
func (c *netConnCloser) Close() error {
c.close()
return c.Conn.Close()
}
func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.cache.Close()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
return nil
}

207
coderd/tailnet_test.go Normal file
View File

@ -0,0 +1,207 @@
package coderd_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/agent/agenttest"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
func TestServerTailnet_AgentConn_OK(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Connect through the ServerTailnet
agentID, _, serverTailnet := setupAgent(t, nil)
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
require.NoError(t, err)
defer release()
assert.True(t, conn.AwaitReachable(ctx))
}
func TestServerTailnet_AgentConn_Legacy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
})
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
require.NoError(t, err)
defer release()
assert.True(t, conn.AwaitReachable(ctx))
}
func TestServerTailnet_ReverseProxy_OK(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, nil)
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
require.NoError(t, err)
defer release()
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
})
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
require.NoError(t, err)
defer release()
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
DERPMap: derpMap,
}
var coordPtr atomic.Pointer[tailnet.Coordinator]
coord := tailnet.NewCoordinator(logger)
coordPtr.Store(&coord)
t.Cleanup(func() {
_ = coord.Close()
})
c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
options := agent.Options{
Client: c,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
Addresses: agentAddresses,
}
ag := agent.New(options)
t.Cleanup(func() {
_ = ag.Close()
})
// Wait for the agent to connect.
require.Eventually(t, func() bool {
return coord.Node(manifest.AgentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
Logger: logger.Named("client"),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
<-serveClientDone
})
go func() {
defer close(serveClientDone)
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error { return codersdk.ErrSkipClose },
}), nil
}, 0)
serverTailnet, err := coderd.NewServerTailnet(
context.Background(),
logger,
derpServer,
manifest.DERPMap,
&coordPtr,
cache,
)
require.NoError(t, err)
t.Cleanup(func() {
_ = serverTailnet.Close()
})
return manifest.AgentID, ag, serverTailnet
}

View File

@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
}
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{
AgentID: apiAgent.ID,
Apps: convertApps(dbApps),
DERPMap: api.DERPMap,
GitAuthConfigs: len(api.GitAuthConfigs),
@ -654,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
return
}
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error dialing workspace agent.",
@ -729,7 +730,9 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
}
func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
// Deprecated: use api.tailnet.AgentConn instead.
// See: https://github.com/coder/coder/issues/8218
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
clientConn, serverConn := net.Pipe()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
@ -765,14 +768,16 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac
return nil
})
conn.SetNodeCallback(sendNodes)
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error {
cancel()
_ = clientConn.Close()
_ = serverConn.Close()
return nil
},
}
})
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {

View File

@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques
return resp, err
}
func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
t.Helper()
var resp *http.Response
var err error
require.Eventually(t, func() bool {

View File

@ -23,7 +23,6 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/util/slice"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
)
@ -61,6 +60,22 @@ var nonCanonicalHeaders = map[string]string{
"Sec-Websocket-Version": "Sec-WebSocket-Version",
}
type AgentProvider interface {
// ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests
// to the specified agent.
//
// TODO: after wsconncache is deleted this doesn't need to return an error.
ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error)
// AgentConn returns a new connection to the specified agent.
//
// TODO: after wsconncache is deleted this doesn't need to return a release
// func.
AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error)
Close() error
}
// Server serves workspace apps endpoints, including:
// - Path-based apps
// - Subdomain app middleware
@ -83,7 +98,6 @@ type Server struct {
RealIPConfig *httpmw.RealIPConfig
SignedTokenProvider SignedTokenProvider
WorkspaceConnCache *wsconncache.Cache
AppSecurityKey SecurityKey
// DisablePathApps disables path-based apps. This is a security feature as path
@ -95,6 +109,8 @@ type Server struct {
DisablePathApps bool
SecureAuthCookie bool
AgentProvider AgentProvider
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
}
@ -106,8 +122,8 @@ func (s *Server) Close() error {
s.websocketWaitGroup.Wait()
s.websocketWaitMutex.Unlock()
// The caller must close the SignedTokenProvider (if necessary) and the
// wsconncache.
// The caller must close the SignedTokenProvider and the AgentProvider (if
// necessary).
return nil
}
@ -517,18 +533,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
r.URL.Path = path
appURL.RawQuery = ""
proxy := httputil.NewSingleHostReverseProxy(appURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: s.DashboardURL.String(),
})
}
conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID)
if err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway,
@ -540,7 +545,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
proxy.ModifyResponse = func(r *http.Response) error {
r.Header.Del(httpmw.AccessControlAllowOriginHeader)
@ -658,13 +662,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
go httpapi.Heartbeat(ctx, conn)
agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
if err != nil {
log.Debug(ctx, "dial workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer release()
defer agentConn.Close()
log.Debug(ctx, "dialed workspace agent")
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil {

View File

@ -1,9 +1,12 @@
// Package wsconncache caches workspace agent connections by UUID.
// Deprecated: Use ServerTailnet instead.
package wsconncache
import (
"context"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
@ -13,13 +16,57 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
)
// New creates a new workspace connection cache that closes
// connections after the inactive timeout provided.
type AgentProvider struct {
Cache *Cache
}
func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
conn, rel, err := a.Cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
}
return conn.WorkspaceAgentConn, rel, nil
}
func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
})
}
conn, release, err := a.Cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
}
proxy.Transport = conn.HTTPTransport()
return proxy, release, nil
}
func (a *AgentProvider) Close() error {
return a.Cache.Close()
}
// New creates a new workspace connection cache that closes connections after
// the inactive timeout provided.
//
// Agent connections are cached due to WebRTC negotiation
// taking a few hundred milliseconds.
// Agent connections are cached due to Wireguard negotiation taking a few
// hundred milliseconds, depending on latency.
//
// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased
// out because it creates a unique Tailnet for each agent.
// See: https://github.com/coder/coder/issues/8218
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
if inactiveTimeout == 0 {
inactiveTimeout = 5 * time.Minute

View File

@ -157,22 +157,23 @@ func TestCache(t *testing.T) {
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
t.Helper()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
manifest.DERPMap = tailnettest.RunDERPAndSTUN(t)
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
agentID := uuid.New()
manifest.AgentID = uuid.New()
closer := agent.New(agent.Options{
Client: &client{
t: t,
agentID: agentID,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
},
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
})
t.Cleanup(func() {
_ = closer.Close()
@ -189,14 +190,15 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
_ = serverConn.Close()
_ = conn.Close()
})
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
}
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
})
t.Cleanup(func() {
_ = agentConn.Close()
})

View File

@ -84,6 +84,7 @@ func (c *Client) PostMetadata(ctx context.Context, key string, req PostMetadataR
}
type Manifest struct {
AgentID uuid.UUID `json:"agent_id"`
// GitAuthConfigs stores the number of Git configurations
// the Coder deployment has. If this number is >0, we
// set up special configuration in the workspace.

View File

@ -1764,6 +1764,12 @@ const (
// oidc.
ExperimentConvertToOIDC Experiment = "convert-to-oidc"
// ExperimentSingleTailnet replaces workspace connections inside coderd to
// all use a single tailnet, instead of the previous behavior of creating a
// single tailnet for each agent.
// WARNING: This cannot be enabled when using HA.
ExperimentSingleTailnet Experiment = "single_tailnet"
ExperimentWorkspaceBuildLogsUI Experiment = "workspace_build_logs_ui"
// Add new experiments here!
// ExperimentExample Experiment = "example"

View File

@ -15,6 +15,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/ipn/ipnstate"
@ -27,8 +28,14 @@ import (
// WorkspaceAgentIP is a static IPv6 address with the Tailscale prefix that is used to route
// connections from clients to this node. A dynamic address is not required because a Tailnet
// client only dials a single agent at a time.
//
// Deprecated: use tailnet.IP() instead. This is kept for backwards
// compatibility with wsconncache.
// See: https://github.com/coder/coder/issues/8218
var WorkspaceAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
var ErrSkipClose = xerrors.New("skip tailnet close")
const (
WorkspaceAgentSSHPort = tailnet.WorkspaceAgentSSHPort
WorkspaceAgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort
@ -120,11 +127,38 @@ func init() {
}
}
// NewWorkspaceAgentConn creates a new WorkspaceAgentConn. `conn` may be unique
// to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the
// conn is shared and closing it is undesirable, you may return ErrNoClose from
// opts.CloseFunc. This will ensure the underlying conn is not closed.
func NewWorkspaceAgentConn(conn *tailnet.Conn, opts WorkspaceAgentConnOptions) *WorkspaceAgentConn {
return &WorkspaceAgentConn{
Conn: conn,
opts: opts,
}
}
// WorkspaceAgentConn represents a connection to a workspace agent.
// @typescript-ignore WorkspaceAgentConn
type WorkspaceAgentConn struct {
*tailnet.Conn
CloseFunc func()
opts WorkspaceAgentConnOptions
}
// @typescript-ignore WorkspaceAgentConnOptions
type WorkspaceAgentConnOptions struct {
AgentID uuid.UUID
AgentIP netip.Addr
CloseFunc func() error
}
func (c *WorkspaceAgentConn) agentAddress() netip.Addr {
var emptyIP netip.Addr
if cmp := c.opts.AgentIP.Compare(emptyIP); cmp != 0 {
return c.opts.AgentIP
}
return tailnet.IPFromUUID(c.opts.AgentID)
}
// AwaitReachable waits for the agent to be reachable.
@ -132,7 +166,7 @@ func (c *WorkspaceAgentConn) AwaitReachable(ctx context.Context) bool {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
return c.Conn.AwaitReachable(ctx, WorkspaceAgentIP)
return c.Conn.AwaitReachable(ctx, c.agentAddress())
}
// Ping pings the agent and returns the round-trip time.
@ -141,13 +175,20 @@ func (c *WorkspaceAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ip
ctx, span := tracing.StartSpan(ctx)
defer span.End()
return c.Conn.Ping(ctx, WorkspaceAgentIP)
return c.Conn.Ping(ctx, c.agentAddress())
}
// Close ends the connection to the workspace agent.
func (c *WorkspaceAgentConn) Close() error {
if c.CloseFunc != nil {
c.CloseFunc()
var cerr error
if c.opts.CloseFunc != nil {
cerr = c.opts.CloseFunc()
if xerrors.Is(cerr, ErrSkipClose) {
return nil
}
}
if cerr != nil {
return multierror.Append(cerr, c.Conn.Close())
}
return c.Conn.Close()
}
@ -176,10 +217,12 @@ type ReconnectingPTYRequest struct {
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentReconnectingPTYPort))
if err != nil {
return nil, err
}
@ -209,10 +252,12 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSSHPort))
}
// SSHClient calls SSH to create a client that uses a weak cipher
@ -220,10 +265,12 @@ func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
netConn, err := c.SSH(ctx)
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
@ -233,6 +280,7 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
if err != nil {
return nil, xerrors.Errorf("ssh conn: %w", err)
}
return ssh.NewClient(sshConn, channels, requests), nil
}
@ -240,17 +288,21 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
speedConn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), WorkspaceAgentSpeedtestPort))
if err != nil {
return nil, xerrors.Errorf("dial speedtest: %w", err)
}
results, err := speedtest.RunClientWithConn(direction, duration, speedConn)
if err != nil {
return nil, xerrors.Errorf("run speedtest: %w", err)
}
return results, err
}
@ -259,19 +311,23 @@ func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.
func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if network == "unix" {
return nil, xerrors.New("network must be tcp or udp")
}
_, rawPort, _ := net.SplitHostPort(addr)
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
if network == "udp" {
_, rawPort, _ := net.SplitHostPort(addr)
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(c.agentAddress(), uint16(port))
switch network {
case "tcp":
return c.Conn.DialContextTCP(ctx, ipp)
case "udp":
return c.Conn.DialContextUDP(ctx, ipp)
default:
return nil, xerrors.Errorf("unknown network %q", network)
}
return c.Conn.DialContextTCP(ctx, ipp)
}
type WorkspaceAgentListeningPortsResponse struct {
@ -309,7 +365,8 @@ func (c *WorkspaceAgentConn) ListeningPorts(ctx context.Context) (WorkspaceAgent
func (c *WorkspaceAgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
host := net.JoinHostPort(WorkspaceAgentIP.String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort))
host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(WorkspaceAgentHTTPAPIServerPort))
url := fmt.Sprintf("http://%s%s", host, path)
req, err := http.NewRequestWithContext(ctx, method, url, body)
@ -332,13 +389,14 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
if network != "tcp" {
return nil, xerrors.Errorf("network must be tcp")
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
}
// Verify that host is TailnetIP and port is
// TailnetStatisticsPort.
if host != WorkspaceAgentIP.String() || port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) {
// Verify that the port is TailnetStatisticsPort.
if port != strconv.Itoa(WorkspaceAgentHTTPAPIServerPort) {
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
}
@ -346,7 +404,12 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
ipAddr, err := netip.ParseAddr(host)
if err != nil {
return nil, xerrors.Errorf("parse host addr: %w", err)
}
conn, err := c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(ipAddr, WorkspaceAgentHTTPAPIServerPort))
if err != nil {
return nil, xerrors.Errorf("dial http api: %w", err)
}

View File

@ -307,8 +307,8 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
options.Logger.Debug(ctx, "failed to dial", slog.Error(err))
continue
}
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
sendNode, errChan := tailnet.ServeCoordinator(websocket.NetConn(ctx, ws, websocket.MessageBinary), func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, false)
})
conn.SetNodeCallback(sendNode)
options.Logger.Debug(ctx, "serving coordinator")
@ -330,13 +330,15 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
return nil, err
}
agentConn = &WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
agentConn = NewWorkspaceAgentConn(conn, WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error {
cancel()
<-closed
return conn.Close()
},
}
})
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())

View File

@ -292,6 +292,7 @@ curl -X GET http://coder-server:8080/api/v2/workspaceagents/me/manifest \
```json
{
"agent_id": "string",
"apps": [
{
"command": "string",

View File

@ -161,6 +161,7 @@
```json
{
"agent_id": "string",
"apps": [
{
"command": "string",
@ -260,6 +261,7 @@
| Name | Type | Required | Restrictions | Description |
| ---------------------------- | ------------------------------------------------------------------------------------------------- | -------- | ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `agent_id` | string | false | | |
| `apps` | array of [codersdk.WorkspaceApp](#codersdkworkspaceapp) | false | | |
| `derpmap` | [tailcfg.DERPMap](#tailcfgderpmap) | false | | |
| `directory` | string | false | | |
@ -2543,6 +2545,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
| `workspace_actions` |
| `tailnet_ha_coordinator` |
| `convert-to-oidc` |
| `single_tailnet` |
| `workspace_build_logs_ui` |
## codersdk.Feature

View File

@ -6,9 +6,8 @@ import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clibase"
"github.com/coder/coder/coderd/coderdtest"

View File

@ -17,6 +17,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/codersdk"
agpl "github.com/coder/coder/tailnet"
)
@ -37,9 +38,12 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
closeFunc: cancelFunc,
close: make(chan struct{}),
nodes: map[uuid.UUID]*agpl.Node{},
agentSockets: map[uuid.UUID]*agpl.TrackedConn{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn{},
agentSockets: map[uuid.UUID]agpl.Queue{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]agpl.Queue{},
agentNameCache: nameCache,
clients: map[uuid.UUID]agpl.Queue{},
clientsToAgents: map[uuid.UUID]map[uuid.UUID]agpl.Queue{},
legacyAgents: map[uuid.UUID]struct{}{},
}
if err := coord.runPubsub(ctx); err != nil {
@ -49,6 +53,56 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err
return coord, nil
}
func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
m := (&agpl.MultiAgent{
ID: id,
Logger: c.log,
AgentIsLegacyFunc: c.agentIsLegacy,
OnSubscribe: c.clientSubscribeToAgent,
OnNodeUpdate: c.clientNodeUpdate,
OnRemove: c.clientDisconnected,
}).Init()
c.addClient(id, m)
return m
}
func (c *haCoordinator) addClient(id uuid.UUID, q agpl.Queue) {
c.mutex.Lock()
c.clients[id] = q
c.clientsToAgents[id] = map[uuid.UUID]agpl.Queue{}
c.mutex.Unlock()
}
func (c *haCoordinator) clientSubscribeToAgent(enq agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.initOrSetAgentConnectionSocketLocked(agentID, enq)
node := c.nodes[enq.UniqueID()]
if node != nil {
err := c.sendNodeToAgentLocked(agentID, node)
if err != nil {
return nil, xerrors.Errorf("handle client update: %w", err)
}
}
agentNode, ok := c.nodes[agentID]
// If we have the node locally, give it back to the multiagent.
if ok {
return agentNode, nil
}
// If we don't have the node locally, notify other coordinators.
err := c.publishClientHello(agentID)
if err != nil {
return nil, xerrors.Errorf("publish client hello: %w", err)
}
// nolint:nilnil
return nil, nil
}
type haCoordinator struct {
id uuid.UUID
log slog.Logger
@ -60,14 +114,26 @@ type haCoordinator struct {
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*agpl.Node
// agentSockets maps agent IDs to their open websocket.
agentSockets map[uuid.UUID]*agpl.TrackedConn
agentSockets map[uuid.UUID]agpl.Queue
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*agpl.TrackedConn
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]agpl.Queue
// clients holds a map of all clients connected to the coordinator. This is
// necessary because a client may not be subscribed into any agents.
clients map[uuid.UUID]agpl.Queue
// clientsToAgents is an index of clients to all of their subscribed agents.
clientsToAgents map[uuid.UUID]map[uuid.UUID]agpl.Queue
// agentNameCache holds a cache of agent names. If one of them disappears,
// it's helpful to have a name cached for debugging.
agentNameCache *lru.Cache[uuid.UUID, string]
// legacyAgents holda a mapping of all agents detected as legacy, meaning
// they only listen on codersdk.WorkspaceAgentIP. They aren't compatible
// with the new ServerTailnet, so they must be connected through
// wsconncache.
legacyAgents map[uuid.UUID]struct{}
}
// Node returns an in-memory node by ID.
@ -88,61 +154,35 @@ func (c *haCoordinator) agentLogger(agent uuid.UUID) slog.Logger {
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.clientLogger(id, agent)
c.mutex.Lock()
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
connectionSockets = map[uuid.UUID]*agpl.TrackedConn{}
c.agentToConnectionSockets[agent] = connectionSockets
}
logger := c.clientLogger(id, agentID)
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0)
// Insert this connection into a map so the agent
// can publish node updates.
connectionSockets[id] = tc
defer tc.Close()
// When a new connection is requested, we update it with the latest
// node of the agent. This allows the connection to establish.
node, ok := c.nodes[agent]
if ok {
err := tc.Enqueue([]*agpl.Node{node})
c.mutex.Unlock()
c.addClient(id, tc)
defer c.clientDisconnected(id)
agentNode, err := c.clientSubscribeToAgent(tc, agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
if agentNode != nil {
err := tc.Enqueue([]*agpl.Node{agentNode})
if err != nil {
return xerrors.Errorf("enqueue node: %w", err)
}
} else {
c.mutex.Unlock()
err := c.publishClientHello(agent)
if err != nil {
return xerrors.Errorf("publish client hello: %w", err)
logger.Debug(ctx, "enqueue initial node", slog.Error(err))
}
}
go tc.SendUpdates()
defer func() {
c.mutex.Lock()
defer c.mutex.Unlock()
// Clean all traces of this connection from the map.
delete(c.nodes, id)
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
return
}
delete(connectionSockets, id)
if len(connectionSockets) != 0 {
return
}
delete(c.agentToConnectionSockets, agent)
}()
go tc.SendUpdates()
decoder := json.NewDecoder(conn)
// Indefinitely handle messages from the client websocket.
for {
err := c.handleNextClientMessage(id, agent, decoder)
err := c.handleNextClientMessage(id, decoder)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
return nil
@ -152,35 +192,90 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID
}
}
func (c *haCoordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
func (c *haCoordinator) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq agpl.Queue) {
connectionSockets, ok := c.agentToConnectionSockets[agentID]
if !ok {
connectionSockets = map[uuid.UUID]agpl.Queue{}
c.agentToConnectionSockets[agentID] = connectionSockets
}
connectionSockets[enq.UniqueID()] = enq
c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID]
}
func (c *haCoordinator) clientDisconnected(id uuid.UUID) {
c.mutex.Lock()
defer c.mutex.Unlock()
for agentID := range c.clientsToAgents[id] {
// Clean all traces of this connection from the map.
delete(c.nodes, id)
connectionSockets, ok := c.agentToConnectionSockets[agentID]
if !ok {
return
}
delete(connectionSockets, id)
if len(connectionSockets) != 0 {
return
}
delete(c.agentToConnectionSockets, agentID)
}
delete(c.clients, id)
delete(c.clientsToAgents, id)
}
func (c *haCoordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
var node agpl.Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
return c.clientNodeUpdate(id, &node)
}
func (c *haCoordinator) clientNodeUpdate(id uuid.UUID, node *agpl.Node) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Update the node of this client in our in-memory map. If an agent entirely
// shuts down and reconnects, it needs to be aware of all clients attempting
// to establish connections.
c.nodes[id] = &node
// Write the new node from this client to the actively connected agent.
agentSocket, ok := c.agentSockets[agent]
c.nodes[id] = node
for agentID, agentSocket := range c.clientsToAgents[id] {
if agentSocket == nil {
// If we don't own the agent locally, send it over pubsub to a node that
// owns the agent.
err := c.publishNodesToAgent(agentID, []*agpl.Node{node})
if err != nil {
c.log.Error(context.Background(), "publish node to agent", slog.Error(err), slog.F("agent_id", agentID))
}
} else {
// Write the new node from this client to the actively connected agent.
err := agentSocket.Enqueue([]*agpl.Node{node})
if err != nil {
c.log.Error(context.Background(), "enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
}
return nil
}
func (c *haCoordinator) sendNodeToAgentLocked(agentID uuid.UUID, node *agpl.Node) error {
agentSocket, ok := c.agentSockets[agentID]
if !ok {
c.mutex.Unlock()
// If we don't own the agent locally, send it over pubsub to a node that
// owns the agent.
err := c.publishNodesToAgent(agent, []*agpl.Node{&node})
err := c.publishNodesToAgent(agentID, []*agpl.Node{node})
if err != nil {
return xerrors.Errorf("publish node to agent")
}
return nil
}
err = agentSocket.Enqueue([]*agpl.Node{&node})
c.mutex.Unlock()
err := agentSocket.Enqueue([]*agpl.Node{node})
if err != nil {
return xerrors.Errorf("enqueu nodes: %w", err)
return xerrors.Errorf("enqueue node: %w", err)
}
return nil
}
@ -202,7 +297,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
// dead.
oldAgentSocket, ok := c.agentSockets[id]
if ok {
overwrites = oldAgentSocket.Overwrites + 1
overwrites = oldAgentSocket.Overwrites() + 1
_ = oldAgentSocket.Close()
}
// This uniquely identifies a connection that belongs to this goroutine.
@ -219,6 +314,9 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
}
}
c.agentSockets[id] = tc
for clientID := range c.agentToConnectionSockets[id] {
c.clientsToAgents[clientID][id] = tc
}
c.mutex.Unlock()
go tc.SendUpdates()
@ -234,10 +332,13 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
// Only delete the connection if it's ours. It could have been
// overwritten.
if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique {
if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique {
delete(c.agentSockets, id)
delete(c.nodes, id)
}
for clientID := range c.agentToConnectionSockets[id] {
c.clientsToAgents[clientID][id] = nil
}
}()
decoder := json.NewDecoder(conn)
@ -285,6 +386,13 @@ func (c *haCoordinator) handleClientHello(id uuid.UUID) error {
return c.publishAgentToNodes(id, node)
}
func (c *haCoordinator) agentIsLegacy(agentID uuid.UUID) bool {
c.mutex.RLock()
_, ok := c.legacyAgents[agentID]
c.mutex.RUnlock()
return ok
}
func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (*agpl.Node, error) {
var node agpl.Node
err := decoder.Decode(&node)
@ -293,6 +401,11 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
}
c.mutex.Lock()
// Keep a cache of all legacy agents.
if len(node.Addresses) > 0 && node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP {
c.legacyAgents[id] = struct{}{}
}
oldNode := c.nodes[id]
if oldNode != nil {
if oldNode.AsOf.After(node.AsOf) {
@ -311,7 +424,9 @@ func (c *haCoordinator) handleAgentUpdate(id uuid.UUID, decoder *json.Decoder) (
for _, connectionSocket := range connectionSockets {
_ = connectionSocket.Enqueue([]*agpl.Node{&node})
}
c.mutex.Unlock()
return &node, nil
}
@ -334,20 +449,18 @@ func (c *haCoordinator) Close() error {
for _, socket := range c.agentSockets {
socket := socket
go func() {
_ = socket.Close()
_ = socket.CoordinatorClose()
wg.Done()
}()
}
for _, connMap := range c.agentToConnectionSockets {
wg.Add(len(connMap))
for _, socket := range connMap {
socket := socket
go func() {
_ = socket.Close()
wg.Done()
}()
}
wg.Add(len(c.clients))
for _, client := range c.clients {
client := client
go func() {
_ = client.CoordinatorClose()
wg.Done()
}()
}
wg.Wait()
@ -422,13 +535,12 @@ func (c *haCoordinator) runPubsub(ctx context.Context) error {
}
go func() {
for {
var message []byte
select {
case <-ctx.Done():
return
case message = <-messageQueue:
case message := <-messageQueue:
c.handlePubsubMessage(ctx, message)
}
c.handlePubsubMessage(ctx, message)
}
}()

View File

@ -125,6 +125,11 @@ func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store
return c, nil
}
func (c *pgCoord) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
_, _ = c, id
panic("not implemented") // TODO: Implement
}
func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) {
// TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed.
w.WriteHeader(http.StatusOK)

View File

@ -183,9 +183,12 @@ func New(ctx context.Context, opts *Options) (*Server, error) {
SecurityKey: secKey,
Logger: s.Logger.Named("proxy_token_provider"),
},
WorkspaceConnCache: wsconncache.New(s.DialWorkspaceAgent, 0),
AppSecurityKey: secKey,
AppSecurityKey: secKey,
// TODO: Convert wsproxy to use coderd.ServerTailnet.
AgentProvider: &wsconncache.AgentProvider{
Cache: wsconncache.New(s.DialWorkspaceAgent, 0),
},
DisablePathApps: opts.DisablePathApps,
SecureAuthCookie: opts.SecureAuthCookie,
}
@ -273,6 +276,7 @@ func (s *Server) Close() error {
tmp, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = s.SDKClient.WorkspaceProxyGoingAway(tmp)
_ = s.AppServer.AgentProvider.Close()
return s.AppServer.Close()
}

View File

@ -6,7 +6,6 @@ import (
"io"
"net"
"net/http"
"net/netip"
"net/url"
"strconv"
"time"
@ -377,7 +376,10 @@ func agentHTTPClient(conn *codersdk.WorkspaceAgentConn) *http.Client {
if err != nil {
return nil, xerrors.Errorf("parse port %q: %w", port, err)
}
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.WorkspaceAgentIP, uint16(portUint)))
// Addr doesn't matter here, besides the port. DialContext will
// automatically choose the right IP to dial.
return conn.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", portUint))
},
},
}

View File

@ -9,6 +9,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
@ -243,7 +244,7 @@ func Test_Runner(t *testing.T) {
func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) {
t.Helper()
client = coderdtest.New(t, &coderdtest.Options{
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
@ -282,12 +283,16 @@ func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID)
agentClient.SetSessionToken(authToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent"),
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("agent").Leveled(slog.LevelDebug),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
require.Eventually(t, func() bool {
t.Log("agent id", resources[0].Agents[0].ID)
return (*api.TailnetCoordinator.Load()).Node(resources[0].Agents[0].ID) != nil
}, testutil.WaitLong, testutil.IntervalMedium, "agent never connected")
return client, resources[0].Agents[0].ID
}

View File

@ -68,6 +68,7 @@ func TestRun(t *testing.T) {
agentCloser := agent.New(agent.Options{
Client: agentClient,
})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
t.Cleanup(func() {

View File

@ -1431,12 +1431,14 @@ export const Entitlements: Entitlement[] = [
export type Experiment =
| "convert-to-oidc"
| "moons"
| "single_tailnet"
| "tailnet_ha_coordinator"
| "workspace_actions"
| "workspace_build_logs_ui"
export const Experiments: Experiment[] = [
"convert-to-oidc",
"moons",
"single_tailnet",
"tailnet_ha_coordinator",
"workspace_actions",
"workspace_build_logs_ui",

View File

@ -139,6 +139,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
}
}()
IP()
dialer := &tsdial.Dialer{
Logf: Logger(options.Logger.Named("tsdial")),
}
@ -182,10 +183,17 @@ func NewConn(options *Options) (conn *Conn, err error) {
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
netStack, err := netstack.Create(
Logger(options.Logger.Named("netstack")), tunDevice, wireguardEngine, magicConn, dialer, dnsManager)
Logger(options.Logger.Named("netstack")),
tunDevice,
wireguardEngine,
magicConn,
dialer,
dnsManager,
)
if err != nil {
return nil, xerrors.Errorf("create netstack: %w", err)
}
dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
return netStack.DialContextTCP(ctx, dst)
}
@ -203,7 +211,14 @@ func NewConn(options *Options) (conn *Conn, err error) {
localIPs, _ := localIPSet.IPSet()
logIPSet := netipx.IPSetBuilder{}
logIPs, _ := logIPSet.IPSet()
wireguardEngine.SetFilter(filter.New(netMap.PacketFilter, localIPs, logIPs, nil, Logger(options.Logger.Named("packet-filter"))))
wireguardEngine.SetFilter(filter.New(
netMap.PacketFilter,
localIPs,
logIPs,
nil,
Logger(options.Logger.Named("packet-filter")),
))
dialContext, dialCancel := context.WithCancel(context.Background())
server := &Conn{
blockEndpoints: options.BlockEndpoints,
@ -230,6 +245,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
_ = server.Close()
}
}()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.Error(err))
if err != nil {
@ -251,6 +267,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
server.lastMutex.Unlock()
server.sendNode()
})
wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) {
server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni))
server.lastMutex.Lock()
@ -262,6 +279,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
server.lastMutex.Unlock()
server.sendNode()
})
magicConn.SetDERPForcedWebsocketCallback(func(region int, reason string) {
server.logger.Debug(context.Background(), "derp forced websocket", slog.F("region", region), slog.F("reason", reason))
server.lastMutex.Lock()
@ -273,6 +291,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
server.lastMutex.Unlock()
server.sendNode()
})
netStack.ForwardTCPIn = server.forwardTCP
netStack.ForwardTCPSockOpts = server.forwardTCPSockOpts
@ -284,22 +303,30 @@ func NewConn(options *Options) (conn *Conn, err error) {
return server, nil
}
// IP generates a new IP with a static service prefix.
func IP() netip.Addr {
// This is Tailscale's ephemeral service prefix.
// This can be changed easily later-on, because
// all of our nodes are ephemeral.
func maskUUID(uid uuid.UUID) uuid.UUID {
// This is Tailscale's ephemeral service prefix. This can be changed easily
// later-on, because all of our nodes are ephemeral.
// fd7a:115c:a1e0
uid := uuid.New()
uid[0] = 0xfd
uid[1] = 0x7a
uid[2] = 0x11
uid[3] = 0x5c
uid[4] = 0xa1
uid[5] = 0xe0
return uid
}
// IP generates a random IP with a static service prefix.
func IP() netip.Addr {
uid := maskUUID(uuid.New())
return netip.AddrFrom16(uid)
}
// IP generates a new IP from a UUID.
func IPFromUUID(uid uuid.UUID) netip.Addr {
return netip.AddrFrom16(maskUUID(uid))
}
// Conn is an actively listening Wireguard connection.
type Conn struct {
dialContext context.Context
@ -334,6 +361,29 @@ type Conn struct {
trafficStats *connstats.Statistics
}
func (c *Conn) SetAddresses(ips []netip.Prefix) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.netMap.Addresses = ips
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
func (c *Conn) Addresses() []netip.Prefix {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.netMap.Addresses
}
func (c *Conn) SetNodeCallback(callback func(node *Node)) {
c.lastMutex.Lock()
c.nodeCallback = callback
@ -366,32 +416,6 @@ func (c *Conn) SetDERPRegionDialer(dialer func(ctx context.Context, region *tail
c.magicConn.SetDERPRegionDialer(dialer)
}
func (c *Conn) RemoveAllPeers() error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.netMap.Peers = []*tailcfg.Node{}
c.peerMap = map[tailcfg.NodeID]*tailcfg.Node{}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
return xerrors.Errorf("update wireguard config: %w", err)
}
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
if c.isClosed() {
return nil
}
if errors.Is(err, wgengine.ErrNoChanges) {
return nil
}
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
// UpdateNodes connects with a set of peers. This can be constantly updated,
// and peers will continually be reconnected as necessary. If replacePeers is
// true, all peers will be removed before adding the new ones.
@ -423,6 +447,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
}
delete(c.peerMap, peer.ID)
}
for _, node := range nodes {
// If no preferred DERP is provided, we can't reach the node.
if node.PreferredDERP == 0 {
@ -452,17 +477,29 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
}
c.peerMap[node.ID] = peerNode
}
c.netMap.Peers = make([]*tailcfg.Node, 0, len(c.peerMap))
for _, peer := range c.peerMap {
c.netMap.Peers = append(c.netMap.Peers, peer.Clone())
}
netMapCopy := *c.netMap
c.logger.Debug(context.Background(), "updating network map")
c.wireguardEngine.SetNetworkMap(&netMapCopy)
err := c.reconfig()
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}
func (c *Conn) reconfig() error {
cfg, err := nmcfg.WGCfg(c.netMap, Logger(c.logger.Named("wgconfig")), netmap.AllowSingleHosts, "")
if err != nil {
return xerrors.Errorf("update wireguard config: %w", err)
}
err = c.wireguardEngine.Reconfig(cfg, c.wireguardRouter, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
if c.isClosed() {
@ -473,6 +510,7 @@ func (c *Conn) UpdateNodes(nodes []*Node, replacePeers bool) error {
}
return xerrors.Errorf("reconfig: %w", err)
}
return nil
}

View File

@ -23,7 +23,7 @@ func TestMain(m *testing.M) {
func TestTailnet(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap := tailnettest.RunDERPAndSTUN(t)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
conn, err := tailnet.NewConn(&tailnet.Options{
@ -172,7 +172,7 @@ func TestConn_PreferredDERP(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap := tailnettest.RunDERPAndSTUN(t)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w1"),

View File

@ -1,7 +1,6 @@
package tailnet
import (
"bytes"
"context"
"encoding/json"
"errors"
@ -11,17 +10,16 @@ import (
"net/http"
"net/netip"
"sync"
"sync/atomic"
"time"
"cdr.dev/slog"
"github.com/google/uuid"
lru "github.com/hashicorp/golang-lru/v2"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
)
// Coordinator exchanges nodes with agents to establish connections.
@ -44,6 +42,8 @@ type Coordinator interface {
ServeAgent(conn net.Conn, id uuid.UUID, name string) error
// Close closes the coordinator.
Close() error
ServeMultiAgent(id uuid.UUID) MultiAgentConn
}
// Node represents a node in the network.
@ -54,10 +54,11 @@ type Node struct {
AsOf time.Time `json:"as_of"`
// Key is the Wireguard public key of the node.
Key key.NodePublic `json:"key"`
// DiscoKey is used for discovery messages over DERP to establish peer-to-peer connections.
// DiscoKey is used for discovery messages over DERP to establish
// peer-to-peer connections.
DiscoKey key.DiscoPublic `json:"disco"`
// PreferredDERP is the DERP server that peered connections
// should meet at to establish.
// PreferredDERP is the DERP server that peered connections should meet at
// to establish.
PreferredDERP int `json:"preferred_derp"`
// DERPLatency is the latency in seconds to each DERP server.
DERPLatency map[string]float64 `json:"derp_latency"`
@ -68,8 +69,8 @@ type Node struct {
DERPForcedWebsocket map[int]string `json:"derp_forced_websockets"`
// Addresses are the IP address ranges this connection exposes.
Addresses []netip.Prefix `json:"addresses"`
// AllowedIPs specify what addresses can dial the connection.
// We allow all by default.
// AllowedIPs specify what addresses can dial the connection. We allow all
// by default.
AllowedIPs []netip.Prefix `json:"allowed_ips"`
// Endpoints are ip:port combinations that can be used to establish
// peer-to-peer connections.
@ -130,12 +131,33 @@ func NewCoordinator(logger slog.Logger) Coordinator {
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
// This coordinator is incompatible with multiple Coder
// replicas as all node data is in-memory.
// This coordinator is incompatible with multiple Coder replicas as all node
// data is in-memory.
type coordinator struct {
core *core
}
func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
m := (&MultiAgent{
ID: id,
Logger: c.core.logger,
AgentIsLegacyFunc: c.core.agentIsLegacy,
OnSubscribe: c.core.clientSubscribeToAgent,
OnUnsubscribe: c.core.clientUnsubscribeFromAgent,
OnNodeUpdate: c.core.clientNodeUpdate,
OnRemove: c.core.clientDisconnected,
}).Init()
c.core.addClient(id, m)
return m
}
func (c *core) addClient(id uuid.UUID, ma Queue) {
c.mutex.Lock()
c.clients[id] = ma
c.clientsToAgents[id] = map[uuid.UUID]Queue{}
c.mutex.Unlock()
}
// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines;
// it is protected by a mutex to ensure data stay consistent.
type core struct {
@ -146,14 +168,38 @@ type core struct {
// nodes maps agent and connection IDs their respective node.
nodes map[uuid.UUID]*Node
// agentSockets maps agent IDs to their open websocket.
agentSockets map[uuid.UUID]*TrackedConn
agentSockets map[uuid.UUID]Queue
// agentToConnectionSockets maps agent IDs to connection IDs of conns that
// are subscribed to updates for that agent.
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]*TrackedConn
agentToConnectionSockets map[uuid.UUID]map[uuid.UUID]Queue
// clients holds a map of all clients connected to the coordinator. This is
// necessary because a client may not be subscribed into any agents.
clients map[uuid.UUID]Queue
// clientsToAgents is an index of clients to all of their subscribed agents.
clientsToAgents map[uuid.UUID]map[uuid.UUID]Queue
// agentNameCache holds a cache of agent names. If one of them disappears,
// it's helpful to have a name cached for debugging.
agentNameCache *lru.Cache[uuid.UUID, string]
// legacyAgents holda a mapping of all agents detected as legacy, meaning
// they only listen on codersdk.WorkspaceAgentIP. They aren't compatible
// with the new ServerTailnet, so they must be connected through
// wsconncache.
legacyAgents map[uuid.UUID]struct{}
}
type Queue interface {
UniqueID() uuid.UUID
Enqueue(n []*Node) error
Name() string
Stats() (start, lastWrite int64)
Overwrites() int64
// CoordinatorClose is used by the coordinator when closing a Queue. It
// should skip removing itself from the coordinator.
CoordinatorClose() error
Close() error
}
func newCore(logger slog.Logger) *core {
@ -165,128 +211,18 @@ func newCore(logger slog.Logger) *core {
return &core{
logger: logger,
closed: false,
nodes: make(map[uuid.UUID]*Node),
agentSockets: map[uuid.UUID]*TrackedConn{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]*TrackedConn{},
nodes: map[uuid.UUID]*Node{},
agentSockets: map[uuid.UUID]Queue{},
agentToConnectionSockets: map[uuid.UUID]map[uuid.UUID]Queue{},
agentNameCache: nameCache,
legacyAgents: map[uuid.UUID]struct{}{},
clients: map[uuid.UUID]Queue{},
clientsToAgents: map[uuid.UUID]map[uuid.UUID]Queue{},
}
}
var ErrWouldBlock = xerrors.New("would block")
type TrackedConn struct {
ctx context.Context
cancel func()
conn net.Conn
updates chan []*Node
logger slog.Logger
lastData []byte
// ID is an ephemeral UUID used to uniquely identify the owner of the
// connection.
ID uuid.UUID
Name string
Start int64
LastWrite int64
Overwrites int64
}
func (t *TrackedConn) Enqueue(n []*Node) (err error) {
atomic.StoreInt64(&t.LastWrite, time.Now().Unix())
select {
case t.updates <- n:
return nil
default:
return ErrWouldBlock
}
}
// Close the connection and cancel the context for reading node updates from the queue
func (t *TrackedConn) Close() error {
t.cancel()
return t.conn.Close()
}
// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung.
// It is exported so that tests can use it.
const WriteTimeout = time.Second * 5
// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is
// canceled.
func (t *TrackedConn) SendUpdates() {
for {
select {
case <-t.ctx.Done():
t.logger.Debug(t.ctx, "done sending updates")
return
case nodes := <-t.updates:
data, err := json.Marshal(nodes)
if err != nil {
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
return
}
if bytes.Equal(t.lastData, data) {
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data)))
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout))
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err))
_ = t.Close()
return
}
_, err = t.conn.Write(data)
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "could not write nodes to connection",
slog.Error(err), slog.F("nodes", string(data)))
_ = t.Close()
return
}
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data)))
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = t.conn.SetWriteDeadline(time.Time{})
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err))
_ = t.Close()
return
}
t.lastData = data
}
}
}
func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn {
// buffer updates so they don't block, since we hold the
// coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but
// the most pathological cases.
updates := make(chan []*Node, 512)
now := time.Now().Unix()
return &TrackedConn{
ctx: ctx,
conn: conn,
cancel: cancel,
updates: updates,
logger: logger,
ID: id,
Start: now,
LastWrite: now,
Overwrites: overwrites,
}
}
// Node returns an in-memory node by ID.
// If the node does not exist, nil is returned.
func (c *coordinator) Node(id uuid.UUID) *Node {
@ -321,16 +257,29 @@ func (c *core) agentCount() int {
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logger := c.core.clientLogger(id, agent)
logger := c.core.clientLogger(id, agentID)
logger.Debug(ctx, "coordinating client")
tc, err := c.core.initAndTrackClient(ctx, cancel, conn, id, agent)
tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0)
defer tc.Close()
c.core.addClient(id, tc)
defer c.core.clientDisconnected(id)
agentNode, err := c.core.clientSubscribeToAgent(tc, agentID)
if err != nil {
return err
return xerrors.Errorf("subscribe agent: %w", err)
}
if agentNode != nil {
err := tc.Enqueue([]*Node{agentNode})
if err != nil {
logger.Debug(ctx, "enqueue initial node", slog.Error(err))
}
}
defer c.core.clientDisconnected(id, agent)
// On this goroutine, we read updates from the client and publish them. We start a second goroutine
// to write updates back to the client.
@ -338,7 +287,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID)
decoder := json.NewDecoder(conn)
for {
err := c.handleNextClientMessage(id, agent, decoder)
err := c.handleNextClientMessage(id, decoder)
if err != nil {
logger.Debug(ctx, "unable to read client update, connection may be closed", slog.Error(err))
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
@ -353,99 +302,133 @@ func (c *core) clientLogger(id, agent uuid.UUID) slog.Logger {
return c.logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
}
// initAndTrackClient creates a TrackedConn for the client, and sends any initial Node updates if we have any. It is
// one function that does two things because it is critical that we hold the mutex for both things, lest we miss some
// updates.
func (c *core) initAndTrackClient(
ctx context.Context, cancel func(), conn net.Conn, id, agent uuid.UUID,
) (
*TrackedConn, error,
) {
logger := c.clientLogger(id, agent)
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil, xerrors.New("coordinator is closed")
}
tc := NewTrackedConn(ctx, cancel, conn, id, logger, 0)
// When a new connection is requested, we update it with the latest
// node of the agent. This allows the connection to establish.
node, ok := c.nodes[agent]
if ok {
err := tc.Enqueue([]*Node{node})
// this should never error since we're still the only goroutine that
// knows about the TrackedConn. If we hit an error something really
// wrong is happening
if err != nil {
logger.Critical(ctx, "unable to queue initial node", slog.Error(err))
return nil, err
}
}
// Insert this connection into a map so the agent
// can publish node updates.
connectionSockets, ok := c.agentToConnectionSockets[agent]
func (c *core) initOrSetAgentConnectionSocketLocked(agentID uuid.UUID, enq Queue) {
connectionSockets, ok := c.agentToConnectionSockets[agentID]
if !ok {
connectionSockets = map[uuid.UUID]*TrackedConn{}
c.agentToConnectionSockets[agent] = connectionSockets
connectionSockets = map[uuid.UUID]Queue{}
c.agentToConnectionSockets[agentID] = connectionSockets
}
connectionSockets[id] = tc
logger.Debug(ctx, "added tracked connection")
return tc, nil
connectionSockets[enq.UniqueID()] = enq
c.clientsToAgents[enq.UniqueID()][agentID] = c.agentSockets[agentID]
}
func (c *core) clientDisconnected(id, agent uuid.UUID) {
logger := c.clientLogger(id, agent)
func (c *core) clientDisconnected(id uuid.UUID) {
logger := c.clientLogger(id, uuid.Nil)
c.mutex.Lock()
defer c.mutex.Unlock()
// Clean all traces of this connection from the map.
delete(c.nodes, id)
logger.Debug(context.Background(), "deleted client node")
connectionSockets, ok := c.agentToConnectionSockets[agent]
if !ok {
return
for agentID := range c.clientsToAgents[id] {
connectionSockets, ok := c.agentToConnectionSockets[agentID]
if !ok {
continue
}
delete(connectionSockets, id)
logger.Debug(context.Background(), "deleted client connectionSocket from map", slog.F("agent_id", agentID))
if len(connectionSockets) == 0 {
delete(c.agentToConnectionSockets, agentID)
logger.Debug(context.Background(), "deleted last client connectionSocket from map", slog.F("agent_id", agentID))
}
}
delete(connectionSockets, id)
logger.Debug(context.Background(), "deleted client connectionSocket from map")
if len(connectionSockets) != 0 {
return
}
delete(c.agentToConnectionSockets, agent)
logger.Debug(context.Background(), "deleted last client connectionSocket from map")
delete(c.clients, id)
delete(c.clientsToAgents, id)
logger.Debug(context.Background(), "deleted client agents")
}
func (c *coordinator) handleNextClientMessage(id, agent uuid.UUID, decoder *json.Decoder) error {
logger := c.core.clientLogger(id, agent)
func (c *coordinator) handleNextClientMessage(id uuid.UUID, decoder *json.Decoder) error {
logger := c.core.clientLogger(id, uuid.Nil)
var node Node
err := decoder.Decode(&node)
if err != nil {
return xerrors.Errorf("read json: %w", err)
}
logger.Debug(context.Background(), "got client node update", slog.F("node", node))
return c.core.clientNodeUpdate(id, agent, &node)
return c.core.clientNodeUpdate(id, &node)
}
func (c *core) clientNodeUpdate(id, agent uuid.UUID, node *Node) error {
logger := c.clientLogger(id, agent)
func (c *core) clientNodeUpdate(id uuid.UUID, node *Node) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// Update the node of this client in our in-memory map. If an agent entirely
// shuts down and reconnects, it needs to be aware of all clients attempting
// to establish connections.
c.nodes[id] = node
agentSocket, ok := c.agentSockets[agent]
if !ok {
logger.Debug(context.Background(), "no agent socket, unable to send node")
return nil
return c.clientNodeUpdateLocked(id, node)
}
func (c *core) clientNodeUpdateLocked(id uuid.UUID, node *Node) error {
logger := c.clientLogger(id, uuid.Nil)
agents := []uuid.UUID{}
for agentID, agentSocket := range c.clientsToAgents[id] {
if agentSocket == nil {
logger.Debug(context.Background(), "enqueue node to agent; socket is nil", slog.F("agent_id", agentID))
continue
}
err := agentSocket.Enqueue([]*Node{node})
if err != nil {
logger.Debug(context.Background(), "unable to Enqueue node to agent", slog.Error(err), slog.F("agent_id", agentID))
continue
}
agents = append(agents, agentID)
}
err := agentSocket.Enqueue([]*Node{node})
if err != nil {
return xerrors.Errorf("Enqueue node: %w", err)
logger.Debug(context.Background(), "enqueued node to agents", slog.F("agent_ids", agents))
return nil
}
func (c *core) clientSubscribeToAgent(enq Queue, agentID uuid.UUID) (*Node, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
logger := c.clientLogger(enq.UniqueID(), agentID)
c.initOrSetAgentConnectionSocketLocked(agentID, enq)
node, ok := c.nodes[enq.UniqueID()]
if ok {
// If we have the client node, send it to the agent. If not, it will be
// sent async.
agentSocket, ok := c.agentSockets[agentID]
if !ok {
logger.Debug(context.Background(), "subscribe to agent; socket is nil")
} else {
err := agentSocket.Enqueue([]*Node{node})
if err != nil {
return nil, xerrors.Errorf("enqueue client to agent: %w", err)
}
}
} else {
logger.Debug(context.Background(), "multiagent node doesn't exist")
}
logger.Debug(context.Background(), "enqueued node to agent")
agentNode, ok := c.nodes[agentID]
if !ok {
// This is ok, once the agent connects the node will be sent over.
logger.Debug(context.Background(), "agent node doesn't exist", slog.F("agent_id", agentID))
}
// Send the subscribed agent back to the multi agent.
return agentNode, nil
}
func (c *core) clientUnsubscribeFromAgent(enq Queue, agentID uuid.UUID) error {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.clientsToAgents[enq.UniqueID()], agentID)
delete(c.agentToConnectionSockets[agentID], enq.UniqueID())
return nil
}
@ -493,11 +476,14 @@ func (c *core) agentDisconnected(id, unique uuid.UUID) {
// Only delete the connection if it's ours. It could have been
// overwritten.
if idConn, ok := c.agentSockets[id]; ok && idConn.ID == unique {
if idConn, ok := c.agentSockets[id]; ok && idConn.UniqueID() == unique {
delete(c.agentSockets, id)
delete(c.nodes, id)
logger.Debug(context.Background(), "deleted agent socket and node")
}
for clientID := range c.agentToConnectionSockets[id] {
c.clientsToAgents[clientID][id] = nil
}
}
// initAndTrackAgent creates a TrackedConn for the agent, and sends any initial nodes updates if we have any. It is
@ -519,7 +505,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
// dead.
oldAgentSocket, ok := c.agentSockets[id]
if ok {
overwrites = oldAgentSocket.Overwrites + 1
overwrites = oldAgentSocket.Overwrites() + 1
_ = oldAgentSocket.Close()
}
tc := NewTrackedConn(ctx, cancel, conn, unique, logger, overwrites)
@ -549,6 +535,10 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
}
c.agentSockets[id] = tc
for clientID := range c.agentToConnectionSockets[id] {
c.clientsToAgents[clientID][id] = tc
}
logger.Debug(ctx, "added agent socket")
return tc, nil
}
@ -564,11 +554,31 @@ func (c *coordinator) handleNextAgentMessage(id uuid.UUID, decoder *json.Decoder
return c.core.agentNodeUpdate(id, &node)
}
// This is copied from codersdk because importing it here would cause an import
// cycle. This is just temporary until wsconncache is phased out.
var legacyAgentIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
// This is temporary until we no longer need to detect for agent backwards
// compatibility.
// See: https://github.com/coder/coder/issues/8218
func (c *core) agentIsLegacy(agentID uuid.UUID) bool {
c.mutex.RLock()
_, ok := c.legacyAgents[agentID]
c.mutex.RUnlock()
return ok
}
func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error {
logger := c.agentLogger(id)
c.mutex.Lock()
defer c.mutex.Unlock()
c.nodes[id] = node
// Keep a cache of all legacy agents.
if len(node.Addresses) > 0 && node.Addresses[0].Addr() == legacyAgentIP {
c.legacyAgents[id] = struct{}{}
}
connectionSockets, ok := c.agentToConnectionSockets[id]
if !ok {
logger.Debug(context.Background(), "no client sockets; unable to send node")
@ -588,6 +598,7 @@ func (c *core) agentNodeUpdate(id uuid.UUID, node *Node) error {
slog.F("client_id", clientID), slog.Error(err))
}
}
return nil
}
@ -611,20 +622,18 @@ func (c *core) close() error {
for _, socket := range c.agentSockets {
socket := socket
go func() {
_ = socket.Close()
_ = socket.CoordinatorClose()
wg.Done()
}()
}
for _, connMap := range c.agentToConnectionSockets {
wg.Add(len(connMap))
for _, socket := range connMap {
socket := socket
go func() {
_ = socket.Close()
wg.Done()
}()
}
wg.Add(len(c.clients))
for _, client := range c.clients {
client := client
go func() {
_ = client.CoordinatorClose()
wg.Done()
}()
}
c.mutex.Unlock()
@ -649,8 +658,8 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) {
}
func CoordinatorHTTPDebug(
agentSocketsMap map[uuid.UUID]*TrackedConn,
agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]*TrackedConn,
agentSocketsMap map[uuid.UUID]Queue,
agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue,
agentNameCache *lru.Cache[uuid.UUID, string],
) func(w http.ResponseWriter, _ *http.Request) {
return func(w http.ResponseWriter, _ *http.Request) {
@ -658,7 +667,7 @@ func CoordinatorHTTPDebug(
type idConn struct {
id uuid.UUID
conn *TrackedConn
conn Queue
}
{
@ -671,16 +680,17 @@ func CoordinatorHTTPDebug(
}
slices.SortFunc(agentSockets, func(a, b idConn) bool {
return a.conn.Name < b.conn.Name
return a.conn.Name() < b.conn.Name()
})
for _, agent := range agentSockets {
start, lastWrite := agent.conn.Stats()
_, _ = fmt.Fprintf(w, "<li style=\"margin-top:4px\"><b>%s</b> (<code>%s</code>): created %v ago, write %v ago, overwrites %d </li>\n",
agent.conn.Name,
agent.conn.Name(),
agent.id.String(),
now.Sub(time.Unix(agent.conn.Start, 0)).Round(time.Second),
now.Sub(time.Unix(agent.conn.LastWrite, 0)).Round(time.Second),
agent.conn.Overwrites,
now.Sub(time.Unix(start, 0)).Round(time.Second),
now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
agent.conn.Overwrites(),
)
if conns := agentToConnectionSocketsMap[agent.id]; len(conns) > 0 {
@ -696,11 +706,12 @@ func CoordinatorHTTPDebug(
_, _ = fmt.Fprintln(w, "<ul>")
for _, connSocket := range connSockets {
start, lastWrite := connSocket.conn.Stats()
_, _ = fmt.Fprintf(w, "<li><b>%s</b> (<code>%s</code>): created %v ago, write %v ago </li>\n",
connSocket.conn.Name,
connSocket.conn.Name(),
connSocket.id.String(),
now.Sub(time.Unix(connSocket.conn.Start, 0)).Round(time.Second),
now.Sub(time.Unix(connSocket.conn.LastWrite, 0)).Round(time.Second),
now.Sub(time.Unix(start, 0)).Round(time.Second),
now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
)
}
_, _ = fmt.Fprintln(w, "</ul>")
@ -755,11 +766,12 @@ func CoordinatorHTTPDebug(
_, _ = fmt.Fprintf(w, "<h3 style=\"margin:0px;font-size:16px;font-weight:400\">connections: total %d</h3>\n", len(agentConns.conns))
_, _ = fmt.Fprintln(w, "<ul>")
for _, agentConn := range agentConns.conns {
start, lastWrite := agentConn.conn.Stats()
_, _ = fmt.Fprintf(w, "<li><b>%s</b> (<code>%s</code>): created %v ago, write %v ago </li>\n",
agentConn.conn.Name,
agentConn.conn.Name(),
agentConn.id.String(),
now.Sub(time.Unix(agentConn.conn.Start, 0)).Round(time.Second),
now.Sub(time.Unix(agentConn.conn.LastWrite, 0)).Round(time.Second),
now.Sub(time.Unix(start, 0)).Round(time.Second),
now.Sub(time.Unix(lastWrite, 0)).Round(time.Second),
)
}
_, _ = fmt.Fprintln(w, "</ul>")

167
tailnet/multiagent.go Normal file
View File

@ -0,0 +1,167 @@
package tailnet
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
type MultiAgentConn interface {
UpdateSelf(node *Node) error
SubscribeAgent(agentID uuid.UUID) error
UnsubscribeAgent(agentID uuid.UUID) error
NextUpdate(ctx context.Context) ([]*Node, bool)
AgentIsLegacy(agentID uuid.UUID) bool
Close() error
IsClosed() bool
}
type MultiAgent struct {
mu sync.RWMutex
closed bool
ID uuid.UUID
Logger slog.Logger
AgentIsLegacyFunc func(agentID uuid.UUID) bool
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
OnNodeUpdate func(id uuid.UUID, node *Node) error
OnRemove func(id uuid.UUID)
updates chan []*Node
closeOnce sync.Once
start int64
lastWrite int64
// Client nodes normally generate a unique id for each connection so
// overwrites are really not an issue, but is provided for compatibility.
overwrites int64
}
func (m *MultiAgent) Init() *MultiAgent {
m.updates = make(chan []*Node, 128)
m.start = time.Now().Unix()
return m
}
func (m *MultiAgent) UniqueID() uuid.UUID {
return m.ID
}
func (m *MultiAgent) AgentIsLegacy(agentID uuid.UUID) bool {
return m.AgentIsLegacyFunc(agentID)
}
var ErrMultiAgentClosed = xerrors.New("multiagent is closed")
func (m *MultiAgent) UpdateSelf(node *Node) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return ErrMultiAgentClosed
}
return m.OnNodeUpdate(m.ID, node)
}
func (m *MultiAgent) SubscribeAgent(agentID uuid.UUID) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return ErrMultiAgentClosed
}
node, err := m.OnSubscribe(m, agentID)
if err != nil {
return err
}
if node != nil {
return m.enqueueLocked([]*Node{node})
}
return nil
}
func (m *MultiAgent) UnsubscribeAgent(agentID uuid.UUID) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return ErrMultiAgentClosed
}
return m.OnUnsubscribe(m, agentID)
}
func (m *MultiAgent) NextUpdate(ctx context.Context) ([]*Node, bool) {
select {
case <-ctx.Done():
return nil, false
case nodes, ok := <-m.updates:
return nodes, ok
}
}
func (m *MultiAgent) Enqueue(nodes []*Node) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return nil
}
return m.enqueueLocked(nodes)
}
func (m *MultiAgent) enqueueLocked(nodes []*Node) error {
atomic.StoreInt64(&m.lastWrite, time.Now().Unix())
select {
case m.updates <- nodes:
return nil
default:
return ErrWouldBlock
}
}
func (m *MultiAgent) Name() string {
return m.ID.String()
}
func (m *MultiAgent) Stats() (start int64, lastWrite int64) {
return m.start, atomic.LoadInt64(&m.lastWrite)
}
func (m *MultiAgent) Overwrites() int64 {
return m.overwrites
}
func (m *MultiAgent) IsClosed() bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.closed
}
func (m *MultiAgent) CoordinatorClose() error {
m.mu.Lock()
if !m.closed {
m.closed = true
close(m.updates)
}
m.mu.Unlock()
return nil
}
func (m *MultiAgent) Close() error {
_ = m.CoordinatorClose()
m.closeOnce.Do(func() { m.OnRemove(m.ID) })
return nil
}

View File

@ -22,7 +22,7 @@ import (
)
// RunDERPAndSTUN creates a DERP mapping for tests.
func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap {
func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) {
logf := tailnet.Logger(slogtest.Make(t, nil))
d := derp.NewServer(key.NewNode(), logf)
server := httptest.NewUnstartedServer(derphttp.Handler(d))
@ -61,7 +61,7 @@ func RunDERPAndSTUN(t *testing.T) *tailcfg.DERPMap {
},
},
},
}
}, d
}
// RunDERPOnlyWebSockets creates a DERP mapping for tests that

View File

@ -14,7 +14,7 @@ func TestMain(m *testing.M) {
func TestRunDERPAndSTUN(t *testing.T) {
t.Parallel()
_ = tailnettest.RunDERPAndSTUN(t)
_, _ = tailnettest.RunDERPAndSTUN(t)
}
func TestRunDERPOnlyWebSockets(t *testing.T) {

147
tailnet/trackedconn.go Normal file
View File

@ -0,0 +1,147 @@
package tailnet
import (
"bytes"
"context"
"encoding/json"
"net"
"sync/atomic"
"time"
"github.com/google/uuid"
"cdr.dev/slog"
)
// WriteTimeout is the amount of time we wait to write a node update to a connection before we declare it hung.
// It is exported so that tests can use it.
const WriteTimeout = time.Second * 5
type TrackedConn struct {
ctx context.Context
cancel func()
conn net.Conn
updates chan []*Node
logger slog.Logger
lastData []byte
// ID is an ephemeral UUID used to uniquely identify the owner of the
// connection.
id uuid.UUID
name string
start int64
lastWrite int64
overwrites int64
}
func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, overwrites int64) *TrackedConn {
// buffer updates so they don't block, since we hold the
// coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but
// the most pathological cases.
updates := make(chan []*Node, 512)
now := time.Now().Unix()
return &TrackedConn{
ctx: ctx,
conn: conn,
cancel: cancel,
updates: updates,
logger: logger,
id: id,
start: now,
lastWrite: now,
overwrites: overwrites,
}
}
func (t *TrackedConn) Enqueue(n []*Node) (err error) {
atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
select {
case t.updates <- n:
return nil
default:
return ErrWouldBlock
}
}
func (t *TrackedConn) UniqueID() uuid.UUID {
return t.id
}
func (t *TrackedConn) Name() string {
return t.name
}
func (t *TrackedConn) Stats() (start, lastWrite int64) {
return t.start, atomic.LoadInt64(&t.lastWrite)
}
func (t *TrackedConn) Overwrites() int64 {
return t.overwrites
}
func (t *TrackedConn) CoordinatorClose() error {
return t.Close()
}
// Close the connection and cancel the context for reading node updates from the queue
func (t *TrackedConn) Close() error {
t.cancel()
return t.conn.Close()
}
// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is
// canceled.
func (t *TrackedConn) SendUpdates() {
for {
select {
case <-t.ctx.Done():
t.logger.Debug(t.ctx, "done sending updates")
return
case nodes := <-t.updates:
data, err := json.Marshal(nodes)
if err != nil {
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
return
}
if bytes.Equal(t.lastData, data) {
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data)))
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout))
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err))
_ = t.Close()
return
}
_, err = t.conn.Write(data)
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "could not write nodes to connection",
slog.Error(err), slog.F("nodes", string(data)))
_ = t.Close()
return
}
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data)))
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = t.conn.SetWriteDeadline(time.Time{})
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err))
_ = t.Close()
return
}
t.lastData = data
}
}
}