diff --git a/prefork/prefork.go b/prefork/prefork.go index d883640147..4dc00f6b47 100644 --- a/prefork/prefork.go +++ b/prefork/prefork.go @@ -77,6 +77,32 @@ type Prefork struct { // It is recommended to set this to func() { os.Exit(1) } if no custom // cleanup is needed. OnMasterDeath func() + + // OnChildSpawn is called in the master process whenever a new child process is spawned. + // It receives the PID of the newly spawned child process. + // + // If this callback returns an error, the prefork operation will be aborted. + OnChildSpawn func(pid int) error + + // OnMasterReady is called in the master process after all child processes have been spawned. + // It receives a slice of all child process PIDs. + // + // If this callback returns an error, the prefork operation will be aborted. + OnMasterReady func(childPIDs []int) error + + // OnChildRecover is called in the master process when a child process is restarted + // after a crash. It receives the PID of the newly recovered child process. + // + // The callback's error return value is ignored. + OnChildRecover func(pid int) error + + // CommandProducer is called to create child process commands. + // If nil, the default implementation using os.Executable() is used. + // This can be used for testing or customizing child process behavior. + // + // The function receives the files to be passed as ExtraFiles to the child process + // and must return a started command. + CommandProducer func(files []*os.File) (*exec.Cmd, error) } // IsChild checks if the current thread/process is a child. @@ -104,6 +130,22 @@ func (p *Prefork) logger() Logger { } func (p *Prefork) watchMaster(masterPID int) { + if runtime.GOOS == "windows" { + // On Windows, os.Getppid() returns a static PID that doesn't change + // when the parent exits (no reparenting). Use FindProcess+Wait instead. + proc, err := os.FindProcess(masterPID) + if err == nil { + _, _ = proc.Wait() + } + p.logger().Printf("master process died\n") + p.OnMasterDeath() + return + } + + // Unix/Linux/macOS: When the master exits, the OS reparents the child + // to another process, causing Getppid() to change. Comparing against + // the original masterPID (instead of hardcoding 1) ensures this works + // correctly when the master itself is PID 1 (e.g. in Docker containers). ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() @@ -127,9 +169,27 @@ func (p *Prefork) listen(addr string) (net.Listener, error) { return reuseport.Listen(p.Network, addr) } + // File descriptor 3 is the first ExtraFiles entry passed by the master process. return net.FileListener(os.NewFile(3, "")) } +// listenAsChild performs the common child process setup: creates the listener +// and starts watching the master process if OnMasterDeath is configured. +func (p *Prefork) listenAsChild(addr string) (net.Listener, error) { + ln, err := p.listen(addr) + if err != nil { + return nil, err + } + + p.ln = ln + + if p.OnMasterDeath != nil { + go p.watchMaster(os.Getppid()) + } + + return ln, nil +} + func (p *Prefork) setTCPListenerFiles(addr string) error { if p.Network == "" { p.Network = defaultNetwork @@ -158,6 +218,19 @@ func (p *Prefork) setTCPListenerFiles(addr string) error { } func (p *Prefork) doCommand() (*exec.Cmd, error) { + // Use custom CommandProducer if provided + if p.CommandProducer != nil { + cmd, err := p.CommandProducer(p.files) + if err != nil { + return nil, err + } + if cmd == nil || cmd.Process == nil { + return nil, errors.New("prefork: CommandProducer must return a started command") + } + return cmd, nil + } + + // Default implementation using os.Executable() for reliable path resolution executable, err := os.Executable() if err != nil { return nil, err @@ -205,7 +278,7 @@ func (p *Prefork) prefork(addr string) (err error) { goMaxProcs := runtime.GOMAXPROCS(0) sigCh := make(chan procSig, goMaxProcs) - childProcs := make(map[int]*exec.Cmd) + childProcs := make(map[int]*exec.Cmd, goMaxProcs) defer func() { for _, proc := range childProcs { @@ -213,6 +286,9 @@ func (p *Prefork) prefork(addr string) (err error) { } }() + // Collect child PIDs for OnMasterReady callback + childPIDs := make([]int, 0, goMaxProcs) + for range goMaxProcs { var cmd *exec.Cmd if cmd, err = p.doCommand(); err != nil { @@ -220,10 +296,29 @@ func (p *Prefork) prefork(addr string) (err error) { return err } - childProcs[cmd.Process.Pid] = cmd - go func() { - sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()} - }() + pid := cmd.Process.Pid + childProcs[pid] = cmd + childPIDs = append(childPIDs, pid) + + // Call OnChildSpawn callback + if p.OnChildSpawn != nil { + if err = p.OnChildSpawn(pid); err != nil { + p.logger().Printf("OnChildSpawn callback failed for PID %d: %v\n", pid, err) + return err + } + } + + go func(c *exec.Cmd, pid int) { + sigCh <- procSig{pid: pid, err: c.Wait()} + }(cmd, pid) + } + + // Call OnMasterReady callback after all children are spawned + if p.OnMasterReady != nil { + if err = p.OnMasterReady(childPIDs); err != nil { + p.logger().Printf("OnMasterReady callback failed: %v\n", err) + return err + } } var exitedProcs int @@ -237,19 +332,27 @@ func (p *Prefork) prefork(addr string) (err error) { if exitedProcs > p.RecoverThreshold { p.logger().Printf("child prefork processes exit too many times, "+ "which exceeds the value of RecoverThreshold(%d), "+ - "exiting the master process.\n", exitedProcs) + "exiting the master process.\n", p.RecoverThreshold) err = ErrOverRecovery break } var cmd *exec.Cmd - if cmd, err = p.doCommand(); err != nil { + cmd, err = p.doCommand() + if err != nil { break } - childProcs[cmd.Process.Pid] = cmd - go func() { - sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()} - }() + pid := cmd.Process.Pid + childProcs[pid] = cmd + + // Call OnChildRecover callback and ignore its returned error. + if p.OnChildRecover != nil { + _ = p.OnChildRecover(pid) + } + + go func(c *exec.Cmd, pid int) { + sigCh <- procSig{pid: pid, err: c.Wait()} + }(cmd, pid) } return err @@ -258,17 +361,10 @@ func (p *Prefork) prefork(addr string) (err error) { // ListenAndServe serves HTTP requests from the given TCP addr. func (p *Prefork) ListenAndServe(addr string) error { if IsChild() { - ln, err := p.listen(addr) + ln, err := p.listenAsChild(addr) if err != nil { return err } - - p.ln = ln - - if p.OnMasterDeath != nil { - go p.watchMaster(os.Getppid()) - } - return p.ServeFunc(ln) } @@ -277,20 +373,17 @@ func (p *Prefork) ListenAndServe(addr string) error { // ListenAndServeTLS serves HTTPS requests from the given TCP addr. // -// certFile and keyFile are paths to TLS certificate and key files. +// certKey is the path to the TLS private key file. +// certFile is the path to the TLS certificate file. +// +// Note: parameter order is (addr, certKey, certFile) — key before cert. +// Internally forwards to ServeTLSFunc as (certFile, certKey). func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { if IsChild() { - ln, err := p.listen(addr) + ln, err := p.listenAsChild(addr) if err != nil { return err } - - p.ln = ln - - if p.OnMasterDeath != nil { - go p.watchMaster(os.Getppid()) - } - return p.ServeTLSFunc(ln, certFile, certKey) } @@ -302,17 +395,10 @@ func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { // certData and keyData must contain valid TLS certificate and key data. func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error { if IsChild() { - ln, err := p.listen(addr) + ln, err := p.listenAsChild(addr) if err != nil { return err } - - p.ln = ln - - if p.OnMasterDeath != nil { - go p.watchMaster(os.Getppid()) - } - return p.ServeTLSEmbedFunc(ln, certData, keyData) } diff --git a/prefork/prefork_test.go b/prefork/prefork_test.go index 8236e124e4..9240f49f7f 100644 --- a/prefork/prefork_test.go +++ b/prefork/prefork_test.go @@ -1,10 +1,12 @@ package prefork import ( + "errors" "fmt" "math/rand" "net" "os" + "os/exec" "reflect" "runtime" "testing" @@ -224,3 +226,282 @@ func Test_ListenAndServeTLSEmbed(t *testing.T) { t.Error("Prefork.ln is nil") } } + +func Test_Prefork_Logger(t *testing.T) { + t.Parallel() + + s := &fasthttp.Server{} + p := New(s) + + // Test default logger + logger := p.logger() + if logger == nil { + t.Error("Default logger should not be nil") + } + + // Test custom logger + customLogger := &testLogger{} + p.Logger = customLogger + if p.logger() != customLogger { + t.Error("Custom logger should be returned") + } +} + +type testLogger struct { + messages []string +} + +func (l *testLogger) Printf(format string, args ...any) { + l.messages = append(l.messages, fmt.Sprintf(format, args...)) +} + +func Test_Prefork_OnMasterDeath(t *testing.T) { + t.Parallel() + + var called bool + p := &Prefork{ + OnMasterDeath: func() { + called = true + }, + } + + if p.OnMasterDeath == nil { + t.Error("OnMasterDeath should not be nil") + } + + p.OnMasterDeath() + if !called { + t.Error("OnMasterDeath was not called") + } +} + +func Test_Prefork_Callbacks_NotNil(t *testing.T) { + t.Parallel() + + var spawnCalled bool + var readyCalled bool + var recoverCalled bool + + p := &Prefork{ + OnChildSpawn: func(pid int) error { + spawnCalled = true + return nil + }, + OnMasterReady: func(childPIDs []int) error { + readyCalled = true + return nil + }, + OnChildRecover: func(pid int) error { + recoverCalled = true + return nil + }, + } + + // Test that callbacks are set + if p.OnChildSpawn == nil { + t.Error("OnChildSpawn should not be nil") + } + if p.OnMasterReady == nil { + t.Error("OnMasterReady should not be nil") + } + if p.OnChildRecover == nil { + t.Error("OnChildRecover should not be nil") + } + + // Test that callbacks can be called + _ = p.OnChildSpawn(1234) + _ = p.OnMasterReady([]int{1234, 5678}) + _ = p.OnChildRecover(9999) + + if !spawnCalled { + t.Error("OnChildSpawn was not called") + } + if !readyCalled { + t.Error("OnMasterReady was not called") + } + if !recoverCalled { + t.Error("OnChildRecover was not called") + } +} + +func Test_Prefork_Callbacks_Nil(t *testing.T) { + t.Parallel() + + // Test that nil callbacks don't panic when checked + p := &Prefork{} + + if p.OnChildSpawn != nil { + t.Error("OnChildSpawn should be nil by default") + } + if p.OnMasterReady != nil { + t.Error("OnMasterReady should be nil by default") + } + if p.OnChildRecover != nil { + t.Error("OnChildRecover should be nil by default") + } +} + +func Test_Prefork_RecoverThreshold(t *testing.T) { + t.Parallel() + + s := &fasthttp.Server{} + p := New(s) + + // Default should be GOMAXPROCS/2 + expected := runtime.GOMAXPROCS(0) / 2 + if p.RecoverThreshold != expected { + t.Errorf("RecoverThreshold == %d, want %d", p.RecoverThreshold, expected) + } + + // Test custom threshold + p.RecoverThreshold = 10 + if p.RecoverThreshold != 10 { + t.Errorf("RecoverThreshold == %d, want %d", p.RecoverThreshold, 10) + } +} + +func Test_ErrOverRecovery(t *testing.T) { + t.Parallel() + + if ErrOverRecovery == nil { + t.Error("ErrOverRecovery should not be nil") + } + if ErrOverRecovery.Error() != "exceeding the value of RecoverThreshold" { + t.Errorf("ErrOverRecovery message incorrect: %s", ErrOverRecovery.Error()) + } +} + +func Test_ErrOnlyReuseportOnWindows(t *testing.T) { + t.Parallel() + + if ErrOnlyReuseportOnWindows == nil { + t.Error("ErrOnlyReuseportOnWindows should not be nil") + } + if ErrOnlyReuseportOnWindows.Error() != "windows only supports Reuseport = true" { + t.Errorf("ErrOnlyReuseportOnWindows message incorrect: %s", ErrOnlyReuseportOnWindows.Error()) + } +} + +func Test_Listen_ChildCreatesListener(t *testing.T) { + // This test can't run parallel as it modifies env. + + setUp() + defer tearDown() + + p := &Prefork{ + Reuseport: true, + } + addr := getAddr() + + ln, err := p.listen(addr) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer ln.Close() + + if ln == nil { + t.Error("Listener should not be nil") + } +} + +func Test_OnChildSpawn_Error(t *testing.T) { + t.Parallel() + + errExpected := errors.New("spawn callback error") + p := &Prefork{ + OnChildSpawn: func(pid int) error { + return errExpected + }, + } + + // Test that error is returned correctly + err := p.OnChildSpawn(1234) + if err != errExpected { + t.Errorf("OnChildSpawn error == %v, want %v", err, errExpected) + } +} + +func Test_OnMasterReady_Error(t *testing.T) { + t.Parallel() + + errExpected := errors.New("master ready callback error") + p := &Prefork{ + OnMasterReady: func(childPIDs []int) error { + return errExpected + }, + } + + // Test that error is returned correctly + err := p.OnMasterReady([]int{1, 2, 3}) + if err != errExpected { + t.Errorf("OnMasterReady error == %v, want %v", err, errExpected) + } +} + +func Test_OnMasterReady_ReceivesPIDs(t *testing.T) { + t.Parallel() + + var receivedPIDs []int + p := &Prefork{ + OnMasterReady: func(childPIDs []int) error { + receivedPIDs = childPIDs + return nil + }, + } + + expectedPIDs := []int{100, 200, 300} + _ = p.OnMasterReady(expectedPIDs) + + if len(receivedPIDs) != len(expectedPIDs) { + t.Errorf("Received %d PIDs, want %d", len(receivedPIDs), len(expectedPIDs)) + } + + for i, pid := range expectedPIDs { + if receivedPIDs[i] != pid { + t.Errorf("PID[%d] == %d, want %d", i, receivedPIDs[i], pid) + } + } +} + +func Test_CommandProducer(t *testing.T) { + t.Parallel() + + var producerCalled bool + p := &Prefork{ + CommandProducer: func(files []*os.File) (*exec.Cmd, error) { + producerCalled = true + // Re-exec the test binary with a no-op flag for hermetic testing + cmd := exec.Command(os.Args[0], "-test.run=^$") + cmd.ExtraFiles = files + cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") + err := cmd.Start() + return cmd, err + }, + } + + if p.CommandProducer == nil { + t.Error("CommandProducer should not be nil") + } + + cmd, err := p.doCommand() + if err != nil { + t.Fatalf("doCommand failed: %v", err) + } + + _ = cmd.Wait() + + if !producerCalled { + t.Error("CommandProducer was not called") + } +} + +func Test_CommandProducer_Nil_UsesDefault(t *testing.T) { + t.Parallel() + + p := &Prefork{} + + // Verify default CommandProducer is nil + if p.CommandProducer != nil { + t.Error("CommandProducer should be nil by default") + } +}