fix(agent): Allow signal propagation when running as PID 1 (#6141)

This commit is contained in:
Mathias Fredriksson 2023-02-09 23:07:21 +02:00 committed by GitHub
parent af59e2bcfa
commit 6f3f7f2937
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 11 deletions

View File

@ -1,6 +1,10 @@
package reaper
import "github.com/hashicorp/go-reap"
import (
"os"
"github.com/hashicorp/go-reap"
)
type Option func(o *options)
@ -22,7 +26,16 @@ func WithPIDCallback(ch reap.PidCh) Option {
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
// WithCatchSignals sets the signals that are caught and forwarded to the
// child process. By default no signals are forwarded.
func WithCatchSignals(sigs ...os.Signal) Option {
return func(o *options) {
o.CatchSignals = sigs
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
}

View File

@ -3,8 +3,11 @@
package reaper_test
import (
"fmt"
"os"
"os/exec"
"os/signal"
"syscall"
"testing"
"time"
@ -15,9 +18,8 @@ import (
"github.com/coder/coder/testutil"
)
//nolint:paralleltest // Non-parallel subtest.
func TestReap(t *testing.T) {
t.Parallel()
// Don't run the reaper test in CI. It does weird
// things like forkexecing which may have unintended
// consequences in CI.
@ -28,8 +30,9 @@ func TestReap(t *testing.T) {
// OK checks that's the reaper is successfully reaping
// exited processes and passing the PIDs through the shared
// channel.
//nolint:paralleltest // Signal handling.
t.Run("OK", func(t *testing.T) {
t.Parallel()
pids := make(reap.PidCh, 1)
err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
@ -64,3 +67,39 @@ func TestReap(t *testing.T) {
}
})
}
//nolint:paralleltest // Signal handling.
func TestReapInterrupt(t *testing.T) {
// Don't run the reaper test in CI. It does weird
// things like forkexecing which may have unintended
// consequences in CI.
if _, ok := os.LookupEnv("CI"); ok {
t.Skip("Detected CI, skipping reaper tests")
}
errC := make(chan error, 1)
pids := make(reap.PidCh, 1)
// Use signals to notify when the child process is ready for the
// next step of our test.
usrSig := make(chan os.Signal, 1)
signal.Notify(usrSig, syscall.SIGUSR1, syscall.SIGUSR2)
defer signal.Stop(usrSig)
go func() {
errC <- reaper.ForkReap(
reaper.WithPIDCallback(pids),
reaper.WithCatchSignals(os.Interrupt),
// Signal propagation does not extend to children of children, so
// we create a little bash script to ensure sleep is interrupted.
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
)
}()
require.Equal(t, <-usrSig, syscall.SIGUSR1)
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
require.NoError(t, err)
require.Equal(t, <-usrSig, syscall.SIGUSR2)
require.NoError(t, <-errC)
}

View File

@ -4,6 +4,7 @@ package reaper
import (
"os"
"os/signal"
"syscall"
"github.com/hashicorp/go-reap"
@ -15,6 +16,24 @@ func IsInitProcess() bool {
return os.Getpid() == 1
}
func catchSignals(pid int, sigs []os.Signal) {
if len(sigs) == 0 {
return
}
sc := make(chan os.Signal, 1)
signal.Notify(sc, sigs...)
defer signal.Stop(sc)
for {
s := <-sc
sig, ok := s.(syscall.Signal)
if ok {
_ = syscall.Kill(pid, sig)
}
}
}
// ForkReap spawns a goroutine that reaps children. In order to avoid
// complications with spawning `exec.Commands` in the same process that
// is reaping, we forkexec a child process. This prevents a race between
@ -51,13 +70,17 @@ func ForkReap(opt ...Option) error {
}
//#nosec G204
pid, _ := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
if err != nil {
return xerrors.Errorf("fork exec: %w", err)
}
go catchSignals(pid, opts.CatchSignals)
var wstatus syscall.WaitStatus
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
for xerrors.Is(err, syscall.EINTR) {
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
}
return nil
return err
}

View File

@ -68,7 +68,10 @@ func workspaceAgent() *cobra.Command {
// Do not start a reaper on the child process. It's important
// to do this else we fork bomb ourselves.
args := append(os.Args, "--no-reap")
err := reaper.ForkReap(reaper.WithExecArgs(args...))
err := reaper.ForkReap(
reaper.WithExecArgs(args...),
reaper.WithCatchSignals(InterruptSignals...),
)
if err != nil {
logger.Error(ctx, "failed to reap", slog.Error(err))
return xerrors.Errorf("fork reap: %w", err)