mirror of https://github.com/coder/coder.git
fix: use header flags in wsproxy server (#10985)
This commit is contained in:
parent
b07b40b346
commit
695f57f7ff
|
@ -419,9 +419,9 @@ func (r *RootCmd) scaletestCleanup() *clibase.Cmd {
|
|||
}
|
||||
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
transport: http.DefaultTransport,
|
||||
header: map[string][]string{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
|
@ -570,9 +570,9 @@ func (r *RootCmd) scaletestCreateWorkspaces() *clibase.Cmd {
|
|||
}
|
||||
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
transport: http.DefaultTransport,
|
||||
header: map[string][]string{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
|
@ -896,9 +896,9 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
|
|||
|
||||
// Bypass rate limiting
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
transport: http.DefaultTransport,
|
||||
header: map[string][]string{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
|
|
53
cli/root.go
53
cli/root.go
|
@ -471,11 +471,11 @@ type RootCmd struct {
|
|||
}
|
||||
|
||||
func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
|
||||
transport, ok := client.HTTPClient.Transport.(*headerTransport)
|
||||
transport, ok := client.HTTPClient.Transport.(*codersdk.HeaderTransport)
|
||||
if !ok {
|
||||
transport = &headerTransport{
|
||||
transport: client.HTTPClient.Transport,
|
||||
header: http.Header{},
|
||||
transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{},
|
||||
}
|
||||
client.HTTPClient.Transport = transport
|
||||
}
|
||||
|
@ -509,7 +509,7 @@ func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
|
|||
return
|
||||
}
|
||||
|
||||
transport.header.Add(codersdk.CLITelemetryHeader, s)
|
||||
transport.Header.Add(codersdk.CLITelemetryHeader, s)
|
||||
}
|
||||
|
||||
// InitClient sets client to a new client.
|
||||
|
@ -609,10 +609,10 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
|
|||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
|
||||
transport := &headerTransport{
|
||||
transport: http.DefaultTransport,
|
||||
header: http.Header{},
|
||||
func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) {
|
||||
transport := &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: http.Header{},
|
||||
}
|
||||
headers := r.header
|
||||
if r.headerCommand != "" {
|
||||
|
@ -630,23 +630,32 @@ func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, server
|
|||
cmd.Stderr = io.Discard
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to run %v: %w", cmd.Args, err)
|
||||
return nil, xerrors.Errorf("failed to run %v: %w", cmd.Args, err)
|
||||
}
|
||||
scanner := bufio.NewScanner(&outBuf)
|
||||
for scanner.Scan() {
|
||||
headers = append(headers, scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return xerrors.Errorf("scan %v: %w", cmd.Args, err)
|
||||
return nil, xerrors.Errorf("scan %v: %w", cmd.Args, err)
|
||||
}
|
||||
}
|
||||
for _, header := range headers {
|
||||
parts := strings.SplitN(header, "=", 2)
|
||||
if len(parts) < 2 {
|
||||
return xerrors.Errorf("split header %q had less than two parts", header)
|
||||
return nil, xerrors.Errorf("split header %q had less than two parts", header)
|
||||
}
|
||||
transport.header.Add(parts[0], parts[1])
|
||||
transport.Header.Add(parts[0], parts[1])
|
||||
}
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
|
||||
transport, err := r.HeaderTransport(ctx, serverURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create header transport: %w", err)
|
||||
}
|
||||
|
||||
client.URL = serverURL
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: transport,
|
||||
|
@ -853,24 +862,6 @@ func (r *RootCmd) Verbosef(inv *clibase.Invocation, fmtStr string, args ...inter
|
|||
}
|
||||
}
|
||||
|
||||
type headerTransport struct {
|
||||
transport http.RoundTripper
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (h *headerTransport) Header() http.Header {
|
||||
return h.header.Clone()
|
||||
}
|
||||
|
||||
func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.header {
|
||||
for _, vv := range v {
|
||||
req.Header.Add(k, vv)
|
||||
}
|
||||
}
|
||||
return h.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// DumpHandler provides a custom SIGQUIT and SIGTRAP handler that dumps the
|
||||
// stacktrace of all goroutines to stderr and a well-known file in the home
|
||||
// directory. This is useful for debugging deadlock issues that may occur in
|
||||
|
|
|
@ -530,3 +530,33 @@ func WithQueryParam(key, value string) RequestOption {
|
|||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// HeaderTransport is a http.RoundTripper that adds some headers to all requests.
|
||||
// @typescript-ignore HeaderTransport
|
||||
type HeaderTransport struct {
|
||||
Transport http.RoundTripper
|
||||
Header http.Header
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &HeaderTransport{}
|
||||
|
||||
func (h *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.Header {
|
||||
for _, vv := range v {
|
||||
req.Header.Add(k, vv)
|
||||
}
|
||||
}
|
||||
if h.Transport == nil {
|
||||
h.Transport = http.DefaultTransport
|
||||
}
|
||||
return h.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (h *HeaderTransport) CloseIdleConnections() {
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := h.Transport.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -273,11 +273,8 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
|
|||
|
||||
ip := tailnet.IP()
|
||||
var header http.Header
|
||||
headerTransport, ok := c.HTTPClient.Transport.(interface {
|
||||
Header() http.Header
|
||||
})
|
||||
if ok {
|
||||
header = headerTransport.Header()
|
||||
if headerTransport, ok := c.HTTPClient.Transport.(*HeaderTransport); ok {
|
||||
header = headerTransport.Header
|
||||
}
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||
|
|
|
@ -43,7 +43,7 @@ func (c *closers) Add(f func()) {
|
|||
*c = append(*c, f)
|
||||
}
|
||||
|
||||
func (*RootCmd) proxyServer() *clibase.Cmd {
|
||||
func (r *RootCmd) proxyServer() *clibase.Cmd {
|
||||
var (
|
||||
cfg = new(codersdk.DeploymentValues)
|
||||
// Filter options for only relevant ones.
|
||||
|
@ -192,6 +192,15 @@ func (*RootCmd) proxyServer() *clibase.Cmd {
|
|||
defer httpClient.CloseIdleConnections()
|
||||
closers.Add(httpClient.CloseIdleConnections)
|
||||
|
||||
// Attach header transport so we process --header and
|
||||
// --header-command flags
|
||||
headerTransport, err := r.HeaderTransport(ctx, primaryAccessURL.Value())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("configure header transport: %w", err)
|
||||
}
|
||||
headerTransport.Transport = httpClient.Transport
|
||||
httpClient.Transport = headerTransport
|
||||
|
||||
// A newline is added before for visibility in terminal output.
|
||||
cliui.Infof(inv.Stdout, "\nView the Web UI: %s", cfg.AccessURL.String())
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
)
|
||||
|
||||
func Test_Headers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
headerName1 = "X-Test-Header-1"
|
||||
headerVal1 = "test-value-1"
|
||||
headerName2 = "X-Test-Header-2"
|
||||
headerVal2 = "test-value-2"
|
||||
)
|
||||
|
||||
// We're not going to actually start a proxy, we're going to point it
|
||||
// towards a fake server that returns an unexpected status code. This'll
|
||||
// cause the proxy to exit with an error that we can check for.
|
||||
var called int64
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt64(&called, 1)
|
||||
assert.Equal(t, headerVal1, r.Header.Get(headerName1))
|
||||
assert.Equal(t, headerVal2, r.Header.Get(headerName2))
|
||||
|
||||
w.WriteHeader(http.StatusTeapot) // lol
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
inv, _ := newCLI(t, "wsproxy", "server",
|
||||
"--primary-access-url", srv.URL,
|
||||
"--proxy-session-token", "test-token",
|
||||
"--access-url", "http://localhost:8080",
|
||||
"--header", fmt.Sprintf("%s=%s", headerName1, headerVal1),
|
||||
"--header-command", fmt.Sprintf("printf %s=%s", headerName2, headerVal2),
|
||||
)
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdout = pty.Output()
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "unexpected status code 418")
|
||||
require.NoError(t, pty.Close())
|
||||
|
||||
assert.EqualValues(t, 1, atomic.LoadInt64(&called))
|
||||
}
|
Loading…
Reference in New Issue