mirror of https://github.com/coder/coder.git
fix: use a waitgroup to ensure all connections are cleaned up in agent (#5910)
* fix: use a waitgroup to ensure all connections are cleaned up in agent There was a race where connections would be created at the same time as close. The `net.Conn` produced by Tailscale doesn't close then the listener does. * Remove accidental test
This commit is contained in:
parent
ce36a84dd5
commit
0d08065488
|
@ -398,24 +398,28 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
|
|||
}
|
||||
}()
|
||||
if err = a.trackConnGoroutine(func() {
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := sshListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
closed := make(chan struct{})
|
||||
_ = a.trackConnGoroutine(func() {
|
||||
go func() {
|
||||
select {
|
||||
case <-network.Closed():
|
||||
case <-closed:
|
||||
case <-a.closed:
|
||||
_ = conn.Close()
|
||||
}
|
||||
_ = conn.Close()
|
||||
})
|
||||
_ = a.trackConnGoroutine(func() {
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
defer close(closed)
|
||||
a.sshServer.HandleConn(conn)
|
||||
})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -431,35 +435,47 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
|
|||
}()
|
||||
if err = a.trackConnGoroutine(func() {
|
||||
logger := a.logger.Named("reconnecting-pty")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := reconnectingPTYListener.Accept()
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "accept pty failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
// This cannot use a JSON decoder, since that can
|
||||
// buffer additional data that is required for the PTY.
|
||||
rawLen := make([]byte, 2)
|
||||
_, err = conn.Read(rawLen)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
length := binary.LittleEndian.Uint16(rawLen)
|
||||
data := make([]byte, length)
|
||||
_, err = conn.Read(data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var msg codersdk.WorkspaceAgentReconnectingPTYInit
|
||||
err = json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
continue
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-closed:
|
||||
case <-a.closed:
|
||||
_ = conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
defer close(closed)
|
||||
// This cannot use a JSON decoder, since that can
|
||||
// buffer additional data that is required for the PTY.
|
||||
rawLen := make([]byte, 2)
|
||||
_, err = conn.Read(rawLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
length := binary.LittleEndian.Uint16(rawLen)
|
||||
data := make([]byte, length)
|
||||
_, err = conn.Read(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var msg codersdk.WorkspaceAgentReconnectingPTYInit
|
||||
err = json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = a.handleReconnectingPTY(ctx, logger, msg, conn)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -474,20 +490,29 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
|
|||
}
|
||||
}()
|
||||
if err = a.trackConnGoroutine(func() {
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := speedtestListener.Accept()
|
||||
if err != nil {
|
||||
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
|
||||
return
|
||||
break
|
||||
}
|
||||
if err = a.trackConnGoroutine(func() {
|
||||
wg.Add(1)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-closed:
|
||||
case <-a.closed:
|
||||
_ = conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
defer close(closed)
|
||||
_ = speedtest.ServeConn(conn)
|
||||
}); err != nil {
|
||||
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -511,7 +536,10 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_
|
|||
ErrorLog: slog.Stdlib(ctx, a.logger.Named("statistics_http_server"), slog.LevelInfo),
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-a.closed:
|
||||
}
|
||||
_ = server.Close()
|
||||
}()
|
||||
|
||||
|
|
|
@ -23,10 +23,6 @@ import (
|
|||
|
||||
func Test_Runner(t *testing.T) {
|
||||
t.Parallel()
|
||||
// There's a race condition in agent/agent.go where connections
|
||||
// aren't closed when the Tailnet connection is. This causes the
|
||||
// goroutines to hang around and cause the test to fail.
|
||||
t.Skip("TODO: fix this test")
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
Loading…
Reference in New Issue