fix: use a waitgroup to ensure all connections are cleaned up in agent (#5910)

* fix: use a waitgroup to ensure all connections are cleaned up in agent

There was a race where connections would be created at the same time as close.
The `net.Conn` produced by Tailscale doesn't close then the listener does.

* Remove accidental test
This commit is contained in:
Kyle Carberry 2023-01-29 17:20:30 -06:00 committed by GitHub
parent ce36a84dd5
commit 0d08065488
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 39 deletions

View File

@ -398,24 +398,28 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
}
}()
if err = a.trackConnGoroutine(func() {
var wg sync.WaitGroup
for {
conn, err := sshListener.Accept()
if err != nil {
return
break
}
wg.Add(1)
closed := make(chan struct{})
_ = a.trackConnGoroutine(func() {
go func() {
select {
case <-network.Closed():
case <-closed:
case <-a.closed:
_ = conn.Close()
}
_ = conn.Close()
})
_ = a.trackConnGoroutine(func() {
wg.Done()
}()
go func() {
defer close(closed)
a.sshServer.HandleConn(conn)
})
}()
}
wg.Wait()
}); err != nil {
return nil, err
}
@ -431,35 +435,47 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
}()
if err = a.trackConnGoroutine(func() {
logger := a.logger.Named("reconnecting-pty")
var wg sync.WaitGroup
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
logger.Debug(ctx, "accept pty failed", slog.Error(err))
return
}
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
continue
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
continue
}
var msg codersdk.WorkspaceAgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
continue
break
}
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-a.closed:
_ = conn.Close()
}
wg.Done()
}()
go func() {
defer close(closed)
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
return
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
return
}
var msg codersdk.WorkspaceAgentReconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
return
}
_ = a.handleReconnectingPTY(ctx, logger, msg, conn)
}()
}
wg.Wait()
}); err != nil {
return nil, err
}
@ -474,20 +490,29 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
}
}()
if err = a.trackConnGoroutine(func() {
var wg sync.WaitGroup
for {
conn, err := speedtestListener.Accept()
if err != nil {
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
return
break
}
if err = a.trackConnGoroutine(func() {
wg.Add(1)
closed := make(chan struct{})
go func() {
select {
case <-closed:
case <-a.closed:
_ = conn.Close()
}
wg.Done()
}()
go func() {
defer close(closed)
_ = speedtest.ServeConn(conn)
}); err != nil {
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
_ = conn.Close()
return
}
}()
}
wg.Wait()
}); err != nil {
return nil, err
}
@ -511,7 +536,10 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
ErrorLog: slog.Stdlib(ctx, a.logger.Named("statistics_http_server"), slog.LevelInfo),
}
go func() {
<-ctx.Done()
select {
case <-ctx.Done():
case <-a.closed:
}
_ = server.Close()
}()

View File

@ -23,10 +23,6 @@ import (
func Test_Runner(t *testing.T) {
t.Parallel()
// There's a race condition in agent/agent.go where connections
// aren't closed when the Tailnet connection is. This causes the
// goroutines to hang around and cause the test to fail.
t.Skip("TODO: fix this test")
t.Run("OK", func(t *testing.T) {
t.Parallel()