diff --git a/constants.go b/constants.go index 5f93fe1cd4b..16fed939a09 100644 --- a/constants.go +++ b/constants.go @@ -336,3 +336,8 @@ const ( ConstraintDatetime = "datetime" ConstraintRegex = "regex" ) + +// OS identifiers +const ( + windowsOS = "windows" +) diff --git a/go.mod b/go.mod index 5c61e1574c1..5c228bb3f03 100644 --- a/go.mod +++ b/go.mod @@ -29,3 +29,5 @@ require ( golang.org/x/text v0.36.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/valyala/fasthttp => github.com/ReneWerner87/fasthttp v0.0.0-20260413063825-be2ef67270a3 diff --git a/go.sum b/go.sum index c53ab0b9f86..39e747a7e0f 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/ReneWerner87/fasthttp v0.0.0-20260413063825-be2ef67270a3 h1:Xf2qNxpXT+kMFbqt6ZzNJOQ8n07zGFsteSa9DJgAYKY= +github.com/ReneWerner87/fasthttp v0.0.0-20260413063825-be2ef67270a3/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE= github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro= github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -28,8 +30,6 @@ github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s= github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.70.0 h1:LAhMGcWk13QZWm85+eg8ZBNbrq5mnkWFGbHMUJHIdXA= -github.com/valyala/fasthttp v1.70.0/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= diff --git a/hooks_test.go b/hooks_test.go index 8eb209bf35a..bc0cf323376 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/prefork" "github.com/gofiber/fiber/v3/log" ) @@ -519,7 +520,8 @@ func Test_ListenData_Hook_HelperFunctions(t *testing.T) { } func Test_Hook_OnListenPrefork(t *testing.T) { - t.Parallel() + testPreforkMaster = true + app := New() buf := bytebufferpool.Get() @@ -532,12 +534,12 @@ func Test_Hook_OnListenPrefork(t *testing.T) { return nil }) - go func() { - time.Sleep(1000 * time.Millisecond) - assert.NoError(t, app.Shutdown()) - }() - - require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) + err := app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + PreforkRecoverThreshold: 1, + }) + require.ErrorIs(t, err, prefork.ErrOverRecovery) require.Equal(t, "ready", buf.String()) } @@ -548,17 +550,17 @@ func Test_Hook_OnHook(t *testing.T) { testPreforkMaster = true testOnPrefork = true - go func() { - time.Sleep(1000 * time.Millisecond) - assert.NoError(t, app.Shutdown()) - }() - app.Hooks().OnFork(func(pid int) error { require.Equal(t, 1, pid) return nil }) - require.NoError(t, app.prefork(":0", nil, &ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) + err := app.prefork(":0", nil, &ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + PreforkRecoverThreshold: 1, + }) + require.ErrorIs(t, err, prefork.ErrOverRecovery) } func Test_Hook_OnMount(t *testing.T) { diff --git a/listen.go b/listen.go index cd085112905..1f9edafdb5f 100644 --- a/listen.go +++ b/listen.go @@ -126,6 +126,19 @@ type ListenConfig struct { // Default: false EnablePrefork bool `json:"enable_prefork"` + // PreforkRecoverThreshold defines the maximum number of times a child process + // can be restarted after crashing before the master process exits with an error. + // This only applies when EnablePrefork is true. + // + // Default: runtime.GOMAXPROCS(0) / 2 + PreforkRecoverThreshold int `json:"prefork_recover_threshold"` + + // PreforkLogger sets a custom logger for the prefork process manager. + // This only applies when EnablePrefork is true. + // + // Default: Fiber's built-in logger (log.Infof) + PreforkLogger PreforkLoggerInterface `json:"prefork_logger"` + // If set to true, will print all routes with their method, path and handler. // // Default: false diff --git a/listen_test.go b/listen_test.go index e02f8d0d1df..237e8d6754e 100644 --- a/listen_test.go +++ b/listen_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" + "github.com/valyala/fasthttp/prefork" "golang.org/x/crypto/acme/autocert" ) @@ -169,7 +170,12 @@ func Test_Listen_Prefork(t *testing.T) { app := New() - require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) + err := app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + PreforkRecoverThreshold: 1, + }) + require.ErrorIs(t, err, prefork.ErrOverRecovery) } // go test -run Test_Listen_TLSMinVersion @@ -202,11 +208,13 @@ func Test_Listen_TLSMinVersion(t *testing.T) { require.NoError(t, app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS13})) // Valid TLSMinVersion with Prefork - go func() { - time.Sleep(1000 * time.Millisecond) - assert.NoError(t, app.Shutdown()) - }() - require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS13})) + err := app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + TLSMinVersion: tls.VersionTLS13, + PreforkRecoverThreshold: 1, + }) + require.ErrorIs(t, err, prefork.ErrOverRecovery) } // go test -run Test_Listen_TLS @@ -244,17 +252,14 @@ func Test_Listen_TLS_Prefork(t *testing.T) { CertKeyFile: "./.github/testdata/template.tmpl", })) - go func() { - time.Sleep(1000 * time.Millisecond) - assert.NoError(t, app.Shutdown()) - }() - - require.NoError(t, app.Listen(":0", ListenConfig{ - DisableStartupMessage: true, - EnablePrefork: true, - CertFile: "./.github/testdata/ssl.pem", - CertKeyFile: "./.github/testdata/ssl.key", - })) + tlsErr := app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + PreforkRecoverThreshold: 1, + }) + require.ErrorIs(t, tlsErr, prefork.ErrOverRecovery) } // go test -run Test_Listen_MutualTLS @@ -295,18 +300,15 @@ func Test_Listen_MutualTLS_Prefork(t *testing.T) { CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) - go func() { - time.Sleep(1000 * time.Millisecond) - assert.NoError(t, app.Shutdown()) - }() - - require.NoError(t, app.Listen(":0", ListenConfig{ - DisableStartupMessage: true, - EnablePrefork: true, - CertFile: "./.github/testdata/ssl.pem", - CertKeyFile: "./.github/testdata/ssl.key", - CertClientFile: "./.github/testdata/ca-chain.cert.pem", - })) + mtlsErr := app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + CertClientFile: "./.github/testdata/ca-chain.cert.pem", + PreforkRecoverThreshold: 1, + }) + require.ErrorIs(t, mtlsErr, prefork.ErrOverRecovery) } // go test -run Test_Listener diff --git a/prefork.go b/prefork.go index 4412c4f141e..0e8e1408127 100644 --- a/prefork.go +++ b/prefork.go @@ -2,66 +2,76 @@ package fiber import ( "crypto/tls" - "errors" "fmt" "net" "os" "os/exec" "runtime" "sync/atomic" - "time" - "github.com/valyala/fasthttp/reuseport" - - "github.com/gofiber/fiber/v3/log" -) - -const ( - envPreforkChildKey = "FIBER_PREFORK_CHILD" - envPreforkChildVal = "1" - sleepDuration = 100 * time.Millisecond - windowsOS = "windows" + "github.com/valyala/fasthttp/prefork" ) +// Test seams for prefork testing - allows injecting dummy commands var ( testPreforkMaster = false testOnPrefork = false + dummyPid = 1 + dummyChildCmd atomic.Value ) // IsChild determines if the current process is a child of Prefork func IsChild() bool { - return os.Getenv(envPreforkChildKey) == envPreforkChildVal + return prefork.IsChild() } -// prefork manages child processes to make use of the OS REUSEPORT or REUSEADDR feature +// prefork manages child processes to make use of the OS REUSEPORT feature. +// It delegates to fasthttp's prefork package to avoid duplicating process management logic. func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg *ListenConfig) error { if cfg == nil { cfg = &ListenConfig{} } - var ln net.Listener - var err error - - // 👶 child process 👶 - if IsChild() { - // use 1 cpu core per child process - runtime.GOMAXPROCS(1) - // Linux will use SO_REUSEPORT and Windows falls back to SO_REUSEADDR - // Only tcp4 or tcp6 is supported when preforking, both are not supported - if ln, err = reuseport.Listen(cfg.ListenerNetwork, addr); err != nil { - if !cfg.DisableStartupMessage { - time.Sleep(sleepDuration) // avoid colliding with startup message + + // Determine RecoverThreshold + recoverThreshold := cfg.PreforkRecoverThreshold + if recoverThreshold == 0 { + recoverThreshold = max(1, runtime.GOMAXPROCS(0)/2) + } + + // Use configured logger or default to Fiber's log package + var logger prefork.Logger = preforkLogger{} //nolint:wastedassign // fallback default + if cfg.PreforkLogger != nil { + logger = cfg.PreforkLogger + } + + p := &prefork.Prefork{ + Network: cfg.ListenerNetwork, + Reuseport: true, + RecoverThreshold: recoverThreshold, + Logger: logger, + OnMasterDeath: func() { os.Exit(1) }, //nolint:revive // Exiting child process is intentional + } + + // Use test command producer if in test mode + if testPreforkMaster { + p.CommandProducer = func(_ []*os.File) (*exec.Cmd, error) { + cmd := dummyCmd() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + return cmd, fmt.Errorf("prefork: failed to start test command: %w", err) } - return fmt.Errorf("prefork: %w", err) + return cmd, nil } + } + + // Child process: serve function wraps TLS, starts up process, etc. + p.ServeFunc = func(ln net.Listener) error { // wrap a tls config around the listener if provided if tlsConfig != nil { ln = tls.NewListener(ln, tlsConfig) } - // kill current child proc when master exits - masterPID := os.Getppid() - go watchMaster(masterPID) - // prepare the server for the start app.startupProcess() @@ -73,57 +83,29 @@ func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg *ListenConfig) e return app.server.Serve(ln) } - // 👮 master process 👮 - type child struct { - err error - pid int - } - // create variables - maxProcs := runtime.GOMAXPROCS(0) - children := make(map[int]*exec.Cmd) - channel := make(chan child, maxProcs) - - // kill child procs when master exits - defer func() { - for _, proc := range children { - if err = proc.Process.Kill(); err != nil { - if !errors.Is(err, os.ErrProcessDone) { - log.Errorf("prefork: failed to kill child: %v", err) - } + // Master callback: child spawned → execute OnFork hooks + p.OnChildSpawn = func(pid int) error { + if app.hooks != nil { + if testOnPrefork { + app.hooks.executeOnForkHooks(dummyPid) + } else { + app.hooks.executeOnForkHooks(pid) } } - }() - - // collect child pids - var childPIDs []int - - // launch child procs - for range maxProcs { - cmd := exec.Command(os.Args[0], os.Args[1:]...) //nolint:gosec // It's fine to launch the same process again - if testPreforkMaster { - // When test prefork master, - // just start the child process with a dummy cmd, - // which will exit soon - cmd = dummyCmd() - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - // add fiber prefork child flag into child proc env - cmd.Env = append(os.Environ(), - fmt.Sprintf("%s=%s", envPreforkChildKey, envPreforkChildVal), - ) - - if err = cmd.Start(); err != nil { - return fmt.Errorf("failed to start a child prefork process, error: %w", err) - } + return nil + } - // store child process - pid := cmd.Process.Pid - children[pid] = cmd - childPIDs = append(childPIDs, pid) + // Master callback: all children spawned → startup message & OnListen hooks + p.OnMasterReady = func(childPIDs []int) error { + listenData := app.prepareListenData(addr, tlsConfig != nil, cfg, childPIDs) + app.runOnListenHooks(listenData) + app.printMessages(cfg, listenData) + return nil + } - // execute fork hook + // Master callback: child recovered after crash + p.OnChildRecover = func(pid int) error { + logger.Printf("prefork: child process crashed, recovered with new PID %d", pid) if app.hooks != nil { if testOnPrefork { app.hooks.executeOnForkHooks(dummyPid) @@ -131,61 +113,16 @@ func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg *ListenConfig) e app.hooks.executeOnForkHooks(pid) } } - - // notify master if child crashes - go func() { - channel <- child{pid: pid, err: cmd.Wait()} - }() + return nil } - // Run onListen hooks - // Hooks have to be run here as different as non-prefork mode due to they should run as child or master - listenData := app.prepareListenData(addr, tlsConfig != nil, cfg, childPIDs) - - app.runOnListenHooks(listenData) - - app.startupMessage(listenData, cfg) - - if cfg.EnablePrintRoutes { - app.printRoutesMessage() + if err := p.ListenAndServe(addr); err != nil { + return fmt.Errorf("prefork: %w", err) } - // return error if child crashes - return (<-channel).err + return nil } -// watchMaster watches the master process and exits if it dies. -// It detects master death by checking if the parent PID has changed, -// which happens when the master exits and the child is reparented to -// another process (often init/PID 1, but could be a subreaper). -func watchMaster(masterPID int) { - if runtime.GOOS == windowsOS { - // finds parent process, - // and waits for it to exit - p, err := os.FindProcess(masterPID) - if err == nil { - _, _ = p.Wait() //nolint:errcheck // It is fine to ignore the error here - } - os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork - } - // Watch for parent PID changes. When the master exits, the OS - // reparents the child to another process, causing Getppid() to change. - // Comparing against the original PID instead of hardcoding 1 ensures - // this works correctly when the master itself is PID 1 (e.g. in - // Docker containers). - const watchInterval = 500 * time.Millisecond - for range time.NewTicker(watchInterval).C { - if os.Getppid() != masterPID { - os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork - } - } -} - -var ( - dummyPid = 1 - dummyChildCmd atomic.Value -) - // dummyCmd is for internal prefork testing func dummyCmd() *exec.Cmd { command := "go" diff --git a/prefork_logger.go b/prefork_logger.go new file mode 100644 index 00000000000..51a5a8948f7 --- /dev/null +++ b/prefork_logger.go @@ -0,0 +1,19 @@ +package fiber + +import ( + "github.com/gofiber/fiber/v3/log" +) + +// PreforkLoggerInterface defines a logger for the prefork process manager. +// Compatible with fasthttp/prefork.Logger. +type PreforkLoggerInterface interface { + // Printf must have the same semantics as log.Printf. + Printf(format string, args ...any) +} + +// preforkLogger adapts Fiber's logger to the PreforkLoggerInterface. +type preforkLogger struct{} + +func (preforkLogger) Printf(format string, args ...any) { + log.Infof(format, args...) +} diff --git a/prefork_test.go b/prefork_test.go index 67910ec8925..b08b6dc1695 100644 --- a/prefork_test.go +++ b/prefork_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp/prefork" ) func Test_App_Prefork_Child_Process(t *testing.T) { @@ -58,18 +59,21 @@ func Test_App_Prefork_Master_Process(t *testing.T) { app := New() - go func() { - time.Sleep(1000 * time.Millisecond) - assert.NoError(t, app.Shutdown()) - }() - + // With dummy commands that exit immediately, fasthttp recovers children + // until RecoverThreshold is exceeded, then returns ErrOverRecovery. + // Use low threshold for fast test execution. cfg := listenConfigDefault() - require.NoError(t, app.prefork(":0", nil, &cfg)) + cfg.PreforkRecoverThreshold = 1 + err := app.prefork(":0", nil, &cfg) + require.ErrorIs(t, err, prefork.ErrOverRecovery) + // With invalid command, should get a start error immediately + // (error happens during initial spawning, before recovery loop) dummyChildCmd.Store("invalid") cfg = listenConfigDefault() - err := app.prefork("127.0.0.1:", nil, &cfg) + cfg.PreforkRecoverThreshold = 1 + err = app.prefork("127.0.0.1:", nil, &cfg) require.Error(t, err) dummyChildCmd.Store("go") @@ -99,8 +103,26 @@ func Test_App_Prefork_Child_Process_Never_Show_Startup_Message(t *testing.T) { require.Empty(t, out) } +func Test_IsChild(t *testing.T) { + // Without env var, should be false + require.False(t, IsChild()) + + // With env var, should be true + setupIsChild(t) + require.True(t, IsChild()) +} + +func Test_Prefork_Logger(t *testing.T) { + t.Parallel() + + l := preforkLogger{} + // Should not panic + l.Printf("test %s", "message") +} + func setupIsChild(t *testing.T) { t.Helper() - t.Setenv(envPreforkChildKey, envPreforkChildVal) + // Set the environment variable that fasthttp's prefork.IsChild() checks + t.Setenv("FASTHTTP_PREFORK_CHILD", "1") }