fix: Use sync.WaitGroup to await hijacked HTTP connections (#337)

WebSockets hijack the HTTP connection from the server, causing
server.Close() to not wait for these connections to fully cleanup.

This adds a global wait-group to the coderd API, which ensures all
WebSocket HTTP handlers have properly exited before returning.
This commit is contained in:
Kyle Carberry 2022-02-20 16:29:16 -06:00 committed by GitHub
parent 3c04c7f3e6
commit d04570ad29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 13 deletions

View File

@ -33,7 +33,7 @@ func Root() *cobra.Command {
Use: "coderd",
RunE: func(cmd *cobra.Command, args []string) error {
logger := slog.Make(sloghuman.Sink(os.Stderr))
handler := coderd.New(&coderd.Options{
handler, closeCoderd := coderd.New(&coderd.Options{
Logger: logger,
Database: databasefake.New(),
Pubsub: database.NewPubsubInMemory(),
@ -49,11 +49,11 @@ func Root() *cobra.Command {
Scheme: "http",
Host: address,
})
closer, err := newProvisionerDaemon(cmd.Context(), client, logger)
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err)
}
defer closer.Close()
defer daemonClose.Close()
errCh := make(chan error)
go func() {
@ -61,6 +61,7 @@ func Root() *cobra.Command {
errCh <- http.Serve(listener, handler)
}()
closeCoderd()
select {
case <-cmd.Context().Done():
return cmd.Context().Err()

View File

@ -2,6 +2,7 @@ package coderd
import (
"net/http"
"sync"
"github.com/go-chi/chi/v5"
@ -20,11 +21,12 @@ type Options struct {
}
// New constructs the Coder API into an HTTP handler.
func New(options *Options) http.Handler {
//
// A wait function is returned to handle awaiting closure
// of hijacked HTTP requests.
func New(options *Options) (http.Handler, func()) {
api := &api{
Database: options.Database,
Logger: options.Logger,
Pubsub: options.Pubsub,
Options: options,
}
r := chi.NewRouter()
@ -144,13 +146,13 @@ func New(options *Options) http.Handler {
})
})
r.NotFound(site.Handler(options.Logger).ServeHTTP)
return r
return r, api.websocketWaitGroup.Wait
}
// API contains all route handlers. Only HTTP handlers should
// be added to this struct for code clarity.
type api struct {
Database database.Store
Logger slog.Logger
Pubsub database.Pubsub
*Options
websocketWaitGroup sync.WaitGroup
}

View File

@ -55,7 +55,7 @@ func New(t *testing.T) *codersdk.Client {
})
}
handler := coderd.New(&coderd.Options{
handler, closeWait := coderd.New(&coderd.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
Database: db,
Pubsub: pubsub,
@ -69,7 +69,10 @@ func New(t *testing.T) *codersdk.Client {
srv.Start()
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
t.Cleanup(srv.Close)
t.Cleanup(func() {
srv.Close()
closeWait()
})
return codersdk.New(serverURL)
}

View File

@ -62,6 +62,8 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request)
})
return
}
api.websocketWaitGroup.Add(1)
defer api.websocketWaitGroup.Done()
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
ID: uuid.New(),
@ -100,7 +102,9 @@ func (api *api) provisionerDaemonsServe(rw http.ResponseWriter, r *http.Request)
err = server.Serve(r.Context(), session)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, fmt.Sprintf("serve: %s", err))
return
}
_ = conn.Close(websocket.StatusGoingAway, "")
}
// The input for a "workspace_provision" job.