diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index 99fc5da1f52..e735861aa98 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -256,6 +256,8 @@ type Ctx interface { // Make copies or use the Immutable setting to use the value outside the Handler. Cookies(key string, defaultValue ...string) string // FormFile returns the first file by key from a MultipartForm. + // The multipart form is parsed using the application's BodyLimit to prevent + // unbounded memory usage. FormFile(key string) (*multipart.FileHeader, error) // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. @@ -263,6 +265,8 @@ type Ctx interface { // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. + // When the request is a multipart form, it is parsed using the application's + // BodyLimit so the configured limit is consistently enforced. FormValue(key string, defaultValue ...string) string // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale @@ -398,7 +402,7 @@ type Ctx interface { // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. - // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a

element. + // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a `

` element. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. AutoFormat(body any) error @@ -431,7 +435,8 @@ type Ctx interface { Links(link ...string) // Location sets the response Location HTTP header to the specified path parameter. Location(path string) - // getLocationFromRoute get URL location from route using parameters + // getLocationFromRoute gets the URL location from a route using parameters. + // Nil receivers and missing routes return ErrNotFound to match Route.URL semantics. getLocationFromRoute(route *Route, params Map) (string, error) // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error) diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md new file mode 100644 index 00000000000..d1f07075956 --- /dev/null +++ b/docs/middleware/sse.md @@ -0,0 +1,133 @@ +--- +id: sse +--- + +# SSE + +The SSE middleware provides the transport pieces for Server-Sent Events: response headers, event formatting, flushing, heartbeat comments, and disconnect detection through `Flush` errors. + +It intentionally does not include a hub, topics, authentication, replay storage, metrics, or external pub/sub bridges. Those are application concerns that can be composed around the stream handler. + +## Signatures + +```go +func New(config ...Config) fiber.Handler +``` + +## Examples + +Import the middleware package: + +```go +import ( + "context" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/sse" +) +``` + +Once your Fiber app is initialized, mount an SSE endpoint like this: + +```go +app.Get("/events", sse.New(sse.Config{ + Retry: 5 * time.Second, + Handler: func(c fiber.Ctx, stream *sse.Stream) error { + return stream.Event(sse.Event{ + Name: "message", + Data: fiber.Map{"message": "hello"}, + }) + }, +})) +``` + +For long-running streams, subscribe each client to its own event channel and stop when the client disconnects. +A single shared channel load-balances messages across clients; use a fan-out source when every client must receive every event: + +```go +type Broker interface { + Subscribe(ctx context.Context) (<-chan string, error) +} + +app.Get("/events", sse.New(sse.Config{ + Handler: func(c fiber.Ctx, stream *sse.Stream) error { + events, err := broker.Subscribe(stream.Context()) + if err != nil { + return err + } + + for { + select { + case msg, ok := <-events: + if !ok { + return nil + } + if err := stream.Event(sse.Event{Name: "message", Data: msg}); err != nil { + return err + } + case <-stream.Done(): + return stream.Err() + } + } + }, +})) +``` + +`stream.Context()` is canceled when the stream ends or a write fails, which makes it convenient to pass into database, broker, or gRPC calls: + +```go +app.Get("/events", sse.New(sse.Config{ + Handler: func(c fiber.Ctx, stream *sse.Stream) error { + rows, err := db.QueryContext(stream.Context(), "SELECT id FROM jobs") + if err != nil { + return err + } + defer rows.Close() + + return stream.Comment("connected") + }, +})) +``` + +## Config + +| Property | Type | Description | Default | +|:------------------|:-----------------------------|:----------------------------------------------|:--------------------| +| Next | `func(fiber.Ctx) bool` | Skip when the function returns `true`. | `nil` | +| Handler | `sse.Handler` | Writes events to the stream. | `nil` | +| OnClose | `func(fiber.Ctx, error)` | Called when the stream ends, with `nil` when the handler returned successfully and no stream write failed. | `nil` | +| Retry | `time.Duration` | Initial EventSource reconnect delay. | `0` | +| HeartbeatInterval | `time.Duration` | Interval for SSE comment heartbeats. | `15 * time.Second` | +| DisableHeartbeat | `bool` | Disable automatic heartbeat comments. When disabled, disconnected clients may not be detected until the next write. | `false` | + +## Default Config + +```go +var ConfigDefault = Config{ + Next: nil, + Handler: nil, + OnClose: nil, + Retry: 0, + HeartbeatInterval: 15 * time.Second, + DisableHeartbeat: false, +} +``` + +## Stream + +```go +func (s *Stream) Event(event Event) error +func (s *Stream) Comment(comment string) error +func (s *Stream) Retry(retry time.Duration) error +func (s *Stream) Context() context.Context +func (s *Stream) Done() <-chan struct{} +func (s *Stream) Err() error +func (s *Stream) LastEventID() string +``` + +Every write is flushed. A failed flush closes `Done`, stores the error returned by `Err`, and lets the handler stop without relying on `fasthttp.RequestCtx.Done`, which is not a per-client disconnect signal. After a normal handler return, `Done` is closed and `Context()` is canceled while `Err()` remains `nil`; writes after that return `sse: stream closed`. + +Automatic heartbeat comments keep idle streams active and make silent client disconnects observable through the next flush error. If heartbeats are disabled, a handler waiting on an external source might not notice a disconnected client until it writes again. Stopping a stream waits for an in-flight heartbeat write to finish, so a very slow client can delay shutdown until the underlying write unblocks. + +`Config.Retry` sends the initial reconnect delay when the stream opens. `Event.Retry` changes the reconnect delay for a specific event, following the SSE wire format. diff --git a/docs/whats_new.md b/docs/whats_new.md index 3ee2d2b2981..eb50b3f82ee 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -57,6 +57,7 @@ Here's a quick overview of the changes in Fiber `v3`: - [Proxy](#proxy) - [Recover](#recover) - [Session](#session) + - [SSE](#sse) - [🔌 Addons](#-addons) - [📋 Migration guide](#-migration-guide) @@ -1680,6 +1681,13 @@ The session middleware has undergone significant improvements in v3, focusing on For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). +### SSE + +Fiber now includes an [SSE middleware](./middleware/sse.md) for Server-Sent Events. It handles native +`SendStreamWriter` setup, SSE response headers, event formatting, flushing, heartbeat comments, and +disconnect detection through flush errors while leaving application-level hubs, topics, replay stores, and +pub/sub bridges to user code or recipes. + ### Timeout The timeout middleware is now configurable. A new `Config` struct allows customizing the timeout duration, defining a handler that runs when a timeout occurs, and specifying errors to treat as timeouts. The `New` function now accepts a `Config` value instead of a duration. diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index a6ec1af7efe..cbb746875a3 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -1080,7 +1080,7 @@ func Test_Limiter_Sliding_Window_RecalculatesAfterHandlerDelay(t *testing.T) { require.Equal(t, fiber.StatusOK, resp.StatusCode) } - time.Sleep(time.Second + 100*time.Millisecond) + time.Sleep(2*time.Second + 100*time.Millisecond) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) diff --git a/middleware/sse/config.go b/middleware/sse/config.go new file mode 100644 index 00000000000..8c1ed2511ef --- /dev/null +++ b/middleware/sse/config.go @@ -0,0 +1,70 @@ +package sse + +import ( + "time" + + "github.com/gofiber/fiber/v3" +) + +// Handler writes events to a single SSE stream. +type Handler func(c fiber.Ctx, stream *Stream) error + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // Handler writes events to the stream. + // + // Required. + Handler Handler + + // OnClose is called after the stream handler returns or the client disconnects. + // + // Optional. Default: nil + OnClose func(c fiber.Ctx, err error) + + // Retry controls the reconnection delay sent to clients. + // Values less than or equal to zero disable the initial retry field. + // + // Optional. Default: 0 + Retry time.Duration + + // HeartbeatInterval controls comment heartbeats used to keep intermediaries + // from closing idle streams and to detect disconnected clients. + // When DisableHeartbeat is false, values less than or equal to zero are + // replaced by the default interval. + // + // Optional. Default: 15 * time.Second + HeartbeatInterval time.Duration + + // DisableHeartbeat disables automatic comment heartbeats. + // + // Optional. Default: false + DisableHeartbeat bool +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{ + Next: nil, + Handler: nil, + OnClose: nil, + Retry: 0, + HeartbeatInterval: 15 * time.Second, + DisableHeartbeat: false, +} + +// Helper function to set default values. +func configDefault(config ...Config) Config { + if len(config) < 1 { + return ConfigDefault + } + + cfg := config[0] + if !cfg.DisableHeartbeat && cfg.HeartbeatInterval <= 0 { + cfg.HeartbeatInterval = ConfigDefault.HeartbeatInterval + } + return cfg +} diff --git a/middleware/sse/constants.go b/middleware/sse/constants.go new file mode 100644 index 00000000000..e9c7d4e5266 --- /dev/null +++ b/middleware/sse/constants.go @@ -0,0 +1,3 @@ +package sse + +const mimeTextEventStream = "text/event-stream" diff --git a/middleware/sse/event.go b/middleware/sse/event.go new file mode 100644 index 00000000000..cf17e4877a2 --- /dev/null +++ b/middleware/sse/event.go @@ -0,0 +1,170 @@ +package sse + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/gofiber/utils/v2" +) + +var errInvalidField = errors.New("field must not contain CR or LF") + +// Event defines a single Server-Sent Event frame. +type Event struct { + // Data is written as one or more data fields. Strings and byte slices are + // written as-is; other values are JSON encoded. + Data any + + // ID sets the SSE id field. + ID string + + // Name sets the SSE event field. + Name string + + // Retry sets the SSE retry field for this event. + Retry time.Duration +} + +func writeEvent(w *bufio.Writer, event Event, jsonMarshal ...utils.JSONMarshal) error { + data, err := eventData(event.Data, jsonMarshalOrDefault(jsonMarshal)) + if err != nil { + return err + } + + var frame bytes.Buffer + + if event.ID != "" { + id, err := sanitizeField(event.ID) + if err != nil { + return fmt.Errorf("sse: invalid id: %w", err) + } + if id != "" { + appendField(&frame, "id", id) + } + } + if event.Name != "" { + name, err := sanitizeField(event.Name) + if err != nil { + return fmt.Errorf("sse: invalid event: %w", err) + } + if name != "" { + appendField(&frame, "event", name) + } + } + if event.Retry > 0 { + appendField(&frame, "retry", strconv.FormatInt(event.Retry.Milliseconds(), 10)) + } + if data.hasData { + appendData(&frame, data.data) + } + frame.WriteByte('\n') //nolint:errcheck // bytes.Buffer writes never fail. + if _, err := w.Write(frame.Bytes()); err != nil { + return fmt.Errorf("sse: write event: %w", err) + } + return nil +} + +func writeComment(w *bufio.Writer, comment string) error { + comment = sanitizeComment(comment) + if comment == "" { + if _, err := w.WriteString(":\n\n"); err != nil { + return fmt.Errorf("sse: write heartbeat: %w", err) + } + return nil + } + for line := range strings.SplitSeq(comment, "\n") { + if _, err := fmt.Fprintf(w, ": %s\n", line); err != nil { + return fmt.Errorf("sse: write comment: %w", err) + } + } + if _, err := w.WriteString("\n"); err != nil { + return fmt.Errorf("sse: finish comment: %w", err) + } + return nil +} + +type eventPayload struct { + data string + hasData bool +} + +func eventData(data any, jsonMarshal utils.JSONMarshal) (eventPayload, error) { + switch value := data.(type) { + case nil: + return eventPayload{}, nil + case string: + return eventPayload{data: value, hasData: true}, nil + case []byte: + return eventPayload{data: string(value), hasData: true}, nil + case json.RawMessage: + return eventPayload{data: string(value), hasData: true}, nil + default: + encoded, err := jsonMarshal(value) + if err != nil { + return eventPayload{}, fmt.Errorf("sse: marshal data: %w", err) + } + return eventPayload{data: string(encoded), hasData: true}, nil + } +} + +func jsonMarshalOrDefault(jsonMarshal []utils.JSONMarshal) utils.JSONMarshal { + if len(jsonMarshal) > 0 && jsonMarshal[0] != nil { + return jsonMarshal[0] + } + return json.Marshal +} + +func writeData(w *bufio.Writer, data string) error { + data = trimSingleTrailingNewline(normalizeNewlines(data)) + for line := range strings.SplitSeq(data, "\n") { + if _, err := fmt.Fprintf(w, "data: %s\n", line); err != nil { + return fmt.Errorf("sse: write data: %w", err) + } + } + return nil +} + +func appendField(w *bytes.Buffer, field, value string) { + w.WriteString(field) //nolint:errcheck // bytes.Buffer writes never fail. + w.WriteString(": ") //nolint:errcheck // bytes.Buffer writes never fail. + w.WriteString(value) //nolint:errcheck // bytes.Buffer writes never fail. + w.WriteByte('\n') //nolint:errcheck // bytes.Buffer writes never fail. +} + +func appendData(w *bytes.Buffer, data string) { + data = trimSingleTrailingNewline(normalizeNewlines(data)) + for line := range strings.SplitSeq(data, "\n") { + appendField(w, "data", line) + } +} + +func trimSingleTrailingNewline(value string) string { + return strings.TrimSuffix(value, "\n") +} + +func sanitizeField(value string) (string, error) { + if strings.ContainsAny(value, "\r\n") { + return "", errInvalidField + } + return utils.Trim(value, ' '), nil +} + +func sanitizeComment(value string) string { + value = normalizeNewlines(value) + lines := make([]string, 0, strings.Count(value, "\n")+1) + for line := range strings.SplitSeq(value, "\n") { + lines = append(lines, utils.Trim(line, ' ')) + } + return strings.Join(lines, "\n") +} + +func normalizeNewlines(value string) string { + value = strings.ReplaceAll(value, "\r\n", "\n") + return strings.ReplaceAll(value, "\r", "\n") +} diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go new file mode 100644 index 00000000000..34236d9498f --- /dev/null +++ b/middleware/sse/sse.go @@ -0,0 +1,236 @@ +// Package sse provides small Server-Sent Events middleware for Fiber. +// +// The package focuses on the SSE transport: response headers, wire formatting, +// flushing, heartbeat comments, and disconnect detection via flush errors. +// Application-specific concerns such as topics, replay storage, authentication, +// and pub/sub fan-out intentionally stay outside the core middleware. +package sse + +import ( + "bufio" + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/utils/v2" +) + +var errStreamClosed = errors.New("sse: stream closed") + +// New creates a new middleware handler. +func New(config ...Config) fiber.Handler { + cfg := configDefault(config...) + if cfg.Handler == nil { + panic("sse: Handler must not be nil") + } + + return func(c fiber.Ctx) error { + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + c.Set(fiber.HeaderContentType, mimeTextEventStream) + c.Set(fiber.HeaderCacheControl, "no-cache") + c.Set(fiber.HeaderConnection, "keep-alive") + c.Set("X-Accel-Buffering", "no") + + streamContext := c.Context() + lastEventID := c.Get(fiber.HeaderLastEventID) + + c.Abandon() + + return c.SendStreamWriter(func(w *bufio.Writer) { + stream := newStream(streamContext, w, lastEventID, c.App().Config().JSONEncoder) + var streamErr error + defer func() { + if cfg.OnClose != nil { + finalErr := streamErr + if finalErr == nil { + finalErr = stream.Err() + } + cfg.OnClose(c, finalErr) + } + }() + defer func() { + if recovered := recover(); recovered != nil { + streamErr = fmt.Errorf("sse: handler panic: %v", recovered) + } + }() + defer stream.closeStream() + + if cfg.Retry > 0 { + streamErr = stream.Retry(cfg.Retry) + if streamErr != nil { + return + } + } + + if !cfg.DisableHeartbeat { + stopHeartbeat := stream.startHeartbeat(cfg.HeartbeatInterval) + if stopHeartbeat != nil { + defer stopHeartbeat() + } + } + + streamErr = cfg.Handler(c, stream) + if streamErr == nil { + streamErr = stream.Err() + } + }) + } +} + +// Stream is an active SSE response stream. +type Stream struct { + ctx context.Context //nolint:containedctx // Stream exposes a per-stream context canceled with the stream lifecycle. + cancel context.CancelFunc + err error + w *bufio.Writer + done chan struct{} + jsonMarshal utils.JSONMarshal + lastEventID string + closed bool + once sync.Once + mu sync.Mutex +} + +func newStream(ctx context.Context, w *bufio.Writer, lastEventID string, jsonMarshal ...utils.JSONMarshal) *Stream { //nolint:contextcheck // ctx is the parent for the derived stream lifecycle context. + if ctx == nil { + ctx = context.Background() + } + ctx, cancel := context.WithCancel(ctx) + return &Stream{ + ctx: ctx, + cancel: cancel, + w: w, + done: make(chan struct{}), + jsonMarshal: jsonMarshalOrDefault(jsonMarshal), + lastEventID: lastEventID, + } +} + +// Context returns a context canceled when the stream ends or a write fails. +func (s *Stream) Context() context.Context { + return s.ctx +} + +// Done returns a channel closed when a write fails or the handler returns. +func (s *Stream) Done() <-chan struct{} { + return s.done +} + +// LastEventID returns the Last-Event-ID header value sent by the client. +func (s *Stream) LastEventID() string { + return s.lastEventID +} + +// Err returns the first stream write error. +func (s *Stream) Err() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.err +} + +// Event writes one SSE event and flushes it to the client. +func (s *Stream) Event(event Event) error { + return s.write(func(w *bufio.Writer) error { + return writeEvent(w, event, s.jsonMarshal) + }) +} + +// Comment writes one SSE comment and flushes it to the client. +func (s *Stream) Comment(comment string) error { + return s.write(func(w *bufio.Writer) error { + return writeComment(w, comment) + }) +} + +// Retry writes an SSE retry field and flushes it to the client. +func (s *Stream) Retry(retry time.Duration) error { + if retry <= 0 { + return nil + } + return s.write(func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "retry: %d\n\n", retry.Milliseconds()) + if err != nil { + return fmt.Errorf("sse: write retry: %w", err) + } + return nil + }) +} + +func (s *Stream) write(fn func(w *bufio.Writer) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return s.err + } + if s.closed { + return errStreamClosed + } + if err := fn(s.w); err != nil { + return s.failLocked(err) + } + if err := s.w.Flush(); err != nil { + return s.failLocked(err) + } + return nil +} + +func (s *Stream) failLocked(err error) error { + s.err = err + s.closed = true + s.once.Do(func() { + s.cancel() + close(s.done) + }) + return err +} + +func (s *Stream) closeStream() { + s.mu.Lock() + s.closed = true + s.mu.Unlock() + s.once.Do(func() { + s.cancel() + close(s.done) + }) +} + +func (s *Stream) startHeartbeat(interval time.Duration) func() { + if interval <= 0 { + return nil + } + + stop := make(chan struct{}) + var stopOnce sync.Once + done := make(chan struct{}) + go func() { + defer close(done) + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := s.Comment(""); err != nil { + return + } + case <-stop: + return + case <-s.Done(): + return + } + } + }() + + return func() { + stopOnce.Do(func() { + close(stop) + }) + <-done + } +} diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go new file mode 100644 index 00000000000..eb2b3212214 --- /dev/null +++ b/middleware/sse/sse_test.go @@ -0,0 +1,575 @@ +package sse + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +type panicStringer struct{} + +func (*panicStringer) String() string { + panic("String must not be called for SSE data") +} + +func Test_SSE_EventWritesFrame(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ + ID: " 42 ", + Name: "update", + Data: "one\r\ntwo", + Retry: 2500 * time.Millisecond, + })) + require.NoError(t, w.Flush()) + + require.Equal(t, "id: 42\nevent: update\nretry: 2500\ndata: one\ndata: two\n\n", buf.String()) +} + +func Test_SSE_EventRejectsFieldInjection(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.ErrorIs(t, writeEvent(w, Event{ + ID: "42\nretry: 1", + Data: "ignored", + }), errInvalidField) +} + +func Test_SSE_EventOmitsWhitespaceOnlyFields(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ + ID: " ", + Name: " ", + Data: "ok", + })) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: ok\n\n", buf.String()) +} + +func Test_SSE_EventJSONEncodesData(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ + Name: "message", + Data: map[string]string{"hello": "world"}, + })) + require.NoError(t, w.Flush()) + + require.JSONEq(t, `{"hello":"world"}`, stringsTrimData(buf.String())) +} + +func Test_SSE_NewUsesAppJSONEncoder(t *testing.T) { + t.Parallel() + + app := fiber.New(fiber.Config{ + JSONEncoder: func(any) ([]byte, error) { + return []byte(`{"encoded":"custom"}`), nil + }, + }) + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(_ fiber.Ctx, stream *Stream) error { + return stream.Event(Event{Data: map[string]string{"hello": "world"}}) + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "data: {\"encoded\":\"custom\"}\n\n", string(body)) +} + +func Test_SSE_EventJSONEncodesTypedNilStringer(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + var value *panicStringer + + require.NoError(t, writeEvent(w, Event{Data: value})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: null\n\n", buf.String()) +} + +func Test_SSE_EventOmitsDataForUntypedNil(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{ID: "42"})) + require.NoError(t, w.Flush()) + + require.Equal(t, "id: 42\n\n", buf.String()) +} + +func Test_SSE_EventDoesNotWritePartialFrameWhenDataMarshalFails(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.Error(t, writeEvent(w, Event{ + ID: "42", + Name: "broken", + Data: func() {}, + })) + require.NoError(t, w.Flush()) + + require.Empty(t, buf.String()) +} + +func Test_SSE_EventWritesRawJSONData(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: json.RawMessage(`{"hello":"world"}`)})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: {\"hello\":\"world\"}\n\n", buf.String()) +} + +func Test_SSE_EventWritesByteSliceData(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: []byte("hello\nworld")})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: hello\ndata: world\n\n", buf.String()) +} + +func Test_SSE_EventTrimsSingleTrailingDataNewline(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: "hello\n"})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: hello\n\n", buf.String()) +} + +func Test_SSE_EventPreservesIntentionalBlankDataLine(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeEvent(w, Event{Data: "hello\n\n"})) + require.NoError(t, w.Flush()) + + require.Equal(t, "data: hello\ndata: \n\n", buf.String()) +} + +func Test_SSE_EventReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + w := bufio.NewWriterSize(errWriter{err: writeErr}, 1) + + require.ErrorIs(t, writeEvent(w, Event{Data: "hello"}), writeErr) +} + +func Test_SSE_CommentSanitizesLines(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + require.NoError(t, writeComment(w, " first\r\nsecond ")) + require.NoError(t, w.Flush()) + + require.Equal(t, ": first\n: second\n\n", buf.String()) +} + +func Test_SSE_CommentReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + w := bufio.NewWriterSize(errWriter{err: writeErr}, 1) + + require.ErrorIs(t, writeComment(w, ""), writeErr) +} + +func Test_SSE_WriteDataReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + w := bufio.NewWriterSize(errWriter{err: writeErr}, 1) + + require.ErrorIs(t, writeData(w, "hello"), writeErr) +} + +func Test_SSE_NewWritesHeadersAndEvents(t *testing.T) { + t.Parallel() + + app := fiber.New() + var capturedLastEventID string + app.Get("/events", New(Config{ + Retry: time.Second, + DisableHeartbeat: true, + Handler: func(_ fiber.Ctx, stream *Stream) error { + capturedLastEventID = stream.LastEventID() + return stream.Event(Event{Name: "ready", Data: "ok"}) + }, + })) + + req := httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody) + req.Header.Set("Last-Event-ID", "last-1") + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, "last-1", capturedLastEventID) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Equal(t, mimeTextEventStream, resp.Header.Get(fiber.HeaderContentType)) + require.Equal(t, "no-cache", resp.Header.Get(fiber.HeaderCacheControl)) + require.Equal(t, "keep-alive", resp.Header.Get(fiber.HeaderConnection)) + require.Equal(t, "no", resp.Header.Get("X-Accel-Buffering")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "retry: 1000\n\nevent: ready\ndata: ok\n\n", string(body)) +} + +func Test_SSE_NewWritesHeartbeat(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Get("/events", New(Config{ + HeartbeatInterval: 5 * time.Millisecond, + Handler: func(_ fiber.Ctx, stream *Stream) error { + select { + case <-time.After(150 * time.Millisecond): + return nil + case <-stream.Done(): + return stream.Err() + } + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), ":\n\n") +} + +func Test_SSE_OnCloseReceivesNilAfterNormalClose(t *testing.T) { + t.Parallel() + + closed := make(chan error, 1) + + app := fiber.New() + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(_ fiber.Ctx, stream *Stream) error { + return stream.Event(Event{Data: "ok"}) + }, + OnClose: func(_ fiber.Ctx, err error) { + closed <- err + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + select { + case err := <-closed: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("OnClose was not called") + } +} + +func Test_SSE_StreamComment(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + require.NoError(t, stream.Comment("hello")) + require.Equal(t, ": hello\n\n", buf.String()) +} + +func Test_SSE_StreamCommentReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriterSize(errWriter{err: writeErr}, 1), "") + + require.ErrorIs(t, stream.Comment("hello"), writeErr) + require.ErrorIs(t, stream.Err(), writeErr) +} + +func Test_SSE_StreamRetryIgnoresNonPositiveDuration(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + require.NoError(t, stream.Retry(0)) + require.NoError(t, stream.Retry(-time.Second)) + require.Empty(t, buf.String()) +} + +func Test_SSE_StreamRetryReturnsWriterError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriterSize(errWriter{err: writeErr}, 1), "") + + require.ErrorIs(t, stream.Retry(time.Second), writeErr) + require.ErrorIs(t, stream.Err(), writeErr) +} + +func Test_SSE_StreamConcurrentWrites(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + const writers = 16 + errs := make(chan error, writers) + var wg sync.WaitGroup + wg.Add(writers) + for i := 0; i < writers; i++ { + go func(data int) { + defer wg.Done() + errs <- stream.Event(Event{Name: "message", Data: data}) + }(i) + } + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + require.Equal(t, writers, strings.Count(buf.String(), "event: message\n")) + require.Equal(t, writers, strings.Count(buf.String(), "data: ")) +} + +func Test_SSE_StreamContextCanceledOnClose(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + stream.closeStream() + + select { + case <-stream.Context().Done(): + case <-time.After(time.Second): + t.Fatal("stream context was not canceled") + } +} + +func Test_SSE_NewStreamUsesBackgroundContextWhenNil(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(nil, bufio.NewWriter(&buf), "") //nolint:staticcheck // Covers the nil fallback branch in newStream. + defer stream.closeStream() + + require.NotNil(t, stream.Context()) + require.NoError(t, stream.Context().Err()) +} + +func Test_SSE_StreamErrAfterNormalClose(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + stream.closeStream() + + require.NoError(t, stream.Err()) + require.ErrorIs(t, stream.Event(Event{Data: "late"}), errStreamClosed) +} + +func Test_SSE_StreamReturnsLatchedError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriter(errWriter{err: writeErr}), "") + + require.ErrorIs(t, stream.Event(Event{Data: "hello"}), writeErr) + require.ErrorIs(t, stream.Event(Event{Data: "again"}), writeErr) +} + +func Test_SSE_StreamWriteError(t *testing.T) { + t.Parallel() + + writeErr := errors.New("write failed") + stream := newStream(context.Background(), bufio.NewWriter(errWriter{err: writeErr}), "") + + require.ErrorIs(t, stream.Event(Event{Data: "hello"}), writeErr) + require.ErrorIs(t, stream.Err(), writeErr) + select { + case <-stream.Done(): + case <-time.After(time.Second): + t.Fatal("stream was not closed after write error") + } +} + +func Test_SSE_NextSkipsMiddleware(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New(Config{ + Next: func(fiber.Ctx) bool { + return true + }, + Handler: func(_ fiber.Ctx, stream *Stream) error { + return stream.Event(Event{Data: "ignored"}) + }, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("next") + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "next", string(body)) +} + +func Test_SSE_HandlerErrorCallsOnClose(t *testing.T) { + t.Parallel() + + handlerErr := errors.New("boom") + closed := make(chan error, 1) + + app := fiber.New() + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(fiber.Ctx, *Stream) error { + return handlerErr + }, + OnClose: func(_ fiber.Ctx, err error) { + closed <- err + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + select { + case err := <-closed: + require.ErrorIs(t, err, handlerErr) + case <-time.After(time.Second): + t.Fatal("OnClose was not called") + } +} + +func Test_SSE_HandlerPanicCallsOnClose(t *testing.T) { + t.Parallel() + + closed := make(chan error, 1) + + app := fiber.New() + app.Get("/events", New(Config{ + DisableHeartbeat: true, + Handler: func(fiber.Ctx, *Stream) error { + panic("boom") + }, + OnClose: func(_ fiber.Ctx, err error) { + closed <- err + }, + })) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/events", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + select { + case err := <-closed: + require.EqualError(t, err, "sse: handler panic: boom") + case <-time.After(time.Second): + t.Fatal("OnClose was not called") + } +} + +func Test_SSE_NewPanicsWithoutHandler(t *testing.T) { + t.Parallel() + + require.PanicsWithValue(t, "sse: Handler must not be nil", func() { + New() + }) +} + +func Test_SSE_StopHeartbeatIsIdempotent(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + stop := stream.startHeartbeat(time.Hour) + require.NotNil(t, stop) + require.NotPanics(t, stop) + require.NotPanics(t, stop) +} + +func Test_SSE_StartHeartbeatReturnsNilForNonPositiveInterval(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + stream := newStream(context.Background(), bufio.NewWriter(&buf), "") + defer stream.closeStream() + + require.Nil(t, stream.startHeartbeat(0)) + require.Nil(t, stream.startHeartbeat(-time.Second)) +} + +func stringsTrimData(frame string) string { + const prefix = "event: message\ndata: " + frame = strings.TrimPrefix(frame, prefix) + return strings.TrimSuffix(frame, "\n\n") +} + +type errWriter struct { + err error +} + +func (w errWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("test writer: %w", w.err) +} diff --git a/req_interface_gen.go b/req_interface_gen.go index d150c0ec9e7..59841e7a689 100644 --- a/req_interface_gen.go +++ b/req_interface_gen.go @@ -52,6 +52,8 @@ type Req interface { // https://godoc.org/github.com/valyala/fasthttp#Request Request() *fasthttp.Request // FormFile returns the first file by key from a MultipartForm. + // The multipart form is parsed using the application's BodyLimit to prevent + // unbounded memory usage. FormFile(key string) (*multipart.FileHeader, error) // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. @@ -59,6 +61,8 @@ type Req interface { // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. + // When the request is a multipart form, it is parsed using the application's + // BodyLimit so the configured limit is consistently enforced. FormValue(key string, defaultValue ...string) string // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale diff --git a/res_interface_gen.go b/res_interface_gen.go index ebe1caefe2f..0d4aa12e889 100644 --- a/res_interface_gen.go +++ b/res_interface_gen.go @@ -45,7 +45,7 @@ type Res interface { // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. - // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a

element. + // When text/html is selected, the body is treated as plain text and HTML-escaped before being wrapped in a `

` element. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. AutoFormat(body any) error @@ -99,7 +99,8 @@ type Res interface { // ViewBind Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. ViewBind(vars Map) error - // getLocationFromRoute get URL location from route using parameters + // getLocationFromRoute gets the URL location from a route using parameters. + // Nil receivers and missing routes return ErrNotFound to match Route.URL semantics. getLocationFromRoute(route *Route, params Map) (string, error) // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error)