From 5a8a4f1139160aee1a7248d991a2b91163c1d60d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sun, 26 Apr 2026 19:08:20 +0200 Subject: [PATCH 1/9] =?UTF-8?q?=F0=9F=94=A5=20feat:=20add=20lightweight=20?= =?UTF-8?q?SSE=20middleware?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a small Fiber-native Server-Sent Events middleware focused on the transport layer: SSE headers, frame formatting, flushing, heartbeat comments, Last-Event-ID access, stream lifecycle context, and disconnect detection through write/flush errors. The implementation intentionally avoids bundling hub, topic routing, replay storage, auth helpers, metrics, or pub/sub bridges into core. Those remain application-level concerns and can be covered by recipes or separate packages. Also add middleware docs, focused tests, and an implementation plan that captures the review history and scope decisions. --- docs/middleware/sse.md | 119 +++++++++++++++ docs/whats_new.md | 8 + middleware/sse/config.go | 70 +++++++++ middleware/sse/event.go | 137 +++++++++++++++++ middleware/sse/sse.go | 220 +++++++++++++++++++++++++++ middleware/sse/sse_test.go | 294 +++++++++++++++++++++++++++++++++++++ 6 files changed, 848 insertions(+) create mode 100644 docs/middleware/sse.md create mode 100644 middleware/sse/config.go create mode 100644 middleware/sse/event.go create mode 100644 middleware/sse/sse.go create mode 100644 middleware/sse/sse_test.go diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md new file mode 100644 index 00000000000..983a7c1c4e7 --- /dev/null +++ b/docs/middleware/sse.md @@ -0,0 +1,119 @@ +--- +id: sse +--- + +# SSE + +The SSE middleware provides the transport pieces for Server-Sent Events: response headers, event formatting, flushing, heartbeat comments, and disconnect detection through `Flush` errors. + +It intentionally does not include a hub, topics, authentication, replay storage, metrics, or external pub/sub bridges. Those are application concerns that can be composed around the stream handler. + +## Signatures + +```go +func New(config ...Config) fiber.Handler +``` + +## Examples + +Import the middleware package: + +```go +import ( + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/sse" +) +``` + +Once your Fiber app is initialized, mount an SSE endpoint like this: + +```go +app.Get("/events", sse.New(sse.Config{ + Retry: 5 * time.Second, + Handler: func(c fiber.Ctx, stream *sse.Stream) error { + return stream.Event(sse.Event{ + Name: "message", + Data: fiber.Map{"message": "hello"}, + }) + }, +})) +``` + +For long-running streams, wait on your own event source and stop when the client disconnects: + +```go +events := make(chan string) + +app.Get("/events", sse.New(sse.Config{ + Handler: func(c fiber.Ctx, stream *sse.Stream) error { + for { + select { + case msg := <-events: + if err := stream.Event(sse.Event{Name: "message", Data: msg}); err != nil { + return err + } + case <-stream.Done(): + return stream.Err() + } + } + }, +})) +``` + +`stream.Context()` is canceled when the stream ends or a write fails, which makes it convenient to pass into database, broker, or gRPC calls: + +```go +app.Get("/events", sse.New(sse.Config{ + Handler: func(c fiber.Ctx, stream *sse.Stream) error { + rows, err := db.QueryContext(stream.Context(), "SELECT id FROM jobs") + if err != nil { + return err + } + defer rows.Close() + + return stream.Comment("connected") + }, +})) +``` + +## Config + +| Property | Type | Description | Default | +|:------------------|:-----------------------------|:----------------------------------------------|:--------------------| +| Next | `func(fiber.Ctx) bool` | Skip when the function returns `true`. | `nil` | +| Handler | `sse.Handler` | Writes events to the stream. | `nil` | +| OnClose | `func(fiber.Ctx, error)` | Called when the stream ends, with `nil` when the handler returned successfully and no stream write failed. | `nil` | +| Retry | `time.Duration` | Initial EventSource reconnect delay. | `0` | +| HeartbeatInterval | `time.Duration` | Interval for SSE comment heartbeats. | `15 * time.Second` | +| DisableHeartbeat | `bool` | Disable automatic heartbeat comments. | `false` | + +## Default Config + +```go +var ConfigDefault = Config{ + Next: nil, + Handler: nil, + OnClose: nil, + Retry: 0, + HeartbeatInterval: 15 * time.Second, + DisableHeartbeat: false, +} +``` + +## Stream + +```go +func (s *Stream) Event(event Event) error +func (s *Stream) Comment(comment string) error +func (s *Stream) Retry(retry time.Duration) error +func (s *Stream) Context() context.Context +func (s *Stream) Done() <-chan struct{} +func (s *Stream) Err() error +func (s *Stream) LastEventID() string +``` + +Every write is flushed. A failed flush closes `Done`, stores the error returned by `Err`, and lets the handler stop without relying on `fasthttp.RequestCtx.Done`, which is not a per-client disconnect signal. After a normal handler return, `Done` and `Context()` are closed while `Err()` remains `nil`; writes after that return `sse: stream closed`. + +`Config.Retry` sends the initial reconnect delay when the stream opens. `Event.Retry` changes the reconnect delay for a specific event, following the SSE wire format. diff --git a/docs/whats_new.md b/docs/whats_new.md index 3ee2d2b2981..a07231c8712 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -57,6 +57,7 @@ Here's a quick overview of the changes in Fiber `v3`: - [Proxy](#proxy) - [Recover](#recover) - [Session](#session) + - [SSE](#sse) - [🔌 Addons](#-addons) - [📋 Migration guide](#-migration-guide) @@ -1680,6 +1681,13 @@ The session middleware has undergone significant improvements in v3, focusing on For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). +### SSE + +Fiber now includes a small [SSE middleware](./middleware/sse.md) for Server-Sent Events. It handles native +`SendStreamWriter` setup, SSE response headers, event formatting, flushing, heartbeat comments, and +disconnect detection through flush errors while leaving application-level hubs, topics, replay stores, and +pub/sub bridges to user code or recipes. + ### Timeout The timeout middleware is now configurable. A new `Config` struct allows customizing the timeout duration, defining a handler that runs when a timeout occurs, and specifying errors to treat as timeouts. The `New` function now accepts a `Config` value instead of a duration. diff --git a/middleware/sse/config.go b/middleware/sse/config.go new file mode 100644 index 00000000000..8c1ed2511ef --- /dev/null +++ b/middleware/sse/config.go @@ -0,0 +1,70 @@ +package sse + +import ( + "time" + + "github.com/gofiber/fiber/v3" +) + +// Handler writes events to a single SSE stream. +type Handler func(c fiber.Ctx, stream *Stream) error + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // Handler writes events to the stream. + // + // Required. + Handler Handler + + // OnClose is called after the stream handler returns or the client disconnects. + // + // Optional. Default: nil + OnClose func(c fiber.Ctx, err error) + + // Retry controls the reconnection delay sent to clients. + // Values less than or equal to zero disable the initial retry field. + // + // Optional. Default: 0 + Retry time.Duration + + // HeartbeatInterval controls comment heartbeats used to keep intermediaries + // from closing idle streams and to detect disconnected clients. + // When DisableHeartbeat is false, values less than or equal to zero are + // replaced by the default interval. + // + // Optional. Default: 15 * time.Second + HeartbeatInterval time.Duration + + // DisableHeartbeat disables automatic comment heartbeats. + // + // Optional. Default: false + DisableHeartbeat bool +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{ + Next: nil, + Handler: nil, + OnClose: nil, + Retry: 0, + HeartbeatInterval: 15 * time.Second, + DisableHeartbeat: false, +} + +// Helper function to set default values. +func configDefault(config ...Config) Config { + if len(config) < 1 { + return ConfigDefault + } + + cfg := config[0] + if !cfg.DisableHeartbeat && cfg.HeartbeatInterval <= 0 { + cfg.HeartbeatInterval = ConfigDefault.HeartbeatInterval + } + return cfg +} diff --git a/middleware/sse/event.go b/middleware/sse/event.go new file mode 100644 index 00000000000..e8ea49adcab --- /dev/null +++ b/middleware/sse/event.go @@ -0,0 +1,137 @@ +package sse + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/gofiber/utils/v2" +) + +var errInvalidField = errors.New("field must not contain CR or LF") + +// Event defines a single Server-Sent Event frame. +type Event struct { + // Data is written as one or more data fields. Strings and byte slices are + // written as-is; other values are JSON encoded. + Data any + + // ID sets the SSE id field. + ID string + + // Name sets the SSE event field. + Name string + + // Retry sets the SSE retry field for this event. + Retry time.Duration +} + +func writeEvent(w *bufio.Writer, event Event) error { + if event.ID != "" { + id, err := sanitizeField(event.ID) + if err != nil { + return fmt.Errorf("sse: invalid id: %w", err) + } + if _, err := fmt.Fprintf(w, "id: %s\n", id); err != nil { + return fmt.Errorf("sse: write id: %w", err) + } + } + if event.Name != "" { + name, err := sanitizeField(event.Name) + if err != nil { + return fmt.Errorf("sse: invalid event: %w", err) + } + if _, err := fmt.Fprintf(w, "event: %s\n", name); err != nil { + return fmt.Errorf("sse: write event: %w", err) + } + } + if event.Retry > 0 { + if _, err := fmt.Fprintf(w, "retry: %d\n", event.Retry.Milliseconds()); err != nil { + return fmt.Errorf("sse: write retry: %w", err) + } + } + + data, err := eventData(event.Data) + if err != nil { + return err + } + if err := writeData(w, data); err != nil { + return err + } + if _, err := w.WriteString("\n"); err != nil { + return fmt.Errorf("sse: finish event: %w", err) + } + return nil +} + +func writeComment(w *bufio.Writer, comment string) error { + comment = sanitizeComment(comment) + if comment == "" { + if _, err := w.WriteString(":\n\n"); err != nil { + return fmt.Errorf("sse: write heartbeat: %w", err) + } + return nil + } + for line := range strings.SplitSeq(comment, "\n") { + if _, err := fmt.Fprintf(w, ": %s\n", line); err != nil { + return fmt.Errorf("sse: write comment: %w", err) + } + } + if _, err := w.WriteString("\n"); err != nil { + return fmt.Errorf("sse: finish comment: %w", err) + } + return nil +} + +func eventData(data any) (string, error) { + switch value := data.(type) { + case nil: + return "", nil + case string: + return value, nil + case []byte: + return string(value), nil + case json.RawMessage: + return string(value), nil + default: + encoded, err := json.Marshal(value) + if err != nil { + return "", fmt.Errorf("sse: marshal data: %w", err) + } + return string(encoded), nil + } +} + +func writeData(w *bufio.Writer, data string) error { + data = normalizeNewlines(data) + for line := range strings.SplitSeq(data, "\n") { + if _, err := fmt.Fprintf(w, "data: %s\n", line); err != nil { + return fmt.Errorf("sse: write data: %w", err) + } + } + return nil +} + +func sanitizeField(value string) (string, error) { + if strings.ContainsAny(value, "\r\n") { + return "", errInvalidField + } + return utils.Trim(value, ' '), nil +} + +func sanitizeComment(value string) string { + value = normalizeNewlines(value) + lines := make([]string, 0, strings.Count(value, "\n")+1) + for line := range strings.SplitSeq(value, "\n") { + lines = append(lines, utils.Trim(line, ' ')) + } + return strings.Join(lines, "\n") +} + +func normalizeNewlines(value string) string { + value = strings.ReplaceAll(value, "\r\n", "\n") + return strings.ReplaceAll(value, "\r", "\n") +} diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go new file mode 100644 index 00000000000..8d4ab0129c1 --- /dev/null +++ b/middleware/sse/sse.go @@ -0,0 +1,220 @@ +// Package sse provides small Server-Sent Events middleware for Fiber. +// +// The package focuses on the SSE transport: response headers, wire formatting, +// flushing, heartbeat comments, and disconnect detection via flush errors. +// Application-specific concerns such as topics, replay storage, authentication, +// and pub/sub fan-out intentionally stay outside the core middleware. +package sse + +import ( + "bufio" + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/gofiber/fiber/v3" +) + +const mimeTextEventStream = "text/event-stream" + +var errStreamClosed = errors.New("sse: stream closed") + +// New creates a new middleware handler. +func New(config ...Config) fiber.Handler { + cfg := configDefault(config...) + if cfg.Handler == nil { + panic("sse: Handler must not be nil") + } + + return func(c fiber.Ctx) error { + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + c.Set(fiber.HeaderContentType, mimeTextEventStream) + c.Set(fiber.HeaderCacheControl, "no-cache") + c.Set(fiber.HeaderConnection, "keep-alive") + c.Set("X-Accel-Buffering", "no") + + c.Abandon() + + streamContext := c.Context() + lastEventID := c.Get(fiber.HeaderLastEventID) + + return c.SendStreamWriter(func(w *bufio.Writer) { + stream := newStream(streamContext, w, lastEventID) + var streamErr error + defer func() { + if cfg.OnClose != nil { + cfg.OnClose(c, streamErr) + } + }() + defer stream.closeStream() + + if cfg.Retry > 0 { + streamErr = stream.Retry(cfg.Retry) + if streamErr != nil { + return + } + } + + if !cfg.DisableHeartbeat { + stopHeartbeat := stream.startHeartbeat(cfg.HeartbeatInterval) + if stopHeartbeat != nil { + defer stopHeartbeat() + } + } + + streamErr = cfg.Handler(c, stream) + if streamErr == nil { + streamErr = stream.Err() + } + }) + } +} + +// Stream is an active SSE response stream. +type Stream struct { + ctx context.Context //nolint:containedctx // Stream exposes a per-stream context canceled with the stream lifecycle. + cancel context.CancelFunc + err error + w *bufio.Writer + done chan struct{} + lastEventID string + closed bool + once sync.Once + mu sync.Mutex +} + +func newStream(ctx context.Context, w *bufio.Writer, lastEventID string) *Stream { //nolint:contextcheck // ctx is the parent for the derived stream lifecycle context. + if ctx == nil { + ctx = context.Background() + } + ctx, cancel := context.WithCancel(ctx) + return &Stream{ + ctx: ctx, + cancel: cancel, + w: w, + done: make(chan struct{}), + lastEventID: lastEventID, + } +} + +// Context returns a context canceled when the stream ends or a write fails. +func (s *Stream) Context() context.Context { + return s.ctx +} + +// Done returns a channel closed when a write fails or the handler returns. +func (s *Stream) Done() <-chan struct{} { + return s.done +} + +// LastEventID returns the Last-Event-ID header value sent by the client. +func (s *Stream) LastEventID() string { + return s.lastEventID +} + +// Err returns the first stream write error. +func (s *Stream) Err() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.err +} + +// Event writes one SSE event and flushes it to the client. +func (s *Stream) Event(event Event) error { + return s.write(func(w *bufio.Writer) error { + return writeEvent(w, event) + }) +} + +// Comment writes one SSE comment and flushes it to the client. +func (s *Stream) Comment(comment string) error { + return s.write(func(w *bufio.Writer) error { + return writeComment(w, comment) + }) +} + +// Retry writes an SSE retry field and flushes it to the client. +func (s *Stream) Retry(retry time.Duration) error { + if retry <= 0 { + return nil + } + return s.write(func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "retry: %d\n\n", retry.Milliseconds()) + if err != nil { + return fmt.Errorf("sse: write retry: %w", err) + } + return nil + }) +} + +func (s *Stream) write(fn func(w *bufio.Writer) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return s.err + } + if s.closed { + return errStreamClosed + } + if err := fn(s.w); err != nil { + return s.failLocked(err) + } + if err := s.w.Flush(); err != nil { + return s.failLocked(err) + } + return nil +} + +func (s *Stream) failLocked(err error) error { + s.err = err + s.closed = true + s.once.Do(func() { + s.cancel() + close(s.done) + }) + return err +} + +func (s *Stream) closeStream() { + s.mu.Lock() + s.closed = true + s.mu.Unlock() + s.once.Do(func() { + s.cancel() + close(s.done) + }) +} + +func (s *Stream) startHeartbeat(interval time.Duration) func() { + if interval <= 0 { + return nil + } + + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := s.Comment(""); err != nil { + return + } + case <-stop: + return + case <-s.Done(): + return + } + } + }() + + return func() { + close(stop) + } +} diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go new file mode 100644 index 00000000000..5fd27d7ca19 --- /dev/null +++ b/middleware/sse/sse_test.go @@ -0,0 +1,294 @@ +package sse + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +type panicStringer struct{} + +func (*panicStringer) String() string { + panic("String must not be called for SSE data") +} + +func Test_SSE_EventWritesFrame(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ + ID: " 42 ", + Name: "update", + Data: "one\r\ntwo", + Retry: 2500 * time.Millisecond, + })) + require.NoError(t, w.Flush()) + + require.Equal(t, "id: 42\nevent: update\nretry: 2500\ndata: one\ndata: two\n\n", buf.String()) +} + +func Test_SSE_EventRejectsFieldInjection(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.ErrorIs(t, writeEvent(w, Event{ + ID: "42\nretry: 1", + Data: "ignored", + }), errInvalidField) +} + +func Test_SSE_EventJSONEncodesData(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ + Name: "message", + Data: map[string]string{"hello": "world"}, + })) + require.NoError(t, w.Flush()) + + require.JSONEq(t, `{"hello":"world"}`, stringsTrimData(buf.String())) +} + +func Test_SSE_EventJSONEncodesTypedNilStringer(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + var value *panicStringer + + require.NoError(t, writeEvent(w, Event{Data: value})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: null\n\n", buf.String()) +} + +func Test_SSE_CommentSanitizesLines(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeComment(w, " first\r\nsecond ")) + require.NoError(t, w.Flush()) + + require.Equal(t, ": first\n: second\n\n", buf.String()) +} + +func Test_SSE_NewWritesHeadersAndEvents(t *testing.T) { + t.Parallel() + + app := fiber.New() + var capturedLastEventID string + app.Get("/events", New(Config{ + Retry: time.Second, + DisableHeartbeat: true, + Handler: func(_ fiber.Ctx, stream *Stream) error { + capturedLastEventID = stream.LastEventID() + return stream.Event(Event{Name: "ready", Data: "ok"}) + }, + })) + + req := httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody) + req.Header.Set("Last-Event-ID", "last-1") + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, "last-1", capturedLastEventID) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Equal(t, mimeTextEventStream, resp.Header.Get(fiber.HeaderContentType)) + require.Equal(t, "no-cache", resp.Header.Get(fiber.HeaderCacheControl)) + require.Equal(t, "keep-alive", resp.Header.Get(fiber.HeaderConnection)) + require.Equal(t, "no", resp.Header.Get("X-Accel-Buffering")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "retry: 1000\n\nevent: ready\ndata: ok\n\n", string(body)) +} + +func Test_SSE_NewWritesHeartbeat(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Get("/events", New(Config{ + HeartbeatInterval: 10 * time.Millisecond, + Handler: func(_ fiber.Ctx, stream *Stream) error { + select { + case <-time.After(30 * time.Millisecond): + return nil + case <-stream.Done(): + return stream.Err() + } + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), ":\n\n") +} + +func Test_SSE_StreamComment(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + require.NoError(t, stream.Comment("hello")) + require.Equal(t, ": hello\n\n", buf.String()) +} + +func Test_SSE_StreamConcurrentWrites(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + const writers = 16 + errs := make(chan error, writers) + var wg sync.WaitGroup + wg.Add(writers) + for i := range writers { + go func(data int) { + defer wg.Done() + errs <- stream.Event(Event{Name: "message", Data: data}) + }(i) + } + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + require.Equal(t, writers, strings.Count(buf.String(), "event: message\n")) + require.Equal(t, writers, strings.Count(buf.String(), "data: ")) +} + +func Test_SSE_StreamContextCanceledOnClose(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + stream.closeStream() + + select { + case <-stream.Context().Done(): + case <-time.After(time.Second): + t.Fatal("stream context was not canceled") + } +} + +func Test_SSE_StreamErrAfterNormalClose(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + stream.closeStream() + + require.NoError(t, stream.Err()) + require.ErrorIs(t, stream.Event(Event{Data: "late"}), errStreamClosed) +} + +func Test_SSE_StreamWriteError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriter(errWriter{err: writeErr}), "") + + require.ErrorIs(t, stream.Event(Event{Data: "hello"}), writeErr) + require.ErrorIs(t, stream.Err(), writeErr) + select { + case <-stream.Done(): + case <-time.After(time.Second): + t.Fatal("stream was not closed after write error") + } +} + +func Test_SSE_NextSkipsMiddleware(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New(Config{ + Next: func(fiber.Ctx) bool { + return true + }, + Handler: func(_ fiber.Ctx, stream *Stream) error { + return stream.Event(Event{Data: "ignored"}) + }, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("next") + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "next", string(body)) +} + +func Test_SSE_HandlerErrorCallsOnClose(t *testing.T) { + t.Parallel() + + handlerErr := errors.New("boom") + closed := make(chan error, 1) + + app := fiber.New() + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(fiber.Ctx, *Stream) error { + return handlerErr + }, + OnClose: func(_ fiber.Ctx, err error) { + closed <- err + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.ErrorIs(t, <-closed, handlerErr) +} + +func Test_SSE_NewPanicsWithoutHandler(t *testing.T) { + t.Parallel() + + require.PanicsWithValue(t, "sse: Handler must not be nil", func() { + New() + }) +} + +func stringsTrimData(frame string) string { + const prefix = "event: message\ndata: " + frame = strings.TrimPrefix(frame, prefix) + return strings.TrimSuffix(frame, "\n\n") +} + +type errWriter struct { + err error +} + +func (w errWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("test writer: %w", w.err) +} From 1aacd75fdde5da19971a701542f2d09e4ee1dd06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sun, 26 Apr 2026 19:51:59 +0200 Subject: [PATCH 2/9] fix: address SSE review feedback --- ctx_interface_gen.go | 9 +++- docs/middleware/sse.md | 2 +- middleware/limiter/limiter_test.go | 2 +- middleware/sse/event.go | 43 +++++++++++++------ middleware/sse/sse.go | 5 ++- middleware/sse/sse_test.go | 67 ++++++++++++++++++++++++++++-- req_interface_gen.go | 4 ++ res_interface_gen.go | 5 ++- 8 files changed, 113 insertions(+), 24 deletions(-) diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index 99fc5da1f52..e735861aa98 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -256,6 +256,8 @@ type Ctx interface { // Make copies or use the Immutable setting to use the value outside the Handler. Cookies(key string, defaultValue ...string) string // FormFile returns the first file by key from a MultipartForm. + // The multipart form is parsed using the application's BodyLimit to prevent + // unbounded memory usage. FormFile(key string) (*multipart.FileHeader, error) // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. @@ -263,6 +265,8 @@ type Ctx interface { // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. + // When the request is a multipart form, it is parsed using the application's + // BodyLimit so the configured limit is consistently enforced. FormValue(key string, defaultValue ...string) string // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale @@ -398,7 +402,7 @@ type Ctx interface { // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. - // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a

element. + // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a `

` element. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. AutoFormat(body any) error @@ -431,7 +435,8 @@ type Ctx interface { Links(link ...string) // Location sets the response Location HTTP header to the specified path parameter. Location(path string) - // getLocationFromRoute get URL location from route using parameters + // getLocationFromRoute gets the URL location from a route using parameters. + // Nil receivers and missing routes return ErrNotFound to match Route.URL semantics. getLocationFromRoute(route *Route, params Map) (string, error) // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error) diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md index 983a7c1c4e7..fbffc19ae83 100644 --- a/docs/middleware/sse.md +++ b/docs/middleware/sse.md @@ -114,6 +114,6 @@ func (s *Stream) Err() error func (s *Stream) LastEventID() string ``` -Every write is flushed. A failed flush closes `Done`, stores the error returned by `Err`, and lets the handler stop without relying on `fasthttp.RequestCtx.Done`, which is not a per-client disconnect signal. After a normal handler return, `Done` and `Context()` are closed while `Err()` remains `nil`; writes after that return `sse: stream closed`. +Every write is flushed. A failed flush closes `Done`, stores the error returned by `Err`, and lets the handler stop without relying on `fasthttp.RequestCtx.Done`, which is not a per-client disconnect signal. After a normal handler return, `Done` is closed and `Context()` is canceled while `Err()` remains `nil`; writes after that return `sse: stream closed`. `Config.Retry` sends the initial reconnect delay when the stream opens. `Event.Retry` changes the reconnect delay for a specific event, following the SSE wire format. diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index a6ec1af7efe..cbb746875a3 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -1080,7 +1080,7 @@ func Test_Limiter_Sliding_Window_RecalculatesAfterHandlerDelay(t *testing.T) { require.Equal(t, fiber.StatusOK, resp.StatusCode) } - time.Sleep(time.Second + 100*time.Millisecond) + time.Sleep(2*time.Second + 100*time.Millisecond) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) diff --git a/middleware/sse/event.go b/middleware/sse/event.go index e8ea49adcab..dac2435428a 100644 --- a/middleware/sse/event.go +++ b/middleware/sse/event.go @@ -2,6 +2,7 @@ package sse import ( "bufio" + "bytes" "encoding/json" "errors" "fmt" @@ -30,12 +31,15 @@ type Event struct { } func writeEvent(w *bufio.Writer, event Event) error { + var frame bytes.Buffer + fw := bufio.NewWriter(&frame) + if event.ID != "" { id, err := sanitizeField(event.ID) if err != nil { return fmt.Errorf("sse: invalid id: %w", err) } - if _, err := fmt.Fprintf(w, "id: %s\n", id); err != nil { + if _, err := fmt.Fprintf(fw, "id: %s\n", id); err != nil { return fmt.Errorf("sse: write id: %w", err) } } @@ -44,12 +48,12 @@ func writeEvent(w *bufio.Writer, event Event) error { if err != nil { return fmt.Errorf("sse: invalid event: %w", err) } - if _, err := fmt.Fprintf(w, "event: %s\n", name); err != nil { + if _, err := fmt.Fprintf(fw, "event: %s\n", name); err != nil { return fmt.Errorf("sse: write event: %w", err) } } if event.Retry > 0 { - if _, err := fmt.Fprintf(w, "retry: %d\n", event.Retry.Milliseconds()); err != nil { + if _, err := fmt.Fprintf(fw, "retry: %d\n", event.Retry.Milliseconds()); err != nil { return fmt.Errorf("sse: write retry: %w", err) } } @@ -58,12 +62,20 @@ func writeEvent(w *bufio.Writer, event Event) error { if err != nil { return err } - if err := writeData(w, data); err != nil { - return err + if data.hasData { + if err := writeData(fw, data.data); err != nil { + return err + } } - if _, err := w.WriteString("\n"); err != nil { + if _, err := fw.WriteString("\n"); err != nil { return fmt.Errorf("sse: finish event: %w", err) } + if err := fw.Flush(); err != nil { + return fmt.Errorf("sse: flush event frame: %w", err) + } + if _, err := w.Write(frame.Bytes()); err != nil { + return fmt.Errorf("sse: write event: %w", err) + } return nil } @@ -86,22 +98,27 @@ func writeComment(w *bufio.Writer, comment string) error { return nil } -func eventData(data any) (string, error) { +type eventPayload struct { + data string + hasData bool +} + +func eventData(data any) (eventPayload, error) { switch value := data.(type) { case nil: - return "", nil + return eventPayload{}, nil case string: - return value, nil + return eventPayload{data: value, hasData: true}, nil case []byte: - return string(value), nil + return eventPayload{data: string(value), hasData: true}, nil case json.RawMessage: - return string(value), nil + return eventPayload{data: string(value), hasData: true}, nil default: encoded, err := json.Marshal(value) if err != nil { - return "", fmt.Errorf("sse: marshal data: %w", err) + return eventPayload{}, fmt.Errorf("sse: marshal data: %w", err) } - return string(encoded), nil + return eventPayload{data: string(encoded), hasData: true}, nil } } diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go index 8d4ab0129c1..f4914b993d7 100644 --- a/middleware/sse/sse.go +++ b/middleware/sse/sse.go @@ -197,6 +197,7 @@ func (s *Stream) startHeartbeat(interval time.Duration) func() { } stop := make(chan struct{}) + var stopOnce sync.Once go func() { ticker := time.NewTicker(interval) defer ticker.Stop() @@ -215,6 +216,8 @@ func (s *Stream) startHeartbeat(interval time.Duration) func() { }() return func() { - close(stop) + stopOnce.Do(func() { + close(stop) + }) } } diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go index 5fd27d7ca19..8d9695405b9 100644 --- a/middleware/sse/sse_test.go +++ b/middleware/sse/sse_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -81,6 +82,46 @@ func Test_SSE_EventJSONEncodesTypedNilStringer(t *testing.T) { require.Equal(t, "data: null\n\n", buf.String()) } +func Test_SSE_EventOmitsDataForUntypedNil(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ID: "42"})) + require.NoError(t, w.Flush()) + + require.Equal(t, "id: 42\n\n", buf.String()) +} + +func Test_SSE_EventDoesNotWritePartialFrameWhenDataMarshalFails(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.Error(t, writeEvent(w, Event{ + ID: "42", + Name: "broken", + Data: func() {}, + })) + require.NoError(t, w.Flush()) + + require.Empty(t, buf.String()) +} + +func Test_SSE_EventWritesRawJSONData(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: json.RawMessage(`{"hello":"world"}`)})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: {\"hello\":\"world\"}\n\n", buf.String()) +} + func Test_SSE_CommentSanitizesLines(t *testing.T) { t.Parallel() @@ -128,10 +169,10 @@ func Test_SSE_NewWritesHeartbeat(t *testing.T) { app := fiber.New() app.Get("/events", New(Config{ - HeartbeatInterval: 10 * time.Millisecond, + HeartbeatInterval: 5 * time.Millisecond, Handler: func(_ fiber.Ctx, stream *Stream) error { select { - case <-time.After(30 * time.Millisecond): + case <-time.After(150 * time.Millisecond): return nil case <-stream.Done(): return stream.Err() @@ -168,7 +209,7 @@ func Test_SSE_StreamConcurrentWrites(t *testing.T) { errs := make(chan error, writers) var wg sync.WaitGroup wg.Add(writers) - for i := range writers { + for i := 0; i < writers; i++ { go func(data int) { defer wg.Done() errs <- stream.Event(Event{Name: "message", Data: data}) @@ -268,7 +309,12 @@ func Test_SSE_HandlerErrorCallsOnClose(t *testing.T) { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) - require.ErrorIs(t, <-closed, handlerErr) + select { + case err := <-closed: + require.ErrorIs(t, err, handlerErr) + case <-time.After(time.Second): + t.Fatal("OnClose was not called") + } } func Test_SSE_NewPanicsWithoutHandler(t *testing.T) { @@ -279,6 +325,19 @@ func Test_SSE_NewPanicsWithoutHandler(t *testing.T) { }) } +func Test_SSE_StopHeartbeatIsIdempotent(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + stop := stream.startHeartbeat(time.Hour) + require.NotNil(t, stop) + require.NotPanics(t, stop) + require.NotPanics(t, stop) +} + func stringsTrimData(frame string) string { const prefix = "event: message\ndata: " frame = strings.TrimPrefix(frame, prefix) diff --git a/req_interface_gen.go b/req_interface_gen.go index d150c0ec9e7..59841e7a689 100644 --- a/req_interface_gen.go +++ b/req_interface_gen.go @@ -52,6 +52,8 @@ type Req interface { // https://godoc.org/github.com/valyala/fasthttp#Request Request() *fasthttp.Request // FormFile returns the first file by key from a MultipartForm. + // The multipart form is parsed using the application's BodyLimit to prevent + // unbounded memory usage. FormFile(key string) (*multipart.FileHeader, error) // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. @@ -59,6 +61,8 @@ type Req interface { // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. + // When the request is a multipart form, it is parsed using the application's + // BodyLimit so the configured limit is consistently enforced. FormValue(key string, defaultValue ...string) string // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale diff --git a/res_interface_gen.go b/res_interface_gen.go index ebe1caefe2f..0d4aa12e889 100644 --- a/res_interface_gen.go +++ b/res_interface_gen.go @@ -45,7 +45,7 @@ type Res interface { // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. - // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a

element. + // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a `

` element. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. AutoFormat(body any) error @@ -99,7 +99,8 @@ type Res interface { // ViewBind Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. ViewBind(vars Map) error - // getLocationFromRoute get URL location from route using parameters + // getLocationFromRoute gets the URL location from a route using parameters. + // Nil receivers and missing routes return ErrNotFound to match Route.URL semantics. getLocationFromRoute(route *Route, params Map) (string, error) // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error) From cfe765929c014567bd740bd421b02846bdc10cb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sun, 26 Apr 2026 19:57:06 +0200 Subject: [PATCH 3/9] test: cover SSE edge cases --- middleware/sse/sse_test.go | 93 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go index 8d9695405b9..4c25f52033e 100644 --- a/middleware/sse/sse_test.go +++ b/middleware/sse/sse_test.go @@ -122,6 +122,27 @@ func Test_SSE_EventWritesRawJSONData(t *testing.T) { require.Equal(t, "data: {\"hello\":\"world\"}\n\n", buf.String()) } +func Test_SSE_EventWritesByteSliceData(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: []byte("hello\nworld")})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: hello\ndata: world\n\n", buf.String()) +} + +func Test_SSE_EventReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + w := bufio.NewWriterSize(errWriter{err: writeErr}, 1) + + require.ErrorIs(t, writeEvent(w, Event{Data: "hello"}), writeErr) +} + func Test_SSE_CommentSanitizesLines(t *testing.T) { t.Parallel() @@ -134,6 +155,24 @@ func Test_SSE_CommentSanitizesLines(t *testing.T) { require.Equal(t, ": first\n: second\n\n", buf.String()) } +func Test_SSE_CommentReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + w := bufio.NewWriterSize(errWriter{err: writeErr}, 1) + + require.ErrorIs(t, writeComment(w, ""), writeErr) +} + +func Test_SSE_WriteDataReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + w := bufio.NewWriterSize(errWriter{err: writeErr}, 1) + + require.ErrorIs(t, writeData(w, "hello"), writeErr) +} + func Test_SSE_NewWritesHeadersAndEvents(t *testing.T) { t.Parallel() @@ -198,6 +237,28 @@ func Test_SSE_StreamComment(t *testing.T) { require.Equal(t, ": hello\n\n", buf.String()) } +func Test_SSE_StreamRetryIgnoresNonPositiveDuration(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + require.NoError(t, stream.Retry(0)) + require.NoError(t, stream.Retry(-time.Second)) + require.Empty(t, buf.String()) +} + +func Test_SSE_StreamRetryReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriter(errWriter{err: writeErr}), "") + + require.ErrorIs(t, stream.Retry(time.Second), writeErr) + require.ErrorIs(t, stream.Err(), writeErr) +} + func Test_SSE_StreamConcurrentWrites(t *testing.T) { t.Parallel() @@ -240,6 +301,17 @@ func Test_SSE_StreamContextCanceledOnClose(t *testing.T) { } } +func Test_SSE_NewStreamUsesBackgroundContextWhenNil(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(nil, bufio.NewWriter(&buf), "") //nolint:staticcheck // Covers the nil fallback branch in newStream. + defer stream.closeStream() + + require.NotNil(t, stream.Context()) + require.NoError(t, stream.Context().Err()) +} + func Test_SSE_StreamErrAfterNormalClose(t *testing.T) { t.Parallel() @@ -251,6 +323,16 @@ func Test_SSE_StreamErrAfterNormalClose(t *testing.T) { require.ErrorIs(t, stream.Event(Event{Data: "late"}), errStreamClosed) } +func Test_SSE_StreamReturnsLatchedError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriter(errWriter{err: writeErr}), "") + + require.ErrorIs(t, stream.Event(Event{Data: "hello"}), writeErr) + require.ErrorIs(t, stream.Event(Event{Data: "again"}), writeErr) +} + func Test_SSE_StreamWriteError(t *testing.T) { t.Parallel() @@ -338,6 +420,17 @@ func Test_SSE_StopHeartbeatIsIdempotent(t *testing.T) { require.NotPanics(t, stop) } +func Test_SSE_StartHeartbeatReturnsNilForNonPositiveInterval(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + require.Nil(t, stream.startHeartbeat(0)) + require.Nil(t, stream.startHeartbeat(-time.Second)) +} + func stringsTrimData(frame string) string { const prefix = "event: message\ndata: " frame = strings.TrimPrefix(frame, prefix) From a039e4ab9695ee38fce626dd4138622368660617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sun, 26 Apr 2026 19:57:50 +0200 Subject: [PATCH 4/9] fix: preserve late SSE stream errors --- middleware/sse/sse.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go index f4914b993d7..77d97feee41 100644 --- a/middleware/sse/sse.go +++ b/middleware/sse/sse.go @@ -48,7 +48,11 @@ func New(config ...Config) fiber.Handler { var streamErr error defer func() { if cfg.OnClose != nil { - cfg.OnClose(c, streamErr) + finalErr := streamErr + if finalErr == nil { + finalErr = stream.Err() + } + cfg.OnClose(c, finalErr) } }() defer stream.closeStream() From 301ca876fdadbe7566a48665c9b5cc9dcf847756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sun, 26 Apr 2026 20:01:25 +0200 Subject: [PATCH 5/9] test: cover SSE close and writer errors --- middleware/sse/sse_test.go | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go index 4c25f52033e..e89f0d92cbf 100644 --- a/middleware/sse/sse_test.go +++ b/middleware/sse/sse_test.go @@ -226,6 +226,33 @@ func Test_SSE_NewWritesHeartbeat(t *testing.T) { require.Contains(t, string(body), ":\n\n") } +func Test_SSE_OnCloseReceivesNilAfterNormalClose(t *testing.T) { + t.Parallel() + + closed := make(chan error, 1) + + app := fiber.New() + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(_ fiber.Ctx, stream *Stream) error { + return stream.Event(Event{Data: "ok"}) + }, + OnClose: func(_ fiber.Ctx, err error) { + closed <- err + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + select { + case err := <-closed: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("OnClose was not called") + } +} + func Test_SSE_StreamComment(t *testing.T) { t.Parallel() @@ -237,6 +264,16 @@ func Test_SSE_StreamComment(t *testing.T) { require.Equal(t, ": hello\n\n", buf.String()) } +func Test_SSE_StreamCommentReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriterSize(errWriter{err: writeErr}, 1), "") + + require.ErrorIs(t, stream.Comment("hello"), writeErr) + require.ErrorIs(t, stream.Err(), writeErr) +} + func Test_SSE_StreamRetryIgnoresNonPositiveDuration(t *testing.T) { t.Parallel() @@ -253,7 +290,7 @@ func Test_SSE_StreamRetryReturnsWriterError(t *testing.T) { t.Parallel() writeErr := errors.New("write failed") - stream := newStream(context.Background(), bufio.NewWriter(errWriter{err: writeErr}), "") + stream := newStream(context.Background(), bufio.NewWriterSize(errWriter{err: writeErr}, 1), "") require.ErrorIs(t, stream.Retry(time.Second), writeErr) require.ErrorIs(t, stream.Err(), writeErr) From 1ff8f6d3e1d0092a22d040648a9924f18efb2373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Mon, 27 Apr 2026 08:28:38 +0200 Subject: [PATCH 6/9] fix: address SSE coverage review --- docs/middleware/sse.md | 18 ++++++++++++--- middleware/sse/event.go | 49 +++++++++++++++++++++-------------------- middleware/sse/sse.go | 3 +++ 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md index fbffc19ae83..a74ae225d91 100644 --- a/docs/middleware/sse.md +++ b/docs/middleware/sse.md @@ -20,6 +20,7 @@ Import the middleware package: ```go import ( + "context" "time" "github.com/gofiber/fiber/v3" @@ -41,16 +42,27 @@ app.Get("/events", sse.New(sse.Config{ })) ``` -For long-running streams, wait on your own event source and stop when the client disconnects: +For long-running streams, subscribe each client to its own event channel and stop when the client disconnects. +A single shared channel load-balances messages across clients; use a fan-out source when every client must receive every event: ```go -events := make(chan string) +type Broker interface { + Subscribe(ctx context.Context) (<-chan string, error) +} app.Get("/events", sse.New(sse.Config{ Handler: func(c fiber.Ctx, stream *sse.Stream) error { + events, err := broker.Subscribe(stream.Context()) + if err != nil { + return err + } + for { select { - case msg := <-events: + case msg, ok := <-events: + if !ok { + return nil + } if err := stream.Event(sse.Event{Name: "message", Data: msg}); err != nil { return err } diff --git a/middleware/sse/event.go b/middleware/sse/event.go index dac2435428a..2cfcc38a032 100644 --- a/middleware/sse/event.go +++ b/middleware/sse/event.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" "strings" "time" @@ -31,48 +32,34 @@ type Event struct { } func writeEvent(w *bufio.Writer, event Event) error { + data, err := eventData(event.Data) + if err != nil { + return err + } + var frame bytes.Buffer - fw := bufio.NewWriter(&frame) if event.ID != "" { id, err := sanitizeField(event.ID) if err != nil { return fmt.Errorf("sse: invalid id: %w", err) } - if _, err := fmt.Fprintf(fw, "id: %s\n", id); err != nil { - return fmt.Errorf("sse: write id: %w", err) - } + appendField(&frame, "id", id) } if event.Name != "" { name, err := sanitizeField(event.Name) if err != nil { return fmt.Errorf("sse: invalid event: %w", err) } - if _, err := fmt.Fprintf(fw, "event: %s\n", name); err != nil { - return fmt.Errorf("sse: write event: %w", err) - } + appendField(&frame, "event", name) } if event.Retry > 0 { - if _, err := fmt.Fprintf(fw, "retry: %d\n", event.Retry.Milliseconds()); err != nil { - return fmt.Errorf("sse: write retry: %w", err) - } - } - - data, err := eventData(event.Data) - if err != nil { - return err + appendField(&frame, "retry", strconv.FormatInt(event.Retry.Milliseconds(), 10)) } if data.hasData { - if err := writeData(fw, data.data); err != nil { - return err - } - } - if _, err := fw.WriteString("\n"); err != nil { - return fmt.Errorf("sse: finish event: %w", err) - } - if err := fw.Flush(); err != nil { - return fmt.Errorf("sse: flush event frame: %w", err) + appendData(&frame, data.data) } + frame.WriteByte('\n') //nolint:errcheck // bytes.Buffer writes never fail. if _, err := w.Write(frame.Bytes()); err != nil { return fmt.Errorf("sse: write event: %w", err) } @@ -132,6 +119,20 @@ func writeData(w *bufio.Writer, data string) error { return nil } +func appendField(w *bytes.Buffer, field, value string) { + w.WriteString(field) //nolint:errcheck // bytes.Buffer writes never fail. + w.WriteString(": ") //nolint:errcheck // bytes.Buffer writes never fail. + w.WriteString(value) //nolint:errcheck // bytes.Buffer writes never fail. + w.WriteByte('\n') //nolint:errcheck // bytes.Buffer writes never fail. +} + +func appendData(w *bytes.Buffer, data string) { + data = normalizeNewlines(data) + for line := range strings.SplitSeq(data, "\n") { + appendField(w, "data", line) + } +} + func sanitizeField(value string) (string, error) { if strings.ContainsAny(value, "\r\n") { return "", errInvalidField diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go index 77d97feee41..a4e7e2979df 100644 --- a/middleware/sse/sse.go +++ b/middleware/sse/sse.go @@ -202,7 +202,9 @@ func (s *Stream) startHeartbeat(interval time.Duration) func() { stop := make(chan struct{}) var stopOnce sync.Once + done := make(chan struct{}) go func() { + defer close(done) ticker := time.NewTicker(interval) defer ticker.Stop() for { @@ -223,5 +225,6 @@ func (s *Stream) startHeartbeat(interval time.Duration) func() { stopOnce.Do(func() { close(stop) }) + <-done } } From dcd817d41304c2992633955f4fe803595f1593f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Mon, 27 Apr 2026 08:35:02 +0200 Subject: [PATCH 7/9] fix: refine SSE event field handling --- middleware/sse/event.go | 16 +++++++++++---- middleware/sse/sse_test.go | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/middleware/sse/event.go b/middleware/sse/event.go index 2cfcc38a032..e53974254b8 100644 --- a/middleware/sse/event.go +++ b/middleware/sse/event.go @@ -44,14 +44,18 @@ func writeEvent(w *bufio.Writer, event Event) error { if err != nil { return fmt.Errorf("sse: invalid id: %w", err) } - appendField(&frame, "id", id) + if id != "" { + appendField(&frame, "id", id) + } } if event.Name != "" { name, err := sanitizeField(event.Name) if err != nil { return fmt.Errorf("sse: invalid event: %w", err) } - appendField(&frame, "event", name) + if name != "" { + appendField(&frame, "event", name) + } } if event.Retry > 0 { appendField(&frame, "retry", strconv.FormatInt(event.Retry.Milliseconds(), 10)) @@ -110,7 +114,7 @@ func eventData(data any) (eventPayload, error) { } func writeData(w *bufio.Writer, data string) error { - data = normalizeNewlines(data) + data = trimSingleTrailingNewline(normalizeNewlines(data)) for line := range strings.SplitSeq(data, "\n") { if _, err := fmt.Fprintf(w, "data: %s\n", line); err != nil { return fmt.Errorf("sse: write data: %w", err) @@ -127,12 +131,16 @@ func appendField(w *bytes.Buffer, field, value string) { } func appendData(w *bytes.Buffer, data string) { - data = normalizeNewlines(data) + data = trimSingleTrailingNewline(normalizeNewlines(data)) for line := range strings.SplitSeq(data, "\n") { appendField(w, "data", line) } } +func trimSingleTrailingNewline(value string) string { + return strings.TrimSuffix(value, "\n") +} + func sanitizeField(value string) (string, error) { if strings.ContainsAny(value, "\r\n") { return "", errInvalidField diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go index e89f0d92cbf..fa98638af53 100644 --- a/middleware/sse/sse_test.go +++ b/middleware/sse/sse_test.go @@ -54,6 +54,22 @@ func Test_SSE_EventRejectsFieldInjection(t *testing.T) { }), errInvalidField) } +func Test_SSE_EventOmitsWhitespaceOnlyFields(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ + ID: " ", + Name: " ", + Data: "ok", + })) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: ok\n\n", buf.String()) +} + func Test_SSE_EventJSONEncodesData(t *testing.T) { t.Parallel() @@ -134,6 +150,30 @@ func Test_SSE_EventWritesByteSliceData(t *testing.T) { require.Equal(t, "data: hello\ndata: world\n\n", buf.String()) } +func Test_SSE_EventTrimsSingleTrailingDataNewline(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: "hello\n"})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: hello\n\n", buf.String()) +} + +func Test_SSE_EventPreservesIntentionalBlankDataLine(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: "hello\n\n"})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: hello\ndata: \n\n", buf.String()) +} + func Test_SSE_EventReturnsWriterError(t *testing.T) { t.Parallel() From e71b08f76fa4bf9f7645fef6f1902ce6dd6ad323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Wed, 29 Apr 2026 09:01:10 +0200 Subject: [PATCH 8/9] fix: address SSE review comments --- docs/whats_new.md | 2 +- middleware/sse/constants.go | 3 +++ middleware/sse/event.go | 15 +++++++++++---- middleware/sse/sse.go | 11 ++++++----- middleware/sse/sse_test.go | 24 ++++++++++++++++++++++++ 5 files changed, 45 insertions(+), 10 deletions(-) create mode 100644 middleware/sse/constants.go diff --git a/docs/whats_new.md b/docs/whats_new.md index a07231c8712..eb50b3f82ee 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1683,7 +1683,7 @@ For more details on these changes and migration instructions, check the [Session ### SSE -Fiber now includes a small [SSE middleware](./middleware/sse.md) for Server-Sent Events. It handles native +Fiber now includes an [SSE middleware](./middleware/sse.md) for Server-Sent Events. It handles native `SendStreamWriter` setup, SSE response headers, event formatting, flushing, heartbeat comments, and disconnect detection through flush errors while leaving application-level hubs, topics, replay stores, and pub/sub bridges to user code or recipes. diff --git a/middleware/sse/constants.go b/middleware/sse/constants.go new file mode 100644 index 00000000000..e9c7d4e5266 --- /dev/null +++ b/middleware/sse/constants.go @@ -0,0 +1,3 @@ +package sse + +const mimeTextEventStream = "text/event-stream" diff --git a/middleware/sse/event.go b/middleware/sse/event.go index e53974254b8..cf17e4877a2 100644 --- a/middleware/sse/event.go +++ b/middleware/sse/event.go @@ -31,8 +31,8 @@ type Event struct { Retry time.Duration } -func writeEvent(w *bufio.Writer, event Event) error { - data, err := eventData(event.Data) +func writeEvent(w *bufio.Writer, event Event, jsonMarshal ...utils.JSONMarshal) error { + data, err := eventData(event.Data, jsonMarshalOrDefault(jsonMarshal)) if err != nil { return err } @@ -94,7 +94,7 @@ type eventPayload struct { hasData bool } -func eventData(data any) (eventPayload, error) { +func eventData(data any, jsonMarshal utils.JSONMarshal) (eventPayload, error) { switch value := data.(type) { case nil: return eventPayload{}, nil @@ -105,7 +105,7 @@ func eventData(data any) (eventPayload, error) { case json.RawMessage: return eventPayload{data: string(value), hasData: true}, nil default: - encoded, err := json.Marshal(value) + encoded, err := jsonMarshal(value) if err != nil { return eventPayload{}, fmt.Errorf("sse: marshal data: %w", err) } @@ -113,6 +113,13 @@ func eventData(data any) (eventPayload, error) { } } +func jsonMarshalOrDefault(jsonMarshal []utils.JSONMarshal) utils.JSONMarshal { + if len(jsonMarshal) > 0 && jsonMarshal[0] != nil { + return jsonMarshal[0] + } + return json.Marshal +} + func writeData(w *bufio.Writer, data string) error { data = trimSingleTrailingNewline(normalizeNewlines(data)) for line := range strings.SplitSeq(data, "\n") { diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go index a4e7e2979df..aae56d1bedb 100644 --- a/middleware/sse/sse.go +++ b/middleware/sse/sse.go @@ -15,10 +15,9 @@ import ( "time" "github.com/gofiber/fiber/v3" + "github.com/gofiber/utils/v2" ) -const mimeTextEventStream = "text/event-stream" - var errStreamClosed = errors.New("sse: stream closed") // New creates a new middleware handler. @@ -44,7 +43,7 @@ func New(config ...Config) fiber.Handler { lastEventID := c.Get(fiber.HeaderLastEventID) return c.SendStreamWriter(func(w *bufio.Writer) { - stream := newStream(streamContext, w, lastEventID) + stream := newStream(streamContext, w, lastEventID, c.App().Config().JSONEncoder) var streamErr error defer func() { if cfg.OnClose != nil { @@ -86,13 +85,14 @@ type Stream struct { err error w *bufio.Writer done chan struct{} + jsonMarshal utils.JSONMarshal lastEventID string closed bool once sync.Once mu sync.Mutex } -func newStream(ctx context.Context, w *bufio.Writer, lastEventID string) *Stream { //nolint:contextcheck // ctx is the parent for the derived stream lifecycle context. +func newStream(ctx context.Context, w *bufio.Writer, lastEventID string, jsonMarshal ...utils.JSONMarshal) *Stream { //nolint:contextcheck // ctx is the parent for the derived stream lifecycle context. if ctx == nil { ctx = context.Background() } @@ -102,6 +102,7 @@ func newStream(ctx context.Context, w *bufio.Writer, lastEventID string) *Stream cancel: cancel, w: w, done: make(chan struct{}), + jsonMarshal: jsonMarshalOrDefault(jsonMarshal), lastEventID: lastEventID, } } @@ -131,7 +132,7 @@ func (s *Stream) Err() error { // Event writes one SSE event and flushes it to the client. func (s *Stream) Event(event Event) error { return s.write(func(w *bufio.Writer) error { - return writeEvent(w, event) + return writeEvent(w, event, s.jsonMarshal) }) } diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go index fa98638af53..edd56d1feef 100644 --- a/middleware/sse/sse_test.go +++ b/middleware/sse/sse_test.go @@ -85,6 +85,30 @@ func Test_SSE_EventJSONEncodesData(t *testing.T) { require.JSONEq(t, `{"hello":"world"}`, stringsTrimData(buf.String())) } +func Test_SSE_NewUsesAppJSONEncoder(t *testing.T) { + t.Parallel() + + app := fiber.New(fiber.Config{ + JSONEncoder: func(any) ([]byte, error) { + return []byte(`{"encoded":"custom"}`), nil + }, + }) + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(_ fiber.Ctx, stream *Stream) error { + return stream.Event(Event{Data: map[string]string{"hello": "world"}}) + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "data: {\"encoded\":\"custom\"}\n\n", string(body)) +} + func Test_SSE_EventJSONEncodesTypedNilStringer(t *testing.T) { t.Parallel() From f0c3543dbc64a3c61e0287a2544089b7dbf8026c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Wed, 29 Apr 2026 10:53:55 +0200 Subject: [PATCH 9/9] fix: refine SSE stream lifecycle handling --- docs/middleware/sse.md | 4 +++- middleware/sse/sse.go | 9 +++++++-- middleware/sse/sse_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md index a74ae225d91..d1f07075956 100644 --- a/docs/middleware/sse.md +++ b/docs/middleware/sse.md @@ -99,7 +99,7 @@ app.Get("/events", sse.New(sse.Config{ | OnClose | `func(fiber.Ctx, error)` | Called when the stream ends, with `nil` when the handler returned successfully and no stream write failed. | `nil` | | Retry | `time.Duration` | Initial EventSource reconnect delay. | `0` | | HeartbeatInterval | `time.Duration` | Interval for SSE comment heartbeats. | `15 * time.Second` | -| DisableHeartbeat | `bool` | Disable automatic heartbeat comments. | `false` | +| DisableHeartbeat | `bool` | Disable automatic heartbeat comments. When disabled, disconnected clients may not be detected until the next write. | `false` | ## Default Config @@ -128,4 +128,6 @@ func (s *Stream) LastEventID() string Every write is flushed. A failed flush closes `Done`, stores the error returned by `Err`, and lets the handler stop without relying on `fasthttp.RequestCtx.Done`, which is not a per-client disconnect signal. After a normal handler return, `Done` is closed and `Context()` is canceled while `Err()` remains `nil`; writes after that return `sse: stream closed`. +Automatic heartbeat comments keep idle streams active and make silent client disconnects observable through the next flush error. If heartbeats are disabled, a handler waiting on an external source might not notice a disconnected client until it writes again. Stopping a stream waits for an in-flight heartbeat write to finish, so a very slow client can delay shutdown until the underlying write unblocks. + `Config.Retry` sends the initial reconnect delay when the stream opens. `Event.Retry` changes the reconnect delay for a specific event, following the SSE wire format. diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go index aae56d1bedb..34236d9498f 100644 --- a/middleware/sse/sse.go +++ b/middleware/sse/sse.go @@ -37,11 +37,11 @@ func New(config ...Config) fiber.Handler { c.Set(fiber.HeaderConnection, "keep-alive") c.Set("X-Accel-Buffering", "no") - c.Abandon() - streamContext := c.Context() lastEventID := c.Get(fiber.HeaderLastEventID) + c.Abandon() + return c.SendStreamWriter(func(w *bufio.Writer) { stream := newStream(streamContext, w, lastEventID, c.App().Config().JSONEncoder) var streamErr error @@ -54,6 +54,11 @@ func New(config ...Config) fiber.Handler { cfg.OnClose(c, finalErr) } }() + defer func() { + if recovered := recover(); recovered != nil { + streamErr = fmt.Errorf("sse: handler panic: %v", recovered) + } + }() defer stream.closeStream() if cfg.Retry > 0 { diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go index edd56d1feef..eb2b3212214 100644 --- a/middleware/sse/sse_test.go +++ b/middleware/sse/sse_test.go @@ -500,6 +500,34 @@ func Test_SSE_HandlerErrorCallsOnClose(t *testing.T) { } } +func Test_SSE_HandlerPanicCallsOnClose(t *testing.T) { + t.Parallel() + + closed := make(chan error, 1) + + app := fiber.New() + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(fiber.Ctx, *Stream) error { + panic("boom") + }, + OnClose: func(_ fiber.Ctx, err error) { + closed <- err + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + select { + case err := <-closed: + require.EqualError(t, err, "sse: handler panic: boom") + case <-time.After(time.Second): + t.Fatal("OnClose was not called") + } +} + func Test_SSE_NewPanicsWithoutHandler(t *testing.T) { t.Parallel()