mirror of https://github.com/coder/coder.git
feat: Ignore agent pprof port in listening ports (#6515)
* feat: Ignore agent pprof port in listening ports
This commit is contained in:
parent
3de29307b5
commit
2abae42cec
|
@ -77,6 +77,7 @@ type Options struct {
|
|||
ReconnectingPTYTimeout time.Duration
|
||||
EnvironmentVariables map[string]string
|
||||
Logger slog.Logger
|
||||
AgentPorts map[int]string
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
|
@ -123,6 +124,7 @@ func New(options Options) io.Closer {
|
|||
tempDir: options.TempDir,
|
||||
lifecycleUpdate: make(chan struct{}, 1),
|
||||
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
|
||||
ignorePorts: options.AgentPorts,
|
||||
connStatsChan: make(chan *agentsdk.Stats, 1),
|
||||
}
|
||||
a.init(ctx)
|
||||
|
@ -136,6 +138,10 @@ type agent struct {
|
|||
filesystem afero.Fs
|
||||
logDir string
|
||||
tempDir string
|
||||
// ignorePorts tells the api handler which ports to ignore when
|
||||
// listing all listening ports. This is helpful to hide ports that
|
||||
// are used by the agent, that the user does not care about.
|
||||
ignorePorts map[int]string
|
||||
|
||||
reconnectingPTYs sync.Map
|
||||
reconnectingPTYTimeout time.Duration
|
||||
|
|
18
agent/api.go
18
agent/api.go
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func (*agent) apiHandler() http.Handler {
|
||||
func (a *agent) apiHandler() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||
|
@ -19,16 +19,24 @@ func (*agent) apiHandler() http.Handler {
|
|||
})
|
||||
})
|
||||
|
||||
lp := &listeningPortsHandler{}
|
||||
// Make a copy to ensure the map is not modified after the handler is
|
||||
// created.
|
||||
cpy := make(map[int]string)
|
||||
for k, b := range a.ignorePorts {
|
||||
cpy[k] = b
|
||||
}
|
||||
|
||||
lp := &listeningPortsHandler{ignorePorts: cpy}
|
||||
r.Get("/api/v0/listening-ports", lp.handler)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
type listeningPortsHandler struct {
|
||||
mut sync.Mutex
|
||||
ports []codersdk.WorkspaceAgentListeningPort
|
||||
mtime time.Time
|
||||
mut sync.Mutex
|
||||
ports []codersdk.WorkspaceAgentListeningPort
|
||||
mtime time.Time
|
||||
ignorePorts map[int]string
|
||||
}
|
||||
|
||||
// handler returns a list of listening ports. This is tested by coderd's
|
||||
|
|
|
@ -36,6 +36,11 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
|
|||
continue
|
||||
}
|
||||
|
||||
// Ignore ports that we've been told to ignore.
|
||||
if _, ok := lp.ignorePorts[int(tab.LocalAddr.Port)]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't include ports that we've already seen. This can happen on
|
||||
// Windows, and maybe on Linux if you're using a shared listener socket.
|
||||
if _, ok := seen[tab.LocalAddr.Port]; ok {
|
||||
|
|
39
cli/agent.go
39
cli/agent.go
|
@ -11,6 +11,7 @@ import (
|
|||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -51,6 +52,7 @@ func workspaceAgent() *cobra.Command {
|
|||
if err != nil {
|
||||
return xerrors.Errorf("parse %q: %w", rawURL, err)
|
||||
}
|
||||
agentPorts := map[int]string{}
|
||||
|
||||
isLinux := runtime.GOOS == "linux"
|
||||
|
||||
|
@ -122,6 +124,10 @@ func workspaceAgent() *cobra.Command {
|
|||
_ = pprof.Handler
|
||||
pprofSrvClose := serveHandler(ctx, logger, nil, pprofAddress, "pprof")
|
||||
defer pprofSrvClose()
|
||||
// Do a best effort here. If this fails, it's not a big deal.
|
||||
if port, err := urlPort(pprofAddress); err == nil {
|
||||
agentPorts[port] = "pprof"
|
||||
}
|
||||
|
||||
// exchangeToken returns a session token.
|
||||
// This is abstracted to allow for the same looping condition
|
||||
|
@ -202,6 +208,7 @@ func workspaceAgent() *cobra.Command {
|
|||
EnvironmentVariables: map[string]string{
|
||||
"GIT_ASKPASS": executablePath,
|
||||
},
|
||||
AgentPorts: agentPorts,
|
||||
})
|
||||
<-ctx.Done()
|
||||
return closer.Close()
|
||||
|
@ -264,3 +271,35 @@ func (c *closeWriter) Write(p []byte) (int, error) {
|
|||
}
|
||||
return c.w.Write(p)
|
||||
}
|
||||
|
||||
// extractPort handles different url strings.
|
||||
// - localhost:6060
|
||||
// - http://localhost:6060
|
||||
func extractPort(u string) (int, error) {
|
||||
port, firstError := urlPort(u)
|
||||
if firstError == nil {
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// Try with a scheme
|
||||
port, err := urlPort("http://" + u)
|
||||
if err == nil {
|
||||
return port, nil
|
||||
}
|
||||
return -1, xerrors.Errorf("invalid url %q: %w", u, firstError)
|
||||
}
|
||||
|
||||
// urlPort extracts the port from a valid URL.
|
||||
func urlPort(u string) (int, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("invalid url %q: %w", u, err)
|
||||
}
|
||||
if parsed.Port() != "" {
|
||||
port, err := strconv.ParseInt(parsed.Port(), 10, 64)
|
||||
if err == nil && port > 0 {
|
||||
return int(port), nil
|
||||
}
|
||||
}
|
||||
return -1, xerrors.Errorf("invalid port: %s", u)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_extractPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlString string
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
urlString: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "NoScheme",
|
||||
urlString: "localhost:6060",
|
||||
want: 6060,
|
||||
},
|
||||
{
|
||||
name: "WithScheme",
|
||||
urlString: "http://localhost:6060",
|
||||
want: 6060,
|
||||
},
|
||||
{
|
||||
name: "NoPort",
|
||||
urlString: "http://localhost",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "NoPortNoScheme",
|
||||
urlString: "localhost",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "OnlyPort",
|
||||
urlString: "6060",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := extractPort(tt.urlString)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err, fmt.Sprintf("extractPort(%v)", tt.urlString))
|
||||
} else {
|
||||
require.NoError(t, err, fmt.Sprintf("extractPort(%v)", tt.urlString))
|
||||
require.Equal(t, tt.want, got, fmt.Sprintf("extractPort(%v)", tt.urlString))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue