mirror of https://github.com/coder/coder.git
fix(agent): Allow signal propagation when running as PID 1 (#6141)
This commit is contained in:
parent
af59e2bcfa
commit
6f3f7f2937
|
@ -1,6 +1,10 @@
|
|||
package reaper
|
||||
|
||||
import "github.com/hashicorp/go-reap"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/go-reap"
|
||||
)
|
||||
|
||||
type Option func(o *options)
|
||||
|
||||
|
@ -22,7 +26,16 @@ func WithPIDCallback(ch reap.PidCh) Option {
|
|||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
// WithCatchSignals sets the signals that are caught and forwarded to the
|
||||
// child process. By default no signals are forwarded.
|
||||
func WithCatchSignals(sigs ...os.Signal) Option {
|
||||
return func(o *options) {
|
||||
o.CatchSignals = sigs
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
}
|
||||
|
|
|
@ -3,8 +3,11 @@
|
|||
package reaper_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -15,9 +18,8 @@ import (
|
|||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
//nolint:paralleltest // Non-parallel subtest.
|
||||
func TestReap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Don't run the reaper test in CI. It does weird
|
||||
// things like forkexecing which may have unintended
|
||||
// consequences in CI.
|
||||
|
@ -28,8 +30,9 @@ func TestReap(t *testing.T) {
|
|||
// OK checks that's the reaper is successfully reaping
|
||||
// exited processes and passing the PIDs through the shared
|
||||
// channel.
|
||||
|
||||
//nolint:paralleltest // Signal handling.
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pids := make(reap.PidCh, 1)
|
||||
err := reaper.ForkReap(
|
||||
reaper.WithPIDCallback(pids),
|
||||
|
@ -64,3 +67,39 @@ func TestReap(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Signal handling.
|
||||
func TestReapInterrupt(t *testing.T) {
|
||||
// Don't run the reaper test in CI. It does weird
|
||||
// things like forkexecing which may have unintended
|
||||
// consequences in CI.
|
||||
if _, ok := os.LookupEnv("CI"); ok {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
}
|
||||
|
||||
errC := make(chan error, 1)
|
||||
pids := make(reap.PidCh, 1)
|
||||
|
||||
// Use signals to notify when the child process is ready for the
|
||||
// next step of our test.
|
||||
usrSig := make(chan os.Signal, 1)
|
||||
signal.Notify(usrSig, syscall.SIGUSR1, syscall.SIGUSR2)
|
||||
defer signal.Stop(usrSig)
|
||||
|
||||
go func() {
|
||||
errC <- reaper.ForkReap(
|
||||
reaper.WithPIDCallback(pids),
|
||||
reaper.WithCatchSignals(os.Interrupt),
|
||||
// Signal propagation does not extend to children of children, so
|
||||
// we create a little bash script to ensure sleep is interrupted.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
|
||||
)
|
||||
}()
|
||||
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR1)
|
||||
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR2)
|
||||
|
||||
require.NoError(t, <-errC)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ package reaper
|
|||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-reap"
|
||||
|
@ -15,6 +16,24 @@ func IsInitProcess() bool {
|
|||
return os.Getpid() == 1
|
||||
}
|
||||
|
||||
func catchSignals(pid int, sigs []os.Signal) {
|
||||
if len(sigs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, sigs...)
|
||||
defer signal.Stop(sc)
|
||||
|
||||
for {
|
||||
s := <-sc
|
||||
sig, ok := s.(syscall.Signal)
|
||||
if ok {
|
||||
_ = syscall.Kill(pid, sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForkReap spawns a goroutine that reaps children. In order to avoid
|
||||
// complications with spawning `exec.Commands` in the same process that
|
||||
// is reaping, we forkexec a child process. This prevents a race between
|
||||
|
@ -51,13 +70,17 @@ func ForkReap(opt ...Option) error {
|
|||
}
|
||||
|
||||
//#nosec G204
|
||||
pid, _ := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
|
||||
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fork exec: %w", err)
|
||||
}
|
||||
|
||||
go catchSignals(pid, opts.CatchSignals)
|
||||
|
||||
var wstatus syscall.WaitStatus
|
||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||
for xerrors.Is(err, syscall.EINTR) {
|
||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -68,7 +68,10 @@ func workspaceAgent() *cobra.Command {
|
|||
// Do not start a reaper on the child process. It's important
|
||||
// to do this else we fork bomb ourselves.
|
||||
args := append(os.Args, "--no-reap")
|
||||
err := reaper.ForkReap(reaper.WithExecArgs(args...))
|
||||
err := reaper.ForkReap(
|
||||
reaper.WithExecArgs(args...),
|
||||
reaper.WithCatchSignals(InterruptSignals...),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to reap", slog.Error(err))
|
||||
return xerrors.Errorf("fork reap: %w", err)
|
||||
|
|
Loading…
Reference in New Issue