diff --git a/agent/agent_test.go b/agent/agent_test.go index 2db5e85574..f884918c83 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -214,59 +214,46 @@ func TestAgent_Stats_Magic(t *testing.T) { _, b, _, ok := runtime.Caller(0) require.True(t, ok) dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go") + echoServerCmd := exec.Command("go", "run", dir, + "-D", agentssh.MagicProcessCmdlineJetBrains) + stdout, err := echoServerCmd.StdoutPipe() + require.NoError(t, err) + err = echoServerCmd.Start() + require.NoError(t, err) + defer echoServerCmd.Process.Kill() - spawnServer := func(network string) (string, *exec.Cmd) { - echoServerCmd := exec.Command("go", "run", dir, - network, "-D", agentssh.MagicProcessCmdlineJetBrains) - stdout, err := echoServerCmd.StdoutPipe() - require.NoError(t, err) - err = echoServerCmd.Start() - require.NoError(t, err) - t.Cleanup(func() { - echoServerCmd.Process.Kill() - }) - - // The echo server prints its port as the first line. - sc := bufio.NewScanner(stdout) - sc.Scan() - return sc.Text(), echoServerCmd - } - - port4, cmd4 := spawnServer("tcp4") - port6, cmd6 := spawnServer("tcp6") + // The echo server prints its port as the first line. + sc := bufio.NewScanner(stdout) + sc.Scan() + remotePort := sc.Text() //nolint:dogsled conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) - defer conn.Close() - sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) - tunnel4, err := sshClient.Dial("tcp4", fmt.Sprintf("127.0.0.1:%s", port4)) + tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort)) require.NoError(t, err) - defer tunnel4.Close() - - tunnel6, err := sshClient.Dial("tcp6", fmt.Sprintf("[::]:%s", port6)) - require.NoError(t, err) - defer tunnel6.Close() + t.Cleanup(func() { + // always close on failure of test + _ = conn.Close() + _ = tunneledConn.Close() + }) require.Eventuallyf(t, func() bool { s, ok := <-stats t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d", ok, s.ConnectionCount, s.SessionCountJetBrains) return ok && s.ConnectionCount > 0 && - s.SessionCountJetBrains == 2 + s.SessionCountJetBrains == 1 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats with conn open", ) // Kill the server and connection after checking for the echo. - requireEcho(t, tunnel4) - requireEcho(t, tunnel6) - _ = cmd4.Process.Kill() - _ = cmd6.Process.Kill() - _ = tunnel4.Close() - _ = tunnel6.Close() + requireEcho(t, tunneledConn) + _ = echoServerCmd.Process.Kill() + _ = tunneledConn.Close() require.Eventuallyf(t, func() bool { s, ok := <-stats diff --git a/agent/agentssh/portinspection_supported.go b/agent/agentssh/portinspection_supported.go index 600651ab09..d45847bd6f 100644 --- a/agent/agentssh/portinspection_supported.go +++ b/agent/agentssh/portinspection_supported.go @@ -3,7 +3,6 @@ package agentssh import ( - "errors" "fmt" "os" @@ -12,33 +11,24 @@ import ( ) func getListeningPortProcessCmdline(port uint32) (string, error) { - acceptFn := func(s *netstat.SockTabEntry) bool { + tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool { return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port + }) + if err != nil { + return "", xerrors.Errorf("inspect port %d: %w", port, err) } - tabs, err := netstat.TCPSocks(acceptFn) - tabs6, err6 := netstat.TCP6Socks(acceptFn) - - // Only return the error if the other method found nothing. - if (err != nil && len(tabs6) == 0) || (err6 != nil && len(tabs) == 0) { - return "", xerrors.Errorf("inspect port %d: %w", port, errors.Join(err, err6)) + if len(tabs) == 0 { + return "", nil } - var proc *netstat.Process - if len(tabs) > 0 { - proc = tabs[0].Process - } else if len(tabs6) > 0 { - proc = tabs6[0].Process - } - if proc == nil { - // Either nothing is listening on this port or we were unable to read the - // process details (permission issues reading /proc/$pid/* potentially). - // Or, perhaps /proc/net/tcp{,6} is not listing the port for some reason. + // Defensive check. + if tabs[0].Process == nil { return "", nil } // The process name provided by go-netstat does not include the full command // line so grab that instead. - pid := proc.Pid + pid := tabs[0].Process.Pid data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) if err != nil { return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err) diff --git a/scripts/echoserver/main.go b/scripts/echoserver/main.go index 32c0766d6f..cb30a0b383 100644 --- a/scripts/echoserver/main.go +++ b/scripts/echoserver/main.go @@ -9,21 +9,10 @@ import ( "io" "log" "net" - "os" ) func main() { - network := os.Args[1] - var address string - switch network { - case "tcp4": - address = "127.0.0.1" - case "tcp6": - address = "[::]" - default: - log.Fatalf("invalid network: %s", network) - } - l, err := net.Listen(network, address+":0") + l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { log.Fatalf("listen error: err=%s", err) }