fix(devtunnel): close `http.Server` before wireguard interface (#2263)

This commit is contained in:
Colin Adler 2022-06-10 18:40:33 -05:00 committed by GitHub
parent de6f86bf7a
commit 8415022bf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 9 deletions

View File

@ -51,7 +51,7 @@ type configExt struct {
// NewWithConfig calls New with the given config. For documentation, see New.
func NewWithConfig(ctx context.Context, logger slog.Logger, cfg Config) (*Tunnel, <-chan error, error) {
err := startUpdateRoutine(ctx, logger, cfg)
routineEnd, err := startUpdateRoutine(ctx, logger, cfg)
if err != nil {
return nil, nil, xerrors.Errorf("start update routine: %w", err)
}
@ -101,6 +101,7 @@ allowed_ip=%s/128`,
case <-ctx.Done():
_ = wgListen.Close()
dev.Close()
<-routineEnd
close(ch)
case <-dev.Wait():
@ -128,21 +129,24 @@ func New(ctx context.Context, logger slog.Logger) (*Tunnel, <-chan error, error)
return NewWithConfig(ctx, logger, cfg)
}
func startUpdateRoutine(ctx context.Context, logger slog.Logger, cfg Config) error {
func startUpdateRoutine(ctx context.Context, logger slog.Logger, cfg Config) (<-chan struct{}, error) {
// Ensure we send the first config before spawning in the background.
_, err := sendConfigToServer(ctx, cfg)
if err != nil {
return xerrors.Errorf("send config to server: %w", err)
return nil, xerrors.Errorf("send config to server: %w", err)
}
endCh := make(chan struct{})
go func() {
defer close(endCh)
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
break
return
case <-ticker.C:
}
@ -152,19 +156,25 @@ func startUpdateRoutine(ctx context.Context, logger slog.Logger, cfg Config) err
}
}
}()
return nil
return endCh, nil
}
func sendConfigToServer(_ context.Context, cfg Config) (created bool, err error) {
func sendConfigToServer(ctx context.Context, cfg Config) (created bool, err error) {
raw, err := json.Marshal(configExt(cfg))
if err != nil {
return false, xerrors.Errorf("marshal config: %w", err)
}
res, err := http.Post("https://"+EndpointHTTPS+"/tun", "application/json", bytes.NewReader(raw))
req, err := http.NewRequestWithContext(ctx, "POST", "https://"+EndpointHTTPS+"/tun", bytes.NewReader(raw))
if err != nil {
return false, xerrors.Errorf("send request: %w", err)
return false, xerrors.Errorf("new request: %w", err)
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return false, xerrors.Errorf("do request: %w", err)
}
_, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close()

View File

@ -2,11 +2,13 @@ package devtunnel_test
import (
"context"
"io"
"net"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
@ -47,17 +49,26 @@ func TestTunnel(t *testing.T) {
go server.Serve(tun.Listener)
defer tun.Listener.Close()
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
require.Eventually(t, func() bool {
req, err := http.NewRequestWithContext(ctx, "GET", tun.URL, nil)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
res, err := httpClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, _ = io.Copy(io.Discard, res.Body)
return res.StatusCode == http.StatusOK
}, time.Minute, time.Second)
httpClient.CloseIdleConnections()
assert.NoError(t, server.Close())
cancelTun()
select {
case <-errCh:
case <-time.After(10 * time.Second):