mirror of https://github.com/coder/coder.git
fix: lock log sink against concurrent write and close (#10668)
fixes #10663
This commit is contained in:
parent
530be2f96a
commit
dc4b1ef406
|
@ -0,0 +1,38 @@
|
|||
package cliutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type discardAfterClose struct {
|
||||
sync.Mutex
|
||||
wc io.WriteCloser
|
||||
closed bool
|
||||
}
|
||||
|
||||
// DiscardAfterClose is an io.WriteCloser that discards writes after it is closed without errors.
|
||||
// It is useful as a target for a slog.Sink such that an underlying WriteCloser, like a file, can
|
||||
// be cleaned up without race conditions from still-active loggers.
|
||||
func DiscardAfterClose(wc io.WriteCloser) io.WriteCloser {
|
||||
return &discardAfterClose{wc: wc}
|
||||
}
|
||||
|
||||
func (d *discardAfterClose) Write(p []byte) (n int, err error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
if d.closed {
|
||||
return len(p), nil
|
||||
}
|
||||
return d.wc.Write(p)
|
||||
}
|
||||
|
||||
func (d *discardAfterClose) Close() error {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
if d.closed {
|
||||
return nil
|
||||
}
|
||||
d.closed = true
|
||||
return d.wc.Close()
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package cliutil_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliutil"
|
||||
)
|
||||
|
||||
func TestDiscardAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
exErr := errors.New("test")
|
||||
fwc := &fakeWriteCloser{err: exErr}
|
||||
uut := cliutil.DiscardAfterClose(fwc)
|
||||
|
||||
n, err := uut.Write([]byte("one"))
|
||||
require.Equal(t, 3, n)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = uut.Write([]byte("two"))
|
||||
require.Equal(t, 3, n)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = uut.Close()
|
||||
require.Equal(t, exErr, err)
|
||||
|
||||
n, err = uut.Write([]byte("three"))
|
||||
require.Equal(t, 5, n)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, fwc.writes, 2)
|
||||
require.EqualValues(t, "one", fwc.writes[0])
|
||||
require.EqualValues(t, "two", fwc.writes[1])
|
||||
}
|
||||
|
||||
type fakeWriteCloser struct {
|
||||
writes [][]byte
|
||||
closed bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeWriteCloser) Write(p []byte) (n int, err error) {
|
||||
q := make([]byte, len(p))
|
||||
copy(q, p)
|
||||
f.writes = append(f.writes, q)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (f *fakeWriteCloser) Close() error {
|
||||
f.closed = true
|
||||
return f.err
|
||||
}
|
|
@ -28,6 +28,7 @@ import (
|
|||
|
||||
"github.com/coder/coder/v2/cli/clibase"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/cli/cliutil"
|
||||
"github.com/coder/coder/v2/coderd/autobuild/notify"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
@ -114,12 +115,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||
if err != nil {
|
||||
return xerrors.Errorf("error opening %s for logging: %w", logDirPath, err)
|
||||
}
|
||||
dc := cliutil.DiscardAfterClose(logFile)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
_ = logFile.Close()
|
||||
_ = dc.Close()
|
||||
}()
|
||||
|
||||
logger = slog.Make(sloghuman.Sink(logFile))
|
||||
logger = logger.AppendSinks(sloghuman.Sink(dc))
|
||||
if r.verbose {
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clibase"
|
||||
"github.com/coder/coder/v2/cli/cliutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
|
@ -137,15 +138,16 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd {
|
|||
// command via the ProxyCommand SSH option.
|
||||
pid := os.Getppid()
|
||||
|
||||
var logger slog.Logger
|
||||
logger := slog.Make()
|
||||
if logDir != "" {
|
||||
logFilePath := filepath.Join(logDir, fmt.Sprintf("%d.log", pid))
|
||||
logFile, err := fs.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY, 0o600)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("open log file %q: %w", logFilePath, err)
|
||||
}
|
||||
defer logFile.Close()
|
||||
logger = slog.Make(sloghuman.Sink(logFile)).Leveled(slog.LevelDebug)
|
||||
dc := cliutil.DiscardAfterClose(logFile)
|
||||
defer dc.Close()
|
||||
logger = logger.AppendSinks(sloghuman.Sink(dc)).Leveled(slog.LevelDebug)
|
||||
}
|
||||
if r.disableDirect {
|
||||
logger.Info(ctx, "direct connections disabled")
|
||||
|
|
Loading…
Reference in New Issue