diff --git a/agent/agent_test.go b/agent/agent_test.go index 6cef939d47..19e28346ad 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1,6 +1,7 @@ package agent_test import ( + "bufio" "bytes" "context" "encoding/json" @@ -152,7 +153,7 @@ func TestAgent_Stats_Magic(t *testing.T) { require.NoError(t, err) require.Equal(t, expected, strings.TrimSpace(string(output))) }) - t.Run("Tracks", func(t *testing.T) { + t.Run("TracksVSCode", func(t *testing.T) { t.Parallel() if runtime.GOOS == "window" { t.Skip("Sleeping for infinity doesn't work on Windows") @@ -191,6 +192,77 @@ func TestAgent_Stats_Magic(t *testing.T) { err = session.Wait() require.NoError(t, err) }) + + t.Run("TracksJetBrains", func(t *testing.T) { + t.Parallel() + if runtime.GOOS != "linux" { + t.Skip("JetBrains tracking is only supported on Linux") + } + + ctx := testutil.Context(t, testutil.WaitLong) + + // JetBrains tracking works by looking at the process name listening on the + // forwarded port. If the process's command line includes the magic string + // we are looking for, then we assume it is a JetBrains editor. So when we + // connect to the port we must ensure the process includes that magic string + // to fool the agent into thinking this is JetBrains. To do this we need to + // spawn an external process (in this case a simple echo server) so we can + // control the process name. The -D here is just to mimic how Java options + // are set but is not necessary as the agent looks only for the magic + // string itself anywhere in the command. + _, b, _, ok := runtime.Caller(0) + require.True(t, ok) + dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go") + echoServerCmd := exec.Command("go", "run", dir, + "-D", agentssh.MagicProcessCmdlineJetBrains) + stdout, err := echoServerCmd.StdoutPipe() + require.NoError(t, err) + err = echoServerCmd.Start() + require.NoError(t, err) + defer echoServerCmd.Process.Kill() + + // The echo server prints its port as the first line. + sc := bufio.NewScanner(stdout) + sc.Scan() + remotePort := sc.Text() + + //nolint:dogsled + conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + + tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort)) + require.NoError(t, err) + t.Cleanup(func() { + // always close on failure of test + _ = conn.Close() + _ = tunneledConn.Close() + }) + + var s *agentsdk.Stats + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = <-stats + return ok && s.ConnectionCount > 0 && + s.SessionCountJetBrains == 1 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats with conn open: %+v", s, + ) + + // Kill the server and connection after checking for the echo. + requireEcho(t, tunneledConn) + _ = echoServerCmd.Process.Kill() + _ = tunneledConn.Close() + + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = <-stats + return ok && s.ConnectionCount == 0 && + s.SessionCountJetBrains == 0 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats after conn closes: %+v", s, + ) + }) } func TestAgent_SessionExec(t *testing.T) { diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index f88446ecf3..1021d04592 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -47,8 +47,12 @@ const ( MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. MagicSessionTypeVSCode = "vscode" - // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. + // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains + // extension to identify itself. MagicSessionTypeJetBrains = "jetbrains" + // MagicProcessCmdlineJetBrains is a string in a process's command line that + // uniquely identifies it as JetBrains software. + MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains" ) type Server struct { @@ -111,7 +115,11 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom srv := &ssh.Server{ ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-tcpip": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + // Wrapper is designed to find and track JetBrains Gateway connections. + wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, newChan, &s.connCountJetBrains) + ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx) + }, "direct-streamlocal@openssh.com": directStreamLocalHandler, "session": ssh.DefaultSessionHandler, }, @@ -291,8 +299,8 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv s.connCountVSCode.Add(1) defer s.connCountVSCode.Add(-1) case MagicSessionTypeJetBrains: - s.connCountJetBrains.Add(1) - defer s.connCountJetBrains.Add(-1) + // Do nothing here because JetBrains launches hundreds of ssh sessions. + // We instead track JetBrains in the single persistent tcp forwarding channel. case "": s.connCountSSHSession.Add(1) defer s.connCountSSHSession.Add(-1) diff --git a/agent/agentssh/jetbrainstrack.go b/agent/agentssh/jetbrainstrack.go new file mode 100644 index 0000000000..25c8f04dd6 --- /dev/null +++ b/agent/agentssh/jetbrainstrack.go @@ -0,0 +1,89 @@ +package agentssh + +import ( + "strings" + "sync" + + "cdr.dev/slog" + "github.com/gliderlabs/ssh" + "go.uber.org/atomic" + gossh "golang.org/x/crypto/ssh" +) + +// localForwardChannelData is copied from the ssh package. +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 +} + +// JetbrainsChannelWatcher is used to track JetBrains port forwarded (Gateway) +// channels. If the port forward is something other than JetBrains, this struct +// is a noop. +type JetbrainsChannelWatcher struct { + gossh.NewChannel + jetbrainsCounter *atomic.Int64 +} + +func NewJetbrainsChannelWatcher(ctx ssh.Context, logger slog.Logger, newChannel gossh.NewChannel, counter *atomic.Int64) gossh.NewChannel { + d := localForwardChannelData{} + if err := gossh.Unmarshal(newChannel.ExtraData(), &d); err != nil { + // If the data fails to unmarshal, do nothing. + logger.Warn(ctx, "failed to unmarshal port forward data", slog.Error(err)) + return newChannel + } + + // If we do get a port, we should be able to get the matching PID and from + // there look up the invocation. + cmdline, err := getListeningPortProcessCmdline(d.DestPort) + if err != nil { + logger.Warn(ctx, "failed to inspect port", + slog.F("destination_port", d.DestPort), + slog.Error(err)) + return newChannel + } + + // If this is not JetBrains, then we do not need to do anything special. We + // attempt to match on something that appears unique to JetBrains software. + if !strings.Contains(strings.ToLower(cmdline), strings.ToLower(MagicProcessCmdlineJetBrains)) { + return newChannel + } + + logger.Debug(ctx, "discovered forwarded JetBrains process", + slog.F("destination_port", d.DestPort)) + + return &JetbrainsChannelWatcher{ + NewChannel: newChannel, + jetbrainsCounter: counter, + } +} + +func (w *JetbrainsChannelWatcher) Accept() (gossh.Channel, <-chan *gossh.Request, error) { + c, r, err := w.NewChannel.Accept() + if err != nil { + return c, r, err + } + w.jetbrainsCounter.Add(1) + + return &ChannelOnClose{ + Channel: c, + done: func() { + w.jetbrainsCounter.Add(-1) + }, + }, r, err +} + +type ChannelOnClose struct { + gossh.Channel + // once ensures close only decrements the counter once. + // Because close can be called multiple times. + once sync.Once + done func() +} + +func (c *ChannelOnClose) Close() error { + c.once.Do(c.done) + return c.Channel.Close() +} diff --git a/agent/agentssh/portinspection_supported.go b/agent/agentssh/portinspection_supported.go new file mode 100644 index 0000000000..45f59accc4 --- /dev/null +++ b/agent/agentssh/portinspection_supported.go @@ -0,0 +1,31 @@ +//go:build linux + +package agentssh + +import ( + "fmt" + "os" + + "github.com/cakturk/go-netstat/netstat" + "golang.org/x/xerrors" +) + +func getListeningPortProcessCmdline(port uint32) (string, error) { + tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool { + return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port + }) + if err != nil { + return "", xerrors.Errorf("inspect port %d: %w", port, err) + } + if len(tabs) == 0 { + return "", nil + } + // The process name provided by go-netstat does not include the full command + // line so grab that instead. + pid := tabs[0].Process.Pid + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) + if err != nil { + return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err) + } + return string(data), nil +} diff --git a/agent/agentssh/portinspection_unsupported.go b/agent/agentssh/portinspection_unsupported.go new file mode 100644 index 0000000000..f010d03858 --- /dev/null +++ b/agent/agentssh/portinspection_unsupported.go @@ -0,0 +1,9 @@ +//go:build !linux + +package agentssh + +func getListeningPortProcessCmdline(port uint32) (string, error) { + // We are not worrying about other platforms at the moment because Gateway + // only supports Linux anyway. + return "", nil +} diff --git a/scripts/echoserver/main.go b/scripts/echoserver/main.go new file mode 100644 index 0000000000..cb30a0b383 --- /dev/null +++ b/scripts/echoserver/main.go @@ -0,0 +1,50 @@ +package main + +// A simple echo server. It listens on a random port, prints that port, then +// echos back anything sent to it. + +import ( + "errors" + "fmt" + "io" + "log" + "net" +) + +func main() { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + log.Fatalf("listen error: err=%s", err) + } + + defer l.Close() + tcpAddr, valid := l.Addr().(*net.TCPAddr) + if !valid { + log.Fatal("address is not valid") + } + + remotePort := tcpAddr.Port + _, err = fmt.Println(remotePort) + if err != nil { + log.Fatalf("print error: err=%s", err) + } + + for { + conn, err := l.Accept() + if err != nil { + log.Fatalf("accept error, err=%s", err) + return + } + + go func() { + defer conn.Close() + _, err := io.Copy(conn, conn) + + if errors.Is(err, io.EOF) { + return + } else if err != nil { + log.Fatalf("copy error, err=%s", err) + } + }() + } +} diff --git a/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx b/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx index a302780468..8c5b154b4e 100644 --- a/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx +++ b/site/src/components/Dashboard/DeploymentBanner/DeploymentBannerView.tsx @@ -15,6 +15,7 @@ import BuildingIcon from "@mui/icons-material/Build"; import Tooltip from "@mui/material/Tooltip"; import { Link as RouterLink } from "react-router-dom"; import Link from "@mui/material/Link"; +import { JetBrainsIcon } from "components/Icons/JetBrainsIcon"; import { VSCodeIcon } from "components/Icons/VSCodeIcon"; import DownloadIcon from "@mui/icons-material/CloudDownload"; import UploadIcon from "@mui/icons-material/CloudUpload"; @@ -248,6 +249,21 @@ export const DeploymentBannerView: FC = ({ + +
+ + {typeof stats?.session_count.jetbrains === "undefined" + ? "-" + : stats?.session_count.jetbrains} +
+
+
diff --git a/site/src/components/Icons/JetBrainsIcon.tsx b/site/src/components/Icons/JetBrainsIcon.tsx new file mode 100644 index 0000000000..fb551a7e52 --- /dev/null +++ b/site/src/components/Icons/JetBrainsIcon.tsx @@ -0,0 +1,67 @@ +import SvgIcon, { SvgIconProps } from "@mui/material/SvgIcon"; + +export const JetBrainsIcon = (props: SvgIconProps) => ( + + + + + + + + + + + + + + + + + + + + + +);