mirror of https://github.com/coder/coder.git
fix: rewrite url to agent ip in single tailnet (#11810)
This restores previous behavior of being able to cache connections across agents in single tailnet.
This commit is contained in:
parent
073d1f7078
commit
3ace7982aa
|
@ -99,15 +99,14 @@ func NewServerTailnet(
|
|||
transport: tailnetTransport.Clone(),
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
|
||||
// Bugfix: for some reason all calls to tn.dialContext come from
|
||||
// "localhost", causing connections to be cached and requests to go to the
|
||||
// wrong workspaces. This disables keepalives for now until the root cause
|
||||
// can be found.
|
||||
tn.transport.MaxIdleConnsPerHost = -1
|
||||
tn.transport.DisableKeepAlives = true
|
||||
|
||||
// These options are mostly just picked at random, and they can likely be
|
||||
// fine tuned further. Generally, users are running applications in dev mode
|
||||
// which can generate hundreds of requests per page load, so we increased
|
||||
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
|
||||
// conns.
|
||||
tn.transport.MaxIdleConnsPerHost = 6
|
||||
tn.transport.MaxIdleConns = 0
|
||||
tn.transport.IdleConnTimeout = 10 * time.Minute
|
||||
// We intentionally don't verify the certificate chain here.
|
||||
// The connection to the workspace is already established and most
|
||||
// apps are already going to be accessed over plain HTTP, this config
|
||||
|
@ -308,7 +307,15 @@ type ServerTailnet struct {
|
|||
}
|
||||
|
||||
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
// Rewrite the targetURL's Host to point to the agent's IP. This is
|
||||
// necessary because due to TCP connection caching, each agent needs to be
|
||||
// addressed invidivually. Otherwise, all connections get dialed as
|
||||
// "localhost:port", causing connections to be shared across agents.
|
||||
tgt := *targetURL
|
||||
_, port, _ := net.SplitHostPort(tgt.Host)
|
||||
tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port)
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(&tgt)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
|
|
|
@ -3,10 +3,13 @@ package coderd_test
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
// Connect through the ServerTailnet
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
agents, serverTailnet := setupServerTailnetAgent(t, 1)
|
||||
a := agents[0]
|
||||
|
||||
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
|
||||
conn, release, err := serverTailnet.AgentConn(ctx, a.id)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
|
@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
agents, serverTailnet := setupServerTailnetAgent(t, 1)
|
||||
a := agents[0]
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
rp := serverTailnet.ReverseProxy(u, u, agentID)
|
||||
rp := serverTailnet.ReverseProxy(u, u, a.id)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
|
@ -74,13 +79,147 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
|
|||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("HostRewrite", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
agents, serverTailnet := setupServerTailnetAgent(t, 1)
|
||||
a := agents[0]
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
rp := serverTailnet.ReverseProxy(u, u, a.id)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure the reverse proxy director rewrites the url host to the agent's IP.
|
||||
rp.Director(req)
|
||||
assert.Equal(t,
|
||||
fmt.Sprintf("[%s]:%d", tailnet.IPFromUUID(a.id).String(), codersdk.WorkspaceAgentHTTPAPIServerPort),
|
||||
req.URL.Host,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("CachesConnection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
agents, serverTailnet := setupServerTailnetAgent(t, 1)
|
||||
a := agents[0]
|
||||
port := ":4444"
|
||||
ln, err := a.TailnetConn().Listen("tcp", port)
|
||||
require.NoError(t, err)
|
||||
wln := &wrappedListener{Listener: ln}
|
||||
|
||||
serverClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverClosed)
|
||||
//nolint:gosec
|
||||
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("hello from agent"))
|
||||
}))
|
||||
}()
|
||||
defer func() {
|
||||
// wait for server to close
|
||||
<-serverClosed
|
||||
}()
|
||||
|
||||
defer ln.Close()
|
||||
|
||||
u, err := url.Parse("http://127.0.0.1" + port)
|
||||
require.NoError(t, err)
|
||||
|
||||
rp := serverTailnet.ReverseProxy(u, u, a.id)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
u.String(),
|
||||
nil,
|
||||
).WithContext(ctx)
|
||||
|
||||
rp.ServeHTTP(rw, req)
|
||||
res := rw.Result()
|
||||
|
||||
_, _ = io.Copy(io.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, wln.getDials())
|
||||
})
|
||||
|
||||
t.Run("NotReusedBetweenAgents", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
agents, serverTailnet := setupServerTailnetAgent(t, 2)
|
||||
port := ":4444"
|
||||
|
||||
for i, ag := range agents {
|
||||
i := i
|
||||
ln, err := ag.TailnetConn().Listen("tcp", port)
|
||||
require.NoError(t, err)
|
||||
wln := &wrappedListener{Listener: ln}
|
||||
|
||||
serverClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverClosed)
|
||||
//nolint:gosec
|
||||
_ = http.Serve(wln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(strconv.Itoa(i)))
|
||||
}))
|
||||
}()
|
||||
defer func() { //nolint:revive
|
||||
// wait for server to close
|
||||
<-serverClosed
|
||||
}()
|
||||
|
||||
defer ln.Close() //nolint:revive
|
||||
}
|
||||
|
||||
u, err := url.Parse("http://127.0.0.1" + port)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i, ag := range agents {
|
||||
rp := serverTailnet.ReverseProxy(u, u, ag.id)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
u.String(),
|
||||
nil,
|
||||
).WithContext(ctx)
|
||||
|
||||
rp.ServeHTTP(rw, req)
|
||||
res := rw.Result()
|
||||
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, strconv.Itoa(i), string(body))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HTTPSProxy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
agents, serverTailnet := setupServerTailnetAgent(t, 1)
|
||||
a := agents[0]
|
||||
|
||||
const expectedResponseCode = 209
|
||||
// Test that we can proxy HTTPS traffic.
|
||||
|
@ -92,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
|
|||
uri, err := url.Parse(s.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
rp := serverTailnet.ReverseProxy(uri, uri, agentID)
|
||||
rp := serverTailnet.ReverseProxy(uri, uri, a.id)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
|
@ -109,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
|
||||
type wrappedListener struct {
|
||||
net.Listener
|
||||
dials int32
|
||||
}
|
||||
|
||||
func (w *wrappedListener) Accept() (net.Conn, error) {
|
||||
conn, err := w.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
atomic.AddInt32(&w.dials, 1)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (w *wrappedListener) getDials() int {
|
||||
return int(atomic.LoadInt32(&w.dials))
|
||||
}
|
||||
|
||||
type agentWithID struct {
|
||||
id uuid.UUID
|
||||
agent.Agent
|
||||
}
|
||||
|
||||
func setupServerTailnetAgent(t *testing.T, agentNum int) ([]agentWithID, *coderd.ServerTailnet) {
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
|
||||
manifest := agentsdk.Manifest{
|
||||
AgentID: uuid.New(),
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
|
||||
coord := tailnet.NewCoordinator(logger)
|
||||
t.Cleanup(func() {
|
||||
_ = coord.Close()
|
||||
})
|
||||
|
||||
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
|
||||
t.Cleanup(c.Close)
|
||||
agents := []agentWithID{}
|
||||
|
||||
options := agent.Options{
|
||||
Client: c,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
Logger: logger.Named("agent"),
|
||||
Addresses: agentAddresses,
|
||||
for i := 0; i < agentNum; i++ {
|
||||
manifest := agentsdk.Manifest{
|
||||
AgentID: uuid.New(),
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
|
||||
c := agenttest.NewClient(t, logger, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
|
||||
t.Cleanup(c.Close)
|
||||
|
||||
options := agent.Options{
|
||||
Client: c,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
Logger: logger.Named("agent"),
|
||||
}
|
||||
|
||||
ag := agent.New(options)
|
||||
t.Cleanup(func() {
|
||||
_ = ag.Close()
|
||||
})
|
||||
|
||||
// Wait for the agent to connect.
|
||||
require.Eventually(t, func() bool {
|
||||
return coord.Node(manifest.AgentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
agents = append(agents, agentWithID{id: manifest.AgentID, Agent: ag})
|
||||
}
|
||||
|
||||
ag := agent.New(options)
|
||||
t.Cleanup(func() {
|
||||
_ = ag.Close()
|
||||
})
|
||||
|
||||
// Wait for the agent to connect.
|
||||
require.Eventually(t, func() bool {
|
||||
return coord.Node(manifest.AgentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
serverTailnet, err := coderd.NewServerTailnet(
|
||||
context.Background(),
|
||||
logger,
|
||||
derpServer,
|
||||
func() *tailcfg.DERPMap { return manifest.DERPMap },
|
||||
func() *tailcfg.DERPMap { return derpMap },
|
||||
false,
|
||||
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
|
||||
trace.NewNoopTracerProvider(),
|
||||
|
@ -157,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
|
|||
_ = serverTailnet.Close()
|
||||
})
|
||||
|
||||
return manifest.AgentID, ag, serverTailnet
|
||||
return agents, serverTailnet
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue