fix: rewrite url to agent ip in single tailnet (#11810)

This restores previous behavior of being able to cache connections
across agents in single tailnet.
This commit is contained in:
Colin Adler 2024-02-01 00:25:52 -06:00 committed by GitHub
parent 073d1f7078
commit 3ace7982aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 216 additions and 40 deletions

View File

@ -99,15 +99,14 @@ func NewServerTailnet(
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
// Bugfix: for some reason all calls to tn.dialContext come from
// "localhost", causing connections to be cached and requests to go to the
// wrong workspaces. This disables keepalives for now until the root cause
// can be found.
tn.transport.MaxIdleConnsPerHost = -1
tn.transport.DisableKeepAlives = true
// These options are mostly just picked at random, and they can likely be
// fine tuned further. Generally, users are running applications in dev mode
// which can generate hundreds of requests per page load, so we increased
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
// conns.
tn.transport.MaxIdleConnsPerHost = 6
tn.transport.MaxIdleConns = 0
tn.transport.IdleConnTimeout = 10 * time.Minute
// We intentionally don't verify the certificate chain here.
// The connection to the workspace is already established and most
// apps are already going to be accessed over plain HTTP, this config
@ -308,7 +307,15 @@ type ServerTailnet struct {
}
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Rewrite the targetURL's Host to point to the agent's IP. This is
// necessary because due to TCP connection caching, each agent needs to be
// addressed invidivually. Otherwise, all connections get dialed as
// "localhost:port", causing connections to be shared across agents.
tgt := *targetURL
_, port, _ := net.SplitHostPort(tgt.Host)
tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port)
proxy := httputil.NewSingleHostReverseProxy(&tgt)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,

View File

@ -3,10 +3,13 @@ package coderd_test
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strconv"
"sync/atomic"
"testing"
"github.com/google/uuid"
@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
defer cancel()
// Connect through the ServerTailnet
agentID, _, serverTailnet := setupAgent(t, nil)
agents, serverTailnet := setupServerTailnetAgent(t, 1)
a := agents[0]
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
conn, release, err := serverTailnet.AgentConn(ctx, a.id)
require.NoError(t, err)
defer release()
@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentID, _, serverTailnet := setupAgent(t, nil)
agents, serverTailnet := setupServerTailnetAgent(t, 1)
a := agents[0]
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp := serverTailnet.ReverseProxy(u, u, agentID)
rp := serverTailnet.ReverseProxy(u, u, a.id)
rw := httptest.NewRecorder()
req := httptest.NewRequest(
@ -74,13 +79,147 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
assert.Equal(t, http.StatusOK, res.StatusCode)
})
t.Run("HostRewrite", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agents, serverTailnet := setupServerTailnetAgent(t, 1)
a := agents[0]
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp := serverTailnet.ReverseProxy(u, u, a.id)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
require.NoError(t, err)
// Ensure the reverse proxy director rewrites the url host to the agent's IP.
rp.Director(req)
assert.Equal(t,
fmt.Sprintf("[%s]:%d", tailnet.IPFromUUID(a.id).String(), codersdk.WorkspaceAgentHTTPAPIServerPort),
req.URL.Host,
)
})
t.Run("CachesConnection", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agents, serverTailnet := setupServerTailnetAgent(t, 1)
a := agents[0]
port := ":4444"
ln, err := a.TailnetConn().Listen("tcp", port)
require.NoError(t, err)
wln := &wrappedListener{Listener: ln}
serverClosed := make(chan struct{})
go func() {
defer close(serverClosed)
//nolint:gosec
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello from agent"))
}))
}()
defer func() {
// wait for server to close
<-serverClosed
}()
defer ln.Close()
u, err := url.Parse("http://127.0.0.1" + port)
require.NoError(t, err)
rp := serverTailnet.ReverseProxy(u, u, a.id)
for i := 0; i < 5; i++ {
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
_, _ = io.Copy(io.Discard, res.Body)
res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
assert.Equal(t, 1, wln.getDials())
})
t.Run("NotReusedBetweenAgents", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agents, serverTailnet := setupServerTailnetAgent(t, 2)
port := ":4444"
for i, ag := range agents {
i := i
ln, err := ag.TailnetConn().Listen("tcp", port)
require.NoError(t, err)
wln := &wrappedListener{Listener: ln}
serverClosed := make(chan struct{})
go func() {
defer close(serverClosed)
//nolint:gosec
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(strconv.Itoa(i)))
}))
}()
defer func() { //nolint:revive
// wait for server to close
<-serverClosed
}()
defer ln.Close() //nolint:revive
}
u, err := url.Parse("http://127.0.0.1" + port)
require.NoError(t, err)
for i, ag := range agents {
rp := serverTailnet.ReverseProxy(u, u, ag.id)
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
body, _ := io.ReadAll(res.Body)
res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, strconv.Itoa(i), string(body))
}
})
t.Run("HTTPSProxy", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
agentID, _, serverTailnet := setupAgent(t, nil)
agents, serverTailnet := setupServerTailnetAgent(t, 1)
a := agents[0]
const expectedResponseCode = 209
// Test that we can proxy HTTPS traffic.
@ -92,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
uri, err := url.Parse(s.URL)
require.NoError(t, err)
rp := serverTailnet.ReverseProxy(uri, uri, agentID)
rp := serverTailnet.ReverseProxy(uri, uri, a.id)
rw := httptest.NewRecorder()
req := httptest.NewRequest(
@ -109,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
})
}
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
type wrappedListener struct {
net.Listener
dials int32
}
func (w *wrappedListener) Accept() (net.Conn, error) {
conn, err := w.Listener.Accept()
if err != nil {
return nil, err
}
atomic.AddInt32(&w.dials, 1)
return conn, nil
}
func (w *wrappedListener) getDials() int {
return int(atomic.LoadInt32(&w.dials))
}
type agentWithID struct {
id uuid.UUID
agent.Agent
}
func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd.ServerTailnet) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
DERPMap: derpMap,
}
coord := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coord.Close()
})
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
t.Cleanup(c.Close)
agents := []agentWithID{}
options := agent.Options{
Client: c,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
Addresses: agentAddresses,
for i := 0; i < agentNum; i++ {
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
DERPMap: derpMap,
}
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
t.Cleanup(c.Close)
options := agent.Options{
Client: c,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
}
ag := agent.New(options)
t.Cleanup(func() {
_ = ag.Close()
})
// Wait for the agent to connect.
require.Eventually(t, func() bool {
return coord.Node(manifest.AgentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag})
}
ag := agent.New(options)
t.Cleanup(func() {
_ = ag.Close()
})
// Wait for the agent to connect.
require.Eventually(t, func() bool {
return coord.Node(manifest.AgentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
serverTailnet, err := coderd.NewServerTailnet(
context.Background(),
logger,
derpServer,
func() *tailcfg.DERPMap { return manifest.DERPMap },
func() *tailcfg.DERPMap { return derpMap },
false,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
trace.NewNoopTracerProvider(),
@ -157,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
_ = serverTailnet.Close()
})
return manifest.AgentID, ag, serverTailnet
return agents, serverTailnet
}