coder/coderd/wsconncache/wsconncache.go

166 lines
4.2 KiB
Go

// Package wsconncache caches workspace agent connections by UUID.
package wsconncache
import (
"context"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/atomic"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"github.com/coder/coder/codersdk"
)
// New creates a new workspace connection cache that closes
// connections after the inactive timeout provided.
//
// Agent connections are cached due to WebRTC negotiation
// taking a few hundred milliseconds.
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
if inactiveTimeout == 0 {
inactiveTimeout = 5 * time.Minute
}
return &Cache{
closed: make(chan struct{}),
dialer: dialer,
inactiveTimeout: inactiveTimeout,
}
}
// Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
*codersdk.WorkspaceAgentConn
locks atomic.Uint64
timeoutMutex sync.Mutex
timeout *time.Timer
timeoutCancel context.CancelFunc
transport *http.Transport
}
func (c *Conn) HTTPTransport() *http.Transport {
return c.transport
}
// Close ends the HTTP transport if exists, and closes the agent.
func (c *Conn) Close() error {
if c.transport != nil {
c.transport.CloseIdleConnections()
}
c.timeoutMutex.Lock()
defer c.timeoutMutex.Unlock()
if c.timeout != nil {
c.timeout.Stop()
}
return c.WorkspaceAgentConn.Close()
}
type Cache struct {
closed chan struct{}
closeMutex sync.Mutex
closeGroup sync.WaitGroup
connGroup singleflight.Group
connMap sync.Map
dialer Dialer
inactiveTimeout time.Duration
}
// Acquire gets or establishes a connection with the dialer using the ID provided.
// If a connection is in-progress, that connection or error will be returned.
//
// The returned function is used to release a lock on the connection. Once zero
// locks exist on a connection, the inactive timeout will begin to tick down.
// After the time expires, the connection will be cleared from the cache.
func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
rawConn, found := c.connMap.Load(id.String())
// If the connection isn't found, establish a new one!
if !found {
var err error
// A singleflight group is used to allow for concurrent requests to the
// same identifier to resolve.
rawConn, err, _ = c.connGroup.Do(id.String(), func() (interface{}, error) {
c.closeMutex.Lock()
select {
case <-c.closed:
c.closeMutex.Unlock()
return nil, xerrors.New("closed")
default:
}
c.closeGroup.Add(1)
c.closeMutex.Unlock()
agentConn, err := c.dialer(r, id)
if err != nil {
c.closeGroup.Done()
return nil, xerrors.Errorf("dial: %w", err)
}
timeoutCtx, timeoutCancelFunc := context.WithCancel(context.Background())
defaultTransport, valid := http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
transport := defaultTransport.Clone()
transport.DialContext = agentConn.DialContext
conn := &Conn{
WorkspaceAgentConn: agentConn,
timeoutCancel: timeoutCancelFunc,
transport: transport,
}
go func() {
defer c.closeGroup.Done()
select {
case <-timeoutCtx.Done():
case <-c.closed:
case <-conn.Closed():
}
c.connMap.Delete(id.String())
c.connGroup.Forget(id.String())
_ = conn.Close()
}()
return conn, nil
})
if err != nil {
return nil, nil, err
}
c.connMap.Store(id.String(), rawConn)
}
conn, _ := rawConn.(*Conn)
conn.timeoutMutex.Lock()
defer conn.timeoutMutex.Unlock()
if conn.timeout != nil {
conn.timeout.Stop()
}
conn.locks.Inc()
return conn, func() {
conn.timeoutMutex.Lock()
defer conn.timeoutMutex.Unlock()
if conn.timeout != nil {
conn.timeout.Stop()
}
conn.locks.Dec()
if conn.locks.Load() == 0 {
conn.timeout = time.AfterFunc(c.inactiveTimeout, conn.timeoutCancel)
}
}, nil
}
func (c *Cache) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return nil
default:
}
close(c.closed)
c.closeGroup.Wait()
return nil
}