Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 122 additions & 36 deletions prefork/prefork.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,32 @@ type Prefork struct {
// It is recommended to set this to func() { os.Exit(1) } if no custom
// cleanup is needed.
OnMasterDeath func()

// OnChildSpawn is called in the master process whenever a new child process is spawned.
// It receives the PID of the newly spawned child process.
//
// If this callback returns an error, the prefork operation will be aborted.
OnChildSpawn func(pid int) error

// OnMasterReady is called in the master process after all child processes have been spawned.
// It receives a slice of all child process PIDs.
//
// If this callback returns an error, the prefork operation will be aborted.
OnMasterReady func(childPIDs []int) error

// OnChildRecover is called in the master process when a child process is restarted
// after a crash. It receives the PID of the newly recovered child process.
//
// The callback's error return value is ignored.
OnChildRecover func(pid int) error

// CommandProducer is called to create child process commands.
// If nil, the default implementation using os.Executable() is used.
// This can be used for testing or customizing child process behavior.
//
// The function receives the files to be passed as ExtraFiles to the child process
// and must return a started command.
CommandProducer func(files []*os.File) (*exec.Cmd, error)
}

// IsChild checks if the current thread/process is a child.
Expand Down Expand Up @@ -104,6 +130,22 @@ func (p *Prefork) logger() Logger {
}

func (p *Prefork) watchMaster(masterPID int) {
if runtime.GOOS == "windows" {
// On Windows, os.Getppid() returns a static PID that doesn't change
// when the parent exits (no reparenting). Use FindProcess+Wait instead.
proc, err := os.FindProcess(masterPID)
if err == nil {
_, _ = proc.Wait()
}
p.logger().Printf("master process died\n")
p.OnMasterDeath()
return
}

// Unix/Linux/macOS: When the master exits, the OS reparents the child
// to another process, causing Getppid() to change. Comparing against
// the original masterPID (instead of hardcoding 1) ensures this works
// correctly when the master itself is PID 1 (e.g. in Docker containers).
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()

Expand All @@ -127,9 +169,27 @@ func (p *Prefork) listen(addr string) (net.Listener, error) {
return reuseport.Listen(p.Network, addr)
}

// File descriptor 3 is the first ExtraFiles entry passed by the master process.
return net.FileListener(os.NewFile(3, ""))
}

// listenAsChild performs the common child process setup: creates the listener
// and starts watching the master process if OnMasterDeath is configured.
func (p *Prefork) listenAsChild(addr string) (net.Listener, error) {
ln, err := p.listen(addr)
if err != nil {
return nil, err
}

p.ln = ln

if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}

return ln, nil
}

func (p *Prefork) setTCPListenerFiles(addr string) error {
if p.Network == "" {
p.Network = defaultNetwork
Expand Down Expand Up @@ -158,6 +218,19 @@ func (p *Prefork) setTCPListenerFiles(addr string) error {
}

func (p *Prefork) doCommand() (*exec.Cmd, error) {
// Use custom CommandProducer if provided
if p.CommandProducer != nil {
cmd, err := p.CommandProducer(p.files)
if err != nil {
return nil, err
}
if cmd == nil || cmd.Process == nil {
return nil, errors.New("prefork: CommandProducer must return a started command")
}
return cmd, nil
}

// Default implementation using os.Executable() for reliable path resolution
executable, err := os.Executable()
if err != nil {
return nil, err
Expand Down Expand Up @@ -205,25 +278,47 @@ func (p *Prefork) prefork(addr string) (err error) {

goMaxProcs := runtime.GOMAXPROCS(0)
sigCh := make(chan procSig, goMaxProcs)
childProcs := make(map[int]*exec.Cmd)
childProcs := make(map[int]*exec.Cmd, goMaxProcs)

defer func() {
for _, proc := range childProcs {
_ = proc.Process.Kill()
}
}()

// Collect child PIDs for OnMasterReady callback
childPIDs := make([]int, 0, goMaxProcs)

for range goMaxProcs {
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
p.logger().Printf("failed to start a child prefork process, error: %v\n", err)
return err
}

childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()}
}()
pid := cmd.Process.Pid
childProcs[pid] = cmd
childPIDs = append(childPIDs, pid)

// Call OnChildSpawn callback
if p.OnChildSpawn != nil {
if err = p.OnChildSpawn(pid); err != nil {
p.logger().Printf("OnChildSpawn callback failed for PID %d: %v\n", pid, err)
return err
}
}

go func(c *exec.Cmd, pid int) {
sigCh <- procSig{pid: pid, err: c.Wait()}
}(cmd, pid)
}

// Call OnMasterReady callback after all children are spawned
if p.OnMasterReady != nil {
if err = p.OnMasterReady(childPIDs); err != nil {
p.logger().Printf("OnMasterReady callback failed: %v\n", err)
return err
}
Comment thread
ReneWerner87 marked this conversation as resolved.
}

var exitedProcs int
Expand All @@ -237,19 +332,27 @@ func (p *Prefork) prefork(addr string) (err error) {
if exitedProcs > p.RecoverThreshold {
p.logger().Printf("child prefork processes exit too many times, "+
"which exceeds the value of RecoverThreshold(%d), "+
"exiting the master process.\n", exitedProcs)
"exiting the master process.\n", p.RecoverThreshold)
err = ErrOverRecovery
break
}

var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
cmd, err = p.doCommand()
if err != nil {
break
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()}
}()
pid := cmd.Process.Pid
childProcs[pid] = cmd

// Call OnChildRecover callback and ignore its returned error.
if p.OnChildRecover != nil {
_ = p.OnChildRecover(pid)
}

go func(c *exec.Cmd, pid int) {
sigCh <- procSig{pid: pid, err: c.Wait()}
}(cmd, pid)
}

return err
Expand All @@ -258,17 +361,10 @@ func (p *Prefork) prefork(addr string) (err error) {
// ListenAndServe serves HTTP requests from the given TCP addr.
func (p *Prefork) ListenAndServe(addr string) error {
if IsChild() {
ln, err := p.listen(addr)
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}

p.ln = ln

if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}

return p.ServeFunc(ln)
}

Expand All @@ -277,20 +373,17 @@ func (p *Prefork) ListenAndServe(addr string) error {

// ListenAndServeTLS serves HTTPS requests from the given TCP addr.
//
// certFile and keyFile are paths to TLS certificate and key files.
// certKey is the path to the TLS private key file.
// certFile is the path to the TLS certificate file.
//
// Note: parameter order is (addr, certKey, certFile) — key before cert.
// Internally forwards to ServeTLSFunc as (certFile, certKey).
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
Comment thread
ReneWerner87 marked this conversation as resolved.
if IsChild() {
ln, err := p.listen(addr)
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}

p.ln = ln

if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}

return p.ServeTLSFunc(ln, certFile, certKey)
}

Expand All @@ -302,17 +395,10 @@ func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
// certData and keyData must contain valid TLS certificate and key data.
func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
if IsChild() {
ln, err := p.listen(addr)
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}

p.ln = ln

if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}

return p.ServeTLSEmbedFunc(ln, certData, keyData)
}

Expand Down
Loading