diff --git a/agent/agent_test.go b/agent/agent_test.go index f884918c83..2db5e85574 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -214,46 +214,59 @@ func TestAgent_Stats_Magic(t *testing.T) { _, b, _, ok := runtime.Caller(0) require.True(t, ok) dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go") - echoServerCmd := exec.Command("go", "run", dir, - "-D", agentssh.MagicProcessCmdlineJetBrains) - stdout, err := echoServerCmd.StdoutPipe() - require.NoError(t, err) - err = echoServerCmd.Start() - require.NoError(t, err) - defer echoServerCmd.Process.Kill() - // The echo server prints its port as the first line. - sc := bufio.NewScanner(stdout) - sc.Scan() - remotePort := sc.Text() + spawnServer := func(network string) (string, *exec.Cmd) { + echoServerCmd := exec.Command("go", "run", dir, + network, "-D", agentssh.MagicProcessCmdlineJetBrains) + stdout, err := echoServerCmd.StdoutPipe() + require.NoError(t, err) + err = echoServerCmd.Start() + require.NoError(t, err) + t.Cleanup(func() { + echoServerCmd.Process.Kill() + }) + + // The echo server prints its port as the first line. + sc := bufio.NewScanner(stdout) + sc.Scan() + return sc.Text(), echoServerCmd + } + + port4, cmd4 := spawnServer("tcp4") + port6, cmd6 := spawnServer("tcp6") //nolint:dogsled conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + defer conn.Close() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) - tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort)) + tunnel4, err := sshClient.Dial("tcp4", fmt.Sprintf("127.0.0.1:%s", port4)) require.NoError(t, err) - t.Cleanup(func() { - // always close on failure of test - _ = conn.Close() - _ = tunneledConn.Close() - }) + defer tunnel4.Close() + + tunnel6, err := sshClient.Dial("tcp6", fmt.Sprintf("[::]:%s", port6)) + require.NoError(t, err) + defer tunnel6.Close() require.Eventuallyf(t, func() bool { s, ok := <-stats t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d", ok, s.ConnectionCount, s.SessionCountJetBrains) return ok && s.ConnectionCount > 0 && - s.SessionCountJetBrains == 1 + s.SessionCountJetBrains == 2 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats with conn open", ) // Kill the server and connection after checking for the echo. - requireEcho(t, tunneledConn) - _ = echoServerCmd.Process.Kill() - _ = tunneledConn.Close() + requireEcho(t, tunnel4) + requireEcho(t, tunnel6) + _ = cmd4.Process.Kill() + _ = cmd6.Process.Kill() + _ = tunnel4.Close() + _ = tunnel6.Close() require.Eventuallyf(t, func() bool { s, ok := <-stats diff --git a/agent/agentssh/portinspection_supported.go b/agent/agentssh/portinspection_supported.go index d45847bd6f..600651ab09 100644 --- a/agent/agentssh/portinspection_supported.go +++ b/agent/agentssh/portinspection_supported.go @@ -3,6 +3,7 @@ package agentssh import ( + "errors" "fmt" "os" @@ -11,24 +12,33 @@ import ( ) func getListeningPortProcessCmdline(port uint32) (string, error) { - tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool { + acceptFn := func(s *netstat.SockTabEntry) bool { return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port - }) - if err != nil { - return "", xerrors.Errorf("inspect port %d: %w", port, err) } - if len(tabs) == 0 { - return "", nil + tabs, err := netstat.TCPSocks(acceptFn) + tabs6, err6 := netstat.TCP6Socks(acceptFn) + + // Only return the error if the other method found nothing. + if (err != nil && len(tabs6) == 0) || (err6 != nil && len(tabs) == 0) { + return "", xerrors.Errorf("inspect port %d: %w", port, errors.Join(err, err6)) } - // Defensive check. - if tabs[0].Process == nil { + var proc *netstat.Process + if len(tabs) > 0 { + proc = tabs[0].Process + } else if len(tabs6) > 0 { + proc = tabs6[0].Process + } + if proc == nil { + // Either nothing is listening on this port or we were unable to read the + // process details (permission issues reading /proc/$pid/* potentially). + // Or, perhaps /proc/net/tcp{,6} is not listing the port for some reason. return "", nil } // The process name provided by go-netstat does not include the full command // line so grab that instead. - pid := tabs[0].Process.Pid + pid := proc.Pid data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) if err != nil { return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err) diff --git a/scripts/echoserver/main.go b/scripts/echoserver/main.go index cb30a0b383..32c0766d6f 100644 --- a/scripts/echoserver/main.go +++ b/scripts/echoserver/main.go @@ -9,10 +9,21 @@ import ( "io" "log" "net" + "os" ) func main() { - l, err := net.Listen("tcp", "127.0.0.1:0") + network := os.Args[1] + var address string + switch network { + case "tcp4": + address = "127.0.0.1" + case "tcp6": + address = "[::]" + default: + log.Fatalf("invalid network: %s", network) + } + l, err := net.Listen(network, address+":0") if err != nil { log.Fatalf("listen error: err=%s", err) }