fix(tailnet): Improve tailnet setup and agentconn stability (#6292)

* fix(tailnet): Improve start and close to detect connection races

* fix: Prevent agentConn use before ready via AwaitReachable

* fix(tailnet): Ensure connstats are closed on conn close

* fix(codersdk): Use AwaitReachable in DialWorkspaceAgent

* fix(tailnet): Improve logging via slog.Helper()
This commit is contained in:
Mathias Fredriksson 2023-02-24 13:11:28 +02:00 committed by GitHub
parent 473ab208af
commit a414de9e81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 25 deletions

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/netip"
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@ -78,8 +79,14 @@ func TestDERP(t *testing.T) {
DERPMap: derpMap,
})
require.NoError(t, err)
w2Ready := make(chan struct{}, 1)
w2ReadyOnce := sync.Once{}
w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node})
w2ReadyOnce.Do(func() {
close(w2Ready)
})
})
w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node})
@ -98,6 +105,7 @@ func TestDERP(t *testing.T) {
}()
<-conn
<-w2Ready
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()

View File

@ -469,6 +469,8 @@ func TestWorkspaceAgentListeningPorts(t *testing.T) {
t.Parallel()
setup := func(t *testing.T, apps []*proto.App) (*codersdk.Client, uint16, uuid.UUID) {
t.Helper()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})

View File

@ -29,6 +29,7 @@ import (
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
@ -131,6 +132,14 @@ func TestCache(t *testing.T) {
return
}
defer release()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !conn.AwaitReachable(ctx) {
t.Error("agent not reachable")
return
}
transport := conn.HTTPTransport()
defer transport.CloseIdleConnections()
proxy.Transport = transport
@ -146,6 +155,8 @@ func TestCache(t *testing.T) {
}
func setupAgent(t *testing.T, metadata agentsdk.Metadata, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
t.Helper()
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator()

View File

@ -176,7 +176,9 @@ type ReconnectingPTYRequest struct {
func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentReconnectingPTYPort))
if err != nil {
return nil, err
@ -207,6 +209,9 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
return c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSSHPort))
}
@ -235,6 +240,9 @@ func (c *WorkspaceAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error)
func (c *WorkspaceAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentSpeedtestPort))
if err != nil {
return nil, xerrors.Errorf("dial speedtest: %w", err)
@ -257,6 +265,9 @@ func (c *WorkspaceAgentConn) DialContext(ctx context.Context, network string, ad
_, rawPort, _ := net.SplitHostPort(addr)
port, _ := strconv.ParseUint(rawPort, 10, 16)
ipp := netip.AddrPortFrom(WorkspaceAgentIP, uint16(port))
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
if network == "udp" {
return c.Conn.DialContextUDP(ctx, ipp)
}
@ -317,7 +328,7 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
// Disable keep alives as we're usually only making a single
// request, and this triggers goleak in tests
DisableKeepAlives: true,
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if network != "tcp" {
return nil, xerrors.Errorf("network must be tcp")
}
@ -331,7 +342,11 @@ func (c *WorkspaceAgentConn) apiClient() *http.Client {
return nil, xerrors.Errorf("request %q does not appear to be for http api", addr)
}
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
if !c.AwaitReachable(ctx) {
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
}
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(WorkspaceAgentIP, WorkspaceAgentHTTPAPIServerPort))
if err != nil {
return nil, xerrors.Errorf("dial http api: %w", err)
}

View File

@ -199,13 +199,19 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
return nil, err
}
return &WorkspaceAgentConn{
agentConn := &WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
cancelFunc()
<-closed
},
}, nil
}
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
return nil, xerrors.Errorf("timed out waiting for agent to become reachable: %w", ctx.Err())
}
return agentConn, nil
}
// WorkspaceAgent returns an agent by ID.

View File

@ -60,7 +60,7 @@ type Options struct {
}
// NewConn constructs a new Wireguard server that will accept connections from the addresses provided.
func NewConn(options *Options) (*Conn, error) {
func NewConn(options *Options) (conn *Conn, err error) {
if options == nil {
options = &Options{}
}
@ -123,6 +123,11 @@ func NewConn(options *Options) (*Conn, error) {
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
}
defer func() {
if err != nil {
wireguardMonitor.Close()
}
}()
dialer := &tsdial.Dialer{
Logf: Logger(options.Logger),
@ -134,6 +139,11 @@ func NewConn(options *Options) (*Conn, error) {
if err != nil {
return nil, xerrors.Errorf("create wgengine: %w", err)
}
defer func() {
if err != nil {
wireguardEngine.Close()
}
}()
dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := wireguardEngine.PeerForIP(ip)
return ok
@ -166,10 +176,6 @@ func NewConn(options *Options) (*Conn, error) {
return netStack.DialContextTCP(ctx, dst)
}
netStack.ProcessLocalIPs = true
err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}
wireguardEngine = wgengine.NewWatchdog(wireguardEngine)
wireguardEngine.SetDERPMap(options.DERPMap)
netMapCopy := *netMap
@ -203,6 +209,11 @@ func NewConn(options *Options) (*Conn, error) {
},
wireguardEngine: wireguardEngine,
}
defer func() {
if err != nil {
_ = server.Close()
}
}()
wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) {
server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err))
if err != nil {
@ -236,6 +247,12 @@ func NewConn(options *Options) (*Conn, error) {
server.sendNode()
})
netStack.ForwardTCPIn = server.forwardTCP
err = netStack.Start(nil)
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}
return server, nil
}
@ -519,22 +536,35 @@ func (c *Conn) Close() error {
default:
}
close(c.closed)
c.mutex.Unlock()
var wg sync.WaitGroup
defer wg.Wait()
if c.trafficStats != nil {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)
}()
}
_ = c.netStack.Close()
c.dialCancel()
_ = c.wireguardMonitor.Close()
_ = c.dialer.Close()
// Stops internals, e.g. tunDevice, magicConn and dnsManager.
c.wireguardEngine.Close()
c.mutex.Lock()
for _, l := range c.listeners {
_ = l.closeNoLock()
}
c.listeners = nil
c.mutex.Unlock()
c.dialCancel()
_ = c.dialer.Close()
_ = c.magicConn.Close()
_ = c.netStack.Close()
_ = c.wireguardMonitor.Close()
_ = c.tunDevice.Close()
c.wireguardEngine.Close()
if c.trafficStats != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.trafficStats.Shutdown(ctx)
}
return nil
}
@ -714,16 +744,25 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) {
func (c *Conn) SetConnStatsCallback(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) {
connStats := connstats.NewStatistics(maxPeriod, maxConns, dump)
shutdown := func(s *connstats.Statistics) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.Shutdown(ctx)
}
c.mutex.Lock()
if c.isClosed() {
c.mutex.Unlock()
shutdown(connStats)
return
}
old := c.trafficStats
c.trafficStats = connStats
c.mutex.Unlock()
// Make sure to shutdown the old callback.
if old != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = old.Shutdown(ctx)
shutdown(old)
}
c.tunDevice.SetStatistics(connStats)
@ -776,6 +815,7 @@ func (a addr) String() string { return a.ln.addr }
// Logger converts the Tailscale logging function to use slog.
func Logger(logger slog.Logger) tslogger.Logf {
return tslogger.Logf(func(format string, args ...any) {
slog.Helper()
logger.Debug(context.Background(), fmt.Sprintf(format, args...))
})
}