mirror of https://github.com/coder/coder.git
fix(tailnet): Improve tailnet setup and agentconn stability (#6292)
* fix(tailnet): Improve start and close to detect connection races * fix: Prevent agentConn use before ready via AwaitReachable * fix(tailnet): Ensure connstats are closed on conn close * fix(codersdk): Use AwaitReachable in DialWorkspaceAgent * fix(tailnet): Improve logging via slog.Helper()
This commit is contained in:
parent
473ab208af
commit
a414de9e81
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -78,8 +79,14 @@ func TestDERP(t *testing.T) {
|
|||
DERPMap: derpMap,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w2Ready := make(chan struct{}, 1)
|
||||
w2ReadyOnce := sync.Once{}
|
||||
w1.SetNodeCallback(func(node *tailnet.Node) {
|
||||
w2.UpdateNodes([]*tailnet.Node{node})
|
||||
w2ReadyOnce.Do(func() {
|
||||
close(w2Ready)
|
||||
})
|
||||
})
|
||||
w2.SetNodeCallback(func(node *tailnet.Node) {
|
||||
w1.UpdateNodes([]*tailnet.Node{node})
|
||||
|
@ -98,6 +105,7 @@ func TestDERP(t *testing.T) {
|
|||
}()
|
||||
|
||||
<-conn
|
||||
<-w2Ready
|
||||
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
|
||||
require.NoError(t, err)
|
||||
_ = nc.Close()
|
||||
|
|
|
@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -131,6 +132,14 @@ func TestCache(t *testing.T) {
|
|||
return
|
||||
}
|
||||
defer release()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
if !conn.AwaitReachable(ctx) {
|
||||
t.Error("agent not reachable")
|
||||
return
|
||||
}
|
||||
|
||||
transport := conn.HTTPTransport()
|
||||
defer transport.CloseIdleConnections()
|
||||
proxy.Transport = transport
|
||||
|
@ -146,6 +155,8 @@ func TestCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
|
||||
t.Helper()
|
||||
|
||||
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
|
|
|
@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct {
|
|||
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
|
|||
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
|
||||
}
|
||||
|
||||
|
@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
|
|||
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial speedtest: %w", err)
|
||||
|
@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad
|
|||
_, rawPort, _ := net.SplitHostPort(addr)
|
||||
port, _ := strconv.ParseUint(rawPort, 10, 16)
|
||||
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
if network == "udp" {
|
||||
return c.Conn.DialContextUDP(ctx, ipp)
|
||||
}
|
||||
|
@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
|
|||
// Disable keep alives as we're usually only making a single
|
||||
// request, and this triggers goleak in tests
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if network != "tcp" {
|
||||
return nil, xerrors.Errorf("network must be tcp")
|
||||
}
|
||||
|
@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
|
|||
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
|
||||
}
|
||||
|
||||
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
|
||||
if !c.AwaitReachable(ctx) {
|
||||
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
|
||||
}
|
||||
|
||||
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial http api: %w", err)
|
||||
}
|
||||
|
|
|
@ -199,13 +199,19 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &WorkspaceAgentConn{
|
||||
agentConn := &WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
CloseFunc: func() {
|
||||
cancelFunc()
|
||||
<-closed
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if !agentConn.AwaitReachable(ctx) {
|
||||
_ = agentConn.Close()
|
||||
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())
|
||||
}
|
||||
|
||||
return agentConn, nil
|
||||
}
|
||||
|
||||
// WorkspaceAgent returns an agent by ID.
|
||||
|
|
|
@ -60,7 +60,7 @@ type Options struct {
|
|||
}
|
||||
|
||||
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
|
||||
func NewConn(options *Options) (*Conn, error) {
|
||||
func NewConn(options *Options) (conn *Conn, err error) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
|
@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
wireguardMonitor.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
dialer := &tsdial.Dialer{
|
||||
Logf: Logger(options.Logger),
|
||||
|
@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
if err != nil {
|
||||
return nil, xerrors.Errorf("create wgengine: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
wireguardEngine.Close()
|
||||
}
|
||||
}()
|
||||
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
|
||||
_, ok := wireguardEngine.PeerForIP(ip)
|
||||
return ok
|
||||
|
@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
return netStack.DialContextTCP(ctx, dst)
|
||||
}
|
||||
netStack.ProcessLocalIPs = true
|
||||
err = netStack.Start(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("start netstack: %w", err)
|
||||
}
|
||||
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
|
||||
wireguardEngine.SetDERPMap(options.DERPMap)
|
||||
netMapCopy := *netMap
|
||||
|
@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
},
|
||||
wireguardEngine: wireguardEngine,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = server.Close()
|
||||
}
|
||||
}()
|
||||
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
|
||||
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err))
|
||||
if err != nil {
|
||||
|
@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) {
|
|||
server.sendNode()
|
||||
})
|
||||
netStack.ForwardTCPIn = server.forwardTCP
|
||||
|
||||
err = netStack.Start(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("start netstack: %w", err)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
|
@ -519,22 +536,35 @@ func (c *Conn) Close() error {
|
|||
default:
|
||||
}
|
||||
close(c.closed)
|
||||
c.mutex.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
if c.trafficStats != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = c.trafficStats.Shutdown(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
_ = c.netStack.Close()
|
||||
c.dialCancel()
|
||||
_ = c.wireguardMonitor.Close()
|
||||
_ = c.dialer.Close()
|
||||
// Stops internals, e.g. tunDevice, magicConn and dnsManager.
|
||||
c.wireguardEngine.Close()
|
||||
|
||||
c.mutex.Lock()
|
||||
for _, l := range c.listeners {
|
||||
_ = l.closeNoLock()
|
||||
}
|
||||
c.listeners = nil
|
||||
c.mutex.Unlock()
|
||||
c.dialCancel()
|
||||
_ = c.dialer.Close()
|
||||
_ = c.magicConn.Close()
|
||||
_ = c.netStack.Close()
|
||||
_ = c.wireguardMonitor.Close()
|
||||
_ = c.tunDevice.Close()
|
||||
c.wireguardEngine.Close()
|
||||
if c.trafficStats != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = c.trafficStats.Shutdown(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -714,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
|
|||
func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) {
|
||||
connStats := connstats.NewStatistics(maxPeriod, maxConns, dump)
|
||||
|
||||
shutdown := func(s *connstats.Statistics) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.Shutdown(ctx)
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
if c.isClosed() {
|
||||
c.mutex.Unlock()
|
||||
shutdown(connStats)
|
||||
return
|
||||
}
|
||||
old := c.trafficStats
|
||||
c.trafficStats = connStats
|
||||
c.mutex.Unlock()
|
||||
|
||||
// Make sure to shutdown the old callback.
|
||||
if old != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = old.Shutdown(ctx)
|
||||
shutdown(old)
|
||||
}
|
||||
|
||||
c.tunDevice.SetStatistics(connStats)
|
||||
|
@ -776,6 +815,7 @@ func (a addr) String() string { return a.ln.addr }
|
|||
// Logger converts the Tailscale logging function to use slog.
|
||||
func Logger(logger slog.Logger) tslogger.Logf {
|
||||
return tslogger.Logf(func(format string, args ...any) {
|
||||
slog.Helper()
|
||||
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue