feat: Ignore agent pprof port in listening ports (#6515)

* feat: Ignore agent pprof port in listening ports
This commit is contained in:
Steven Masley 2023-03-09 10:53:00 -06:00 committed by GitHub
parent 3de29307b5
commit 2abae42cec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 5 deletions

View File

@ -77,6 +77,7 @@ type Options struct {
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
AgentPorts map[int]string
}
type Client interface {
@ -123,6 +124,7 @@ func New(options Options) io.Closer {
tempDir: options.TempDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
ignorePorts: options.AgentPorts,
connStatsChan: make(chan *agentsdk.Stats, 1),
}
a.init(ctx)
@ -136,6 +138,10 @@ type agent struct {
filesystem afero.Fs
logDir string
tempDir string
// ignorePorts tells the api handler which ports to ignore when
// listing all listening ports. This is helpful to hide ports that
// are used by the agent, that the user does not care about.
ignorePorts map[int]string
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration

View File

@ -11,7 +11,7 @@ import (
"github.com/coder/coder/codersdk"
)
func (*agent) apiHandler() http.Handler {
func (a *agent) apiHandler() http.Handler {
r := chi.NewRouter()
r.Get("/", func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
@ -19,16 +19,24 @@ func (*agent) apiHandler() http.Handler {
})
})
lp := &listeningPortsHandler{}
// Make a copy to ensure the map is not modified after the handler is
// created.
cpy := make(map[int]string)
for k, b := range a.ignorePorts {
cpy[k] = b
}
lp := &listeningPortsHandler{ignorePorts: cpy}
r.Get("/api/v0/listening-ports", lp.handler)
return r
}
type listeningPortsHandler struct {
mut sync.Mutex
ports []codersdk.WorkspaceAgentListeningPort
mtime time.Time
mut sync.Mutex
ports []codersdk.WorkspaceAgentListeningPort
mtime time.Time
ignorePorts map[int]string
}
// handler returns a list of listening ports. This is tested by coderd's

View File

@ -36,6 +36,11 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
continue
}
// Ignore ports that we've been told to ignore.
if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok {
continue
}
// Don't include ports that we've already seen. This can happen on
// Windows, and maybe on Linux if you're using a shared listener socket.
if _, ok := seen[tab.LocalAddr.Port]; ok {

View File

@ -11,6 +11,7 @@ import (
"os/signal"
"path/filepath"
"runtime"
"strconv"
"sync"
"time"
@ -51,6 +52,7 @@ func workspaceAgent() *cobra.Command {
if err != nil {
return xerrors.Errorf("parse %q: %w", rawURL, err)
}
agentPorts := map[int]string{}
isLinux := runtime.GOOS == "linux"
@ -122,6 +124,10 @@ func workspaceAgent() *cobra.Command {
_ = pprof.Handler
pprofSrvClose := serveHandler(ctx, logger, nil, pprofAddress, "pprof")
defer pprofSrvClose()
// Do a best effort here. If this fails, it's not a big deal.
if port, err := urlPort(pprofAddress); err == nil {
agentPorts[port] = "pprof"
}
// exchangeToken returns a session token.
// This is abstracted to allow for the same looping condition
@ -202,6 +208,7 @@ func workspaceAgent() *cobra.Command {
EnvironmentVariables: map[string]string{
"GIT_ASKPASS": executablePath,
},
AgentPorts: agentPorts,
})
<-ctx.Done()
return closer.Close()
@ -264,3 +271,35 @@ func (c *closeWriter) Write(p []byte) (int, error) {
}
return c.w.Write(p)
}
// extractPort handles different url strings.
// - localhost:6060
// - http://localhost:6060
func extractPort(u string) (int, error) {
port, firstError := urlPort(u)
if firstError == nil {
return port, nil
}
// Try with a scheme
port, err := urlPort("http://" + u)
if err == nil {
return port, nil
}
return -1, xerrors.Errorf("invalid url %q: %w", u, firstError)
}
// urlPort extracts the port from a valid URL.
func urlPort(u string) (int, error) {
parsed, err := url.Parse(u)
if err != nil {
return -1, xerrors.Errorf("invalid url %q: %w", u, err)
}
if parsed.Port() != "" {
port, err := strconv.ParseInt(parsed.Port(), 10, 64)
if err == nil && port > 0 {
return int(port), nil
}
}
return -1, xerrors.Errorf("invalid port: %s", u)
}

View File

@ -0,0 +1,63 @@
package cli
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func Test_extractPort(t *testing.T) {
t.Parallel()
tests := []struct {
name string
urlString string
want int
wantErr bool
}{
{
name: "Empty",
urlString: "",
wantErr: true,
},
{
name: "NoScheme",
urlString: "localhost:6060",
want: 6060,
},
{
name: "WithScheme",
urlString: "http://localhost:6060",
want: 6060,
},
{
name: "NoPort",
urlString: "http://localhost",
wantErr: true,
},
{
name: "NoPortNoScheme",
urlString: "localhost",
wantErr: true,
},
{
name: "OnlyPort",
urlString: "6060",
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := extractPort(tt.urlString)
if tt.wantErr {
require.Error(t, err, fmt.Sprintf("extractPort(%v)", tt.urlString))
} else {
require.NoError(t, err, fmt.Sprintf("extractPort(%v)", tt.urlString))
require.Equal(t, tt.want, got, fmt.Sprintf("extractPort(%v)", tt.urlString))
}
})
}
}