mirror of https://github.com/coder/coder.git
This reverts commit 2d61d5332a
.
This commit is contained in:
parent
2aa3cbbd03
commit
b173195e0d
|
@ -214,59 +214,46 @@ func TestAgent_Stats_Magic(t *testing.T) {
|
||||||
_, b, _, ok := runtime.Caller(0)
|
_, b, _, ok := runtime.Caller(0)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
|
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
|
||||||
|
echoServerCmd := exec.Command("go", "run", dir,
|
||||||
|
"-D", agentssh.MagicProcessCmdlineJetBrains)
|
||||||
|
stdout, err := echoServerCmd.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = echoServerCmd.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer echoServerCmd.Process.Kill()
|
||||||
|
|
||||||
spawnServer := func(network string) (string, *exec.Cmd) {
|
// The echo server prints its port as the first line.
|
||||||
echoServerCmd := exec.Command("go", "run", dir,
|
sc := bufio.NewScanner(stdout)
|
||||||
network, "-D", agentssh.MagicProcessCmdlineJetBrains)
|
sc.Scan()
|
||||||
stdout, err := echoServerCmd.StdoutPipe()
|
remotePort := sc.Text()
|
||||||
require.NoError(t, err)
|
|
||||||
err = echoServerCmd.Start()
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
echoServerCmd.Process.Kill()
|
|
||||||
})
|
|
||||||
|
|
||||||
// The echo server prints its port as the first line.
|
|
||||||
sc := bufio.NewScanner(stdout)
|
|
||||||
sc.Scan()
|
|
||||||
return sc.Text(), echoServerCmd
|
|
||||||
}
|
|
||||||
|
|
||||||
port4, cmd4 := spawnServer("tcp4")
|
|
||||||
port6, cmd6 := spawnServer("tcp6")
|
|
||||||
|
|
||||||
//nolint:dogsled
|
//nolint:dogsled
|
||||||
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
sshClient, err := conn.SSHClient(ctx)
|
sshClient, err := conn.SSHClient(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tunnel4, err := sshClient.Dial("tcp4", fmt.Sprintf("127.0.0.1:%s", port4))
|
tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer tunnel4.Close()
|
t.Cleanup(func() {
|
||||||
|
// always close on failure of test
|
||||||
tunnel6, err := sshClient.Dial("tcp6", fmt.Sprintf("[::]:%s", port6))
|
_ = conn.Close()
|
||||||
require.NoError(t, err)
|
_ = tunneledConn.Close()
|
||||||
defer tunnel6.Close()
|
})
|
||||||
|
|
||||||
require.Eventuallyf(t, func() bool {
|
require.Eventuallyf(t, func() bool {
|
||||||
s, ok := <-stats
|
s, ok := <-stats
|
||||||
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
|
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
|
||||||
ok, s.ConnectionCount, s.SessionCountJetBrains)
|
ok, s.ConnectionCount, s.SessionCountJetBrains)
|
||||||
return ok && s.ConnectionCount > 0 &&
|
return ok && s.ConnectionCount > 0 &&
|
||||||
s.SessionCountJetBrains == 2
|
s.SessionCountJetBrains == 1
|
||||||
}, testutil.WaitLong, testutil.IntervalFast,
|
}, testutil.WaitLong, testutil.IntervalFast,
|
||||||
"never saw stats with conn open",
|
"never saw stats with conn open",
|
||||||
)
|
)
|
||||||
|
|
||||||
// Kill the server and connection after checking for the echo.
|
// Kill the server and connection after checking for the echo.
|
||||||
requireEcho(t, tunnel4)
|
requireEcho(t, tunneledConn)
|
||||||
requireEcho(t, tunnel6)
|
_ = echoServerCmd.Process.Kill()
|
||||||
_ = cmd4.Process.Kill()
|
_ = tunneledConn.Close()
|
||||||
_ = cmd6.Process.Kill()
|
|
||||||
_ = tunnel4.Close()
|
|
||||||
_ = tunnel6.Close()
|
|
||||||
|
|
||||||
require.Eventuallyf(t, func() bool {
|
require.Eventuallyf(t, func() bool {
|
||||||
s, ok := <-stats
|
s, ok := <-stats
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
package agentssh
|
package agentssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
@ -12,33 +11,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func getListeningPortProcessCmdline(port uint32) (string, error) {
|
func getListeningPortProcessCmdline(port uint32) (string, error) {
|
||||||
acceptFn := func(s *netstat.SockTabEntry) bool {
|
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
|
||||||
return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port
|
return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", xerrors.Errorf("inspect port %d: %w", port, err)
|
||||||
}
|
}
|
||||||
tabs, err := netstat.TCPSocks(acceptFn)
|
if len(tabs) == 0 {
|
||||||
tabs6, err6 := netstat.TCP6Socks(acceptFn)
|
return "", nil
|
||||||
|
|
||||||
// Only return the error if the other method found nothing.
|
|
||||||
if (err != nil && len(tabs6) == 0) || (err6 != nil && len(tabs) == 0) {
|
|
||||||
return "", xerrors.Errorf("inspect port %d: %w", port, errors.Join(err, err6))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var proc *netstat.Process
|
// Defensive check.
|
||||||
if len(tabs) > 0 {
|
if tabs[0].Process == nil {
|
||||||
proc = tabs[0].Process
|
|
||||||
} else if len(tabs6) > 0 {
|
|
||||||
proc = tabs6[0].Process
|
|
||||||
}
|
|
||||||
if proc == nil {
|
|
||||||
// Either nothing is listening on this port or we were unable to read the
|
|
||||||
// process details (permission issues reading /proc/$pid/* potentially).
|
|
||||||
// Or, perhaps /proc/net/tcp{,6} is not listing the port for some reason.
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// The process name provided by go-netstat does not include the full command
|
// The process name provided by go-netstat does not include the full command
|
||||||
// line so grab that instead.
|
// line so grab that instead.
|
||||||
pid := proc.Pid
|
pid := tabs[0].Process.Pid
|
||||||
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err)
|
return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err)
|
||||||
|
|
|
@ -9,21 +9,10 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
network := os.Args[1]
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
var address string
|
|
||||||
switch network {
|
|
||||||
case "tcp4":
|
|
||||||
address = "127.0.0.1"
|
|
||||||
case "tcp6":
|
|
||||||
address = "[::]"
|
|
||||||
default:
|
|
||||||
log.Fatalf("invalid network: %s", network)
|
|
||||||
}
|
|
||||||
l, err := net.Listen(network, address+":0")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("listen error: err=%s", err)
|
log.Fatalf("listen error: err=%s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue