fix: use header flags in wsproxy server (#10985)

This commit is contained in:
Dean Sheather 2023-12-05 02:13:42 -08:00 committed by GitHub
parent b07b40b346
commit 695f57f7ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 127 additions and 46 deletions

View File

@ -419,9 +419,9 @@ func (r *RootCmd) scaletestCleanup() *clibase.Cmd {
}
client.HTTPClient = &http.Client{
Transport: &headerTransport{
transport: http.DefaultTransport,
header: map[string][]string{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
@ -570,9 +570,9 @@ func (r *RootCmd) scaletestCreateWorkspaces() *clibase.Cmd {
}
client.HTTPClient = &http.Client{
Transport: &headerTransport{
transport: http.DefaultTransport,
header: map[string][]string{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
@ -896,9 +896,9 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *clibase.Cmd {
// Bypass rate limiting
client.HTTPClient = &http.Client{
Transport: &headerTransport{
transport: http.DefaultTransport,
header: map[string][]string{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},

View File

@ -471,11 +471,11 @@ type RootCmd struct {
}
func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
transport, ok := client.HTTPClient.Transport.(*headerTransport)
transport, ok := client.HTTPClient.Transport.(*codersdk.HeaderTransport)
if !ok {
transport = &headerTransport{
transport: client.HTTPClient.Transport,
header: http.Header{},
transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{},
}
client.HTTPClient.Transport = transport
}
@ -509,7 +509,7 @@ func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
return
}
transport.header.Add(codersdk.CLITelemetryHeader, s)
transport.Header.Add(codersdk.CLITelemetryHeader, s)
}
// InitClient sets client to a new client.
@ -609,10 +609,10 @@ func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing
}
}
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
transport := &headerTransport{
transport: http.DefaultTransport,
header: http.Header{},
func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) {
transport := &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: http.Header{},
}
headers := r.header
if r.headerCommand != "" {
@ -630,23 +630,32 @@ func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, server
cmd.Stderr = io.Discard
err := cmd.Run()
if err != nil {
return xerrors.Errorf("failed to run %v: %w", cmd.Args, err)
return nil, xerrors.Errorf("failed to run %v: %w", cmd.Args, err)
}
scanner := bufio.NewScanner(&outBuf)
for scanner.Scan() {
headers = append(headers, scanner.Text())
}
if err := scanner.Err(); err != nil {
return xerrors.Errorf("scan %v: %w", cmd.Args, err)
return nil, xerrors.Errorf("scan %v: %w", cmd.Args, err)
}
}
for _, header := range headers {
parts := strings.SplitN(header, "=", 2)
if len(parts) < 2 {
return xerrors.Errorf("split header %q had less than two parts", header)
return nil, xerrors.Errorf("split header %q had less than two parts", header)
}
transport.header.Add(parts[0], parts[1])
transport.Header.Add(parts[0], parts[1])
}
return transport, nil
}
func (r *RootCmd) setClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL) error {
transport, err := r.HeaderTransport(ctx, serverURL)
if err != nil {
return xerrors.Errorf("create header transport: %w", err)
}
client.URL = serverURL
client.HTTPClient = &http.Client{
Transport: transport,
@ -853,24 +862,6 @@ func (r *RootCmd) Verbosef(inv *clibase.Invocation, fmtStr string, args ...inter
}
}
type headerTransport struct {
transport http.RoundTripper
header http.Header
}
func (h *headerTransport) Header() http.Header {
return h.header.Clone()
}
func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.header {
for _, vv := range v {
req.Header.Add(k, vv)
}
}
return h.transport.RoundTrip(req)
}
// DumpHandler provides a custom SIGQUIT and SIGTRAP handler that dumps the
// stacktrace of all goroutines to stderr and a well-known file in the home
// directory. This is useful for debugging deadlock issues that may occur in

View File

@ -530,3 +530,33 @@ func WithQueryParam(key, value string) RequestOption {
r.URL.RawQuery = q.Encode()
}
}
// HeaderTransport is a http.RoundTripper that adds some headers to all requests.
// @typescript-ignore HeaderTransport
type HeaderTransport struct {
Transport http.RoundTripper
Header http.Header
}
var _ http.RoundTripper = &HeaderTransport{}
func (h *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.Header {
for _, vv := range v {
req.Header.Add(k, vv)
}
}
if h.Transport == nil {
h.Transport = http.DefaultTransport
}
return h.Transport.RoundTrip(req)
}
func (h *HeaderTransport) CloseIdleConnections() {
type closeIdler interface {
CloseIdleConnections()
}
if tr, ok := h.Transport.(closeIdler); ok {
tr.CloseIdleConnections()
}
}

View File

@ -273,11 +273,8 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
ip := tailnet.IP()
var header http.Header
headerTransport, ok := c.HTTPClient.Transport.(interface {
Header() http.Header
})
if ok {
header = headerTransport.Header()
if headerTransport, ok := c.HTTPClient.Transport.(*HeaderTransport); ok {
header = headerTransport.Header
}
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},

View File

@ -43,7 +43,7 @@ func (c *closers) Add(f func()) {
*c = append(*c, f)
}
func (*RootCmd) proxyServer() *clibase.Cmd {
func (r *RootCmd) proxyServer() *clibase.Cmd {
var (
cfg = new(codersdk.DeploymentValues)
// Filter options for only relevant ones.
@ -192,6 +192,15 @@ func (*RootCmd) proxyServer() *clibase.Cmd {
defer httpClient.CloseIdleConnections()
closers.Add(httpClient.CloseIdleConnections)
// Attach header transport so we process --header and
// --header-command flags
headerTransport, err := r.HeaderTransport(ctx, primaryAccessURL.Value())
if err != nil {
return xerrors.Errorf("configure header transport: %w", err)
}
headerTransport.Transport = httpClient.Transport
httpClient.Transport = headerTransport
// A newline is added before for visibility in terminal output.
cliui.Infof(inv.Stdout, "\nView the Web UI: %s", cfg.AccessURL.String())

View File

@ -0,0 +1,54 @@
package cli_test
import (
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/pty/ptytest"
)
func Test_Headers(t *testing.T) {
t.Parallel()
const (
headerName1 = "X-Test-Header-1"
headerVal1 = "test-value-1"
headerName2 = "X-Test-Header-2"
headerVal2 = "test-value-2"
)
// We're not going to actually start a proxy, we're going to point it
// towards a fake server that returns an unexpected status code. This'll
// cause the proxy to exit with an error that we can check for.
var called int64
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&called, 1)
assert.Equal(t, headerVal1, r.Header.Get(headerName1))
assert.Equal(t, headerVal2, r.Header.Get(headerName2))
w.WriteHeader(http.StatusTeapot) // lol
}))
defer srv.Close()
inv, _ := newCLI(t, "wsproxy", "server",
"--primary-access-url", srv.URL,
"--proxy-session-token", "test-token",
"--access-url", "http://localhost:8080",
"--header", fmt.Sprintf("%s=%s", headerName1, headerVal1),
"--header-command", fmt.Sprintf("printf %s=%s", headerName2, headerVal2),
)
pty := ptytest.New(t)
inv.Stdout = pty.Output()
err := inv.Run()
require.Error(t, err)
require.ErrorContains(t, err, "unexpected status code 418")
require.NoError(t, pty.Close())
assert.EqualValues(t, 1, atomic.LoadInt64(&called))
}