fix: detect JetBrains running on local ipv6 (#11653)

This commit is contained in:
Asher 2024-01-16 15:53:41 -09:00 committed by GitHub
parent be43d6247d
commit 2d61d5332a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 31 deletions

View File

@ -214,46 +214,59 @@ func TestAgent_Stats_Magic(t *testing.T) {
_, b, _, ok := runtime.Caller(0) _, b, _, ok := runtime.Caller(0)
require.True(t, ok) require.True(t, ok)
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go") dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
echoServerCmd := exec.Command("go", "run", dir,
"-D", agentssh.MagicProcessCmdlineJetBrains)
stdout, err := echoServerCmd.StdoutPipe()
require.NoError(t, err)
err = echoServerCmd.Start()
require.NoError(t, err)
defer echoServerCmd.Process.Kill()
// The echo server prints its port as the first line. spawnServer := func(network string) (string, *exec.Cmd) {
sc := bufio.NewScanner(stdout) echoServerCmd := exec.Command("go", "run", dir,
sc.Scan() network, "-D", agentssh.MagicProcessCmdlineJetBrains)
remotePort := sc.Text() stdout, err := echoServerCmd.StdoutPipe()
require.NoError(t, err)
err = echoServerCmd.Start()
require.NoError(t, err)
t.Cleanup(func() {
echoServerCmd.Process.Kill()
})
// The echo server prints its port as the first line.
sc := bufio.NewScanner(stdout)
sc.Scan()
return sc.Text(), echoServerCmd
}
port4, cmd4 := spawnServer("tcp4")
port6, cmd6 := spawnServer("tcp6")
//nolint:dogsled //nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
defer conn.Close()
sshClient, err := conn.SSHClient(ctx) sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort)) tunnel4, err := sshClient.Dial("tcp4", fmt.Sprintf("127.0.0.1:%s", port4))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { defer tunnel4.Close()
// always close on failure of test
_ = conn.Close() tunnel6, err := sshClient.Dial("tcp6", fmt.Sprintf("[::]:%s", port6))
_ = tunneledConn.Close() require.NoError(t, err)
}) defer tunnel6.Close()
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
s, ok := <-stats s, ok := <-stats
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d", t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
ok, s.ConnectionCount, s.SessionCountJetBrains) ok, s.ConnectionCount, s.SessionCountJetBrains)
return ok && s.ConnectionCount > 0 && return ok && s.ConnectionCount > 0 &&
s.SessionCountJetBrains == 1 s.SessionCountJetBrains == 2
}, testutil.WaitLong, testutil.IntervalFast, }, testutil.WaitLong, testutil.IntervalFast,
"never saw stats with conn open", "never saw stats with conn open",
) )
// Kill the server and connection after checking for the echo. // Kill the server and connection after checking for the echo.
requireEcho(t, tunneledConn) requireEcho(t, tunnel4)
_ = echoServerCmd.Process.Kill() requireEcho(t, tunnel6)
_ = tunneledConn.Close() _ = cmd4.Process.Kill()
_ = cmd6.Process.Kill()
_ = tunnel4.Close()
_ = tunnel6.Close()
require.Eventuallyf(t, func() bool { require.Eventuallyf(t, func() bool {
s, ok := <-stats s, ok := <-stats

View File

@ -3,6 +3,7 @@
package agentssh package agentssh
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
@ -11,24 +12,33 @@ import (
) )
func getListeningPortProcessCmdline(port uint32) (string, error) { func getListeningPortProcessCmdline(port uint32) (string, error) {
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool { acceptFn := func(s *netstat.SockTabEntry) bool {
return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port
})
if err != nil {
return "", xerrors.Errorf("inspect port %d: %w", port, err)
} }
if len(tabs) == 0 { tabs, err := netstat.TCPSocks(acceptFn)
return "", nil tabs6, err6 := netstat.TCP6Socks(acceptFn)
// Only return the error if the other method found nothing.
if (err != nil && len(tabs6) == 0) || (err6 != nil && len(tabs) == 0) {
return "", xerrors.Errorf("inspect port %d: %w", port, errors.Join(err, err6))
} }
// Defensive check. var proc *netstat.Process
if tabs[0].Process == nil { if len(tabs) > 0 {
proc = tabs[0].Process
} else if len(tabs6) > 0 {
proc = tabs6[0].Process
}
if proc == nil {
// Either nothing is listening on this port or we were unable to read the
// process details (permission issues reading /proc/$pid/* potentially).
// Or, perhaps /proc/net/tcp{,6} is not listing the port for some reason.
return "", nil return "", nil
} }
// The process name provided by go-netstat does not include the full command // The process name provided by go-netstat does not include the full command
// line so grab that instead. // line so grab that instead.
pid := tabs[0].Process.Pid pid := proc.Pid
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil { if err != nil {
return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err) return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err)

View File

@ -9,10 +9,21 @@ import (
"io" "io"
"log" "log"
"net" "net"
"os"
) )
func main() { func main() {
l, err := net.Listen("tcp", "127.0.0.1:0") network := os.Args[1]
var address string
switch network {
case "tcp4":
address = "127.0.0.1"
case "tcp6":
address = "[::]"
default:
log.Fatalf("invalid network: %s", network)
}
l, err := net.Listen(network, address+":0")
if err != nil { if err != nil {
log.Fatalf("listen error: err=%s", err) log.Fatalf("listen error: err=%s", err)
} }