diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md new file mode 100644 index 00000000000..38dff0a0bd4 --- /dev/null +++ b/docs/middleware/sse.md @@ -0,0 +1,165 @@ +--- +id: sse +--- + +# SSE + +Server-Sent Events middleware for [Fiber](https://github.com/gofiber/fiber) built natively on Fiber's fasthttp architecture. It provides a Hub-based event broker with topic routing, three priority lanes (instant/batched/coalesced), NATS-style topic wildcards, adaptive per-connection throttling, connection groups, graceful drain, and pluggable Last-Event-ID replay. + +The middleware is fully compatible with the standard SSE wire format — any client that speaks Server-Sent Events (browser `EventSource`, `curl -N`, or any HTTP client that reads `text/event-stream`) works with it. + +## Signatures + +```go +func New(config ...Config) fiber.Handler +func NewWithHub(config ...Config) (fiber.Handler, *Hub) +``` + +`New` returns just the handler; use `NewWithHub` when you need access to the hub for publishing events. + +## Examples + +Import the middleware package: + +```go +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/sse" +) +``` + +Once your Fiber app is initialized, create an SSE handler and hub: + +```go +// Basic usage — subscribe all clients to "notifications" +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { + conn.Topics = []string{"notifications"} + return nil + }, +}) +app.Get("/events", handler) + +// Publish an event from any goroutine +hub.Publish(sse.Event{ + Type: "update", + Data: "hello", + Topics: []string{"notifications"}, +}) +``` + +Use NATS-style wildcards to subscribe to multiple related topics: + +```go +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { + // Match orders.created, orders.updated, orders.deleted + conn.Topics = []string{"orders.*"} + return nil + }, +}) +``` + +Use connection groups (metadata-based filtering) for multi-tenant isolation: + +```go +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { + tenantID := c.Locals("tenant_id").(string) + conn.Metadata["tenant_id"] = tenantID + conn.Topics = []string{"orders"} + return nil + }, +}) + +// Publish only to connections in tenant "t_123" +hub.Publish(sse.Event{ + Type: "order-created", + Data: orderJSON, + Topics: []string{"orders"}, + Group: map[string]string{"tenant_id": "t_123"}, +}) +``` + +Use event coalescing to reduce traffic for high-frequency updates: + +```go +// Coalesced: if progress goes 5%→8% in one flush window, +// only the latest value is sent. +for i := 1; i <= 100; i++ { + hub.Publish(sse.Event{ + Type: "progress", + Data: fmt.Sprintf(`{"pct":%d}`, i), + Topics: []string{"import"}, + Priority: sse.PriorityCoalesced, + CoalesceKey: "import-progress", + }) +} +``` + +Fan out from an external pub/sub system (Redis, NATS, etc.) into the hub. Implement the `SubscriberBridge` interface and declare it on `Config.Bridges` — the middleware auto-starts each bridge and cancels/awaits them on `hub.Shutdown`, so there are no `CancelFunc`s for the caller to track. + +```go +type redisSubscriber struct{ client *redis.Client } + +func (r *redisSubscriber) Subscribe(ctx context.Context, channel string, onMessage func(string)) error { + sub := r.client.Subscribe(ctx, channel) + defer sub.Close() + for msg := range sub.Channel() { + onMessage(msg.Payload) + } + return ctx.Err() +} + +handler, hub := sse.NewWithHub(sse.Config{ + Bridges: []sse.BridgeConfig{{ + Subscriber: &redisSubscriber{client: rdb}, + Channel: "notifications", + Topic: "notifications", + EventType: "notification", + }}, +}) +app.Get("/events", handler) +``` + +Graceful shutdown with deadline: + +```go +ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +defer cancel() +if err := hub.Shutdown(ctx); err != nil { + log.Errorf("sse drain failed: %v", err) +} +``` + +Authentication is left to the user via `OnConnect`. Note that browser `EventSource` cannot send custom headers, so if you need token authentication, consider passing the token via a query parameter or a short-lived ticket exchanged on a separate endpoint. + +## Config + +| Property | Type | Description | Default | +| :---------------- | :------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------- | :------------- | +| OnConnect | `func(fiber.Ctx, *Connection) error` | Called when a new client connects. Set `conn.Topics` and `conn.Metadata` here. Return error to reject (sends 403). | `nil` | +| OnDisconnect | `func(*Connection)` | Called after a client disconnects. | `nil` | +| OnPause | `func(*Connection)` | Called when a connection is paused (browser tab hidden). | `nil` | +| OnResume | `func(*Connection)` | Called when a connection is resumed (browser tab visible). | `nil` | +| Replayer | `Replayer` | Pluggable Last-Event-ID replay backend. If nil, replay is disabled. | `nil` | +| Bridges | `[]BridgeConfig` | Auto-started bridges from external pub/sub systems. Each implements `SubscriberBridge`. Canceled on `hub.Shutdown`. | `nil` | +| FlushInterval | `time.Duration` | How often batched (P1) and coalesced (P2) events are flushed to clients. Instant (P0) events bypass this. | `2s` | +| HeartbeatInterval | `time.Duration` | How often a comment is sent to idle connections to detect disconnects and prevent proxy timeouts. | `30s` | +| MaxLifetime | `time.Duration` | Maximum duration a single SSE connection can stay open. Set to -1 for unlimited. | `30m` | +| SendBufferSize | `int` | Per-connection channel buffer. If full, events are dropped. | `256` | +| RetryMS | `int` | Reconnection interval hint sent to clients via the `retry:` directive on connect. | `3000` | + +The SSE middleware is **terminal** — the returned handler hijacks the response stream and never calls `c.Next()`. For the same reason `Config` does not include a `Next` field: placing handlers after the SSE middleware has no defined effect. + +## Default Config + +```go +var ConfigDefault = Config{ + FlushInterval: 2 * time.Second, + SendBufferSize: 256, + HeartbeatInterval: 30 * time.Second, + MaxLifetime: 30 * time.Minute, + RetryMS: 3000, +} +``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 3ee2d2b2981..2272510726e 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -57,6 +57,7 @@ Here's a quick overview of the changes in Fiber `v3`: - [Proxy](#proxy) - [Recover](#recover) - [Session](#session) + - [SSE](#sse) - [🔌 Addons](#-addons) - [📋 Migration guide](#-migration-guide) @@ -3155,3 +3156,24 @@ app.Use(session.New(session.Config{ See the [Session Middleware Migration Guide](./middleware/session.md#migration-guide) for complete details. + +#### SSE + +The new SSE middleware provides Server-Sent Events for Fiber, built natively on the fasthttp `SendStreamWriter` API. It includes a Hub-based broker with topic routing, three priority lanes (instant/batched/coalesced), NATS-style topic wildcards, connection groups for metadata-based filtering, adaptive throttling, graceful drain, and pluggable Last-Event-ID replay. Fully compatible with the standard SSE wire format and any `EventSource`-style client. + +```go +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { + conn.Topics = []string{"notifications"} + return nil + }, +}) +app.Get("/events", handler) + +// Publish from any handler or worker +hub.Publish(sse.Event{ + Type: "update", + Data: "hello", + Topics: []string{"notifications"}, +}) +``` diff --git a/middleware/sse/bridge.go b/middleware/sse/bridge.go new file mode 100644 index 00000000000..39a6097d541 --- /dev/null +++ b/middleware/sse/bridge.go @@ -0,0 +1,143 @@ +package sse + +import ( + "context" + "time" + + "github.com/gofiber/fiber/v3/log" +) + +// bridgeRetryDelay is how long the hub waits before retrying a failed +// SubscriberBridge.Subscribe call. Package-level var (not const) so tests +// can shorten it to observe retry behavior deterministically. +var bridgeRetryDelay = 3 * time.Second + +// SubscriberBridge adapts an external pub/sub system (Redis, NATS, Kafka, +// etc.) so incoming messages can be forwarded into the hub as SSE events. +// +// Implementations must block until ctx is canceled and return ctx.Err() +// so the hub can distinguish intentional shutdown from subscriber failure. +type SubscriberBridge interface { + // Subscribe listens on channel and invokes onMessage for each received + // payload. It must return when ctx is canceled. + Subscribe(ctx context.Context, channel string, onMessage func(payload string)) error +} + +// BridgeConfig wires a SubscriberBridge into the hub. Populate one of these +// for each external channel you want to forward events from. +type BridgeConfig struct { + // Subscriber is the pub/sub implementation. Required. + Subscriber SubscriberBridge + + // Transform optionally transforms the raw payload into a fully-formed + // Event. Return nil to skip the message. If Transform is nil, the + // payload is used as Event.Data with the defaults below. + Transform func(payload string) *Event + + // Channel is the pub/sub channel to subscribe to. Required. + Channel string + + // Topic is the SSE topic forwarded events are tagged with. + // Defaults to Channel if empty. + Topic string + + // EventType is the SSE event: field set on forwarded events. + EventType string + + // CoalesceKey for PriorityCoalesced events. + CoalesceKey string + + // TTL for forwarded events. Zero means no expiration. + TTL time.Duration + + // Priority for forwarded events. PriorityInstant (0) is the default. + Priority Priority +} + +// runBridge consumes a single BridgeConfig, publishing incoming payloads +// until ctx is canceled. Retries on Subscribe errors with bridgeRetryDelay. +// +// A nil Subscriber is a programming error caught at hub startup (see +// NewWithHub) so runBridge assumes cfg.Subscriber is non-nil. +func (h *Hub) runBridge(ctx context.Context, cfg BridgeConfig) { //nolint:gocritic // hugeParam: value semantics preferred + topic := cfg.Topic + if topic == "" { + topic = cfg.Channel + } + + for { + select { + case <-ctx.Done(): + return + default: + } + + // Wrap the callback in a recover so a panic inside the caller- + // supplied Transform can't tear down the bridge goroutine (which + // would leak h.bridges.Done() and hang Shutdown forever). + err := cfg.Subscriber.Subscribe(ctx, cfg.Channel, func(payload string) { + defer func() { + if r := recover(); r != nil { + log.Errorf("sse: bridge transform panic, message dropped channel=%s panic=%v", + cfg.Channel, r) + } + }() + if event := h.buildBridgeEvent(&cfg, topic, payload); event != nil { + h.Publish(*event) + } + }) + + if ctx.Err() != nil { + return + } + + // Any early return — error or unexpected nil from a well-behaved + // subscriber — is treated as retryable. Without the backoff on + // nil, a misbehaving subscriber that returns immediately would + // spin this loop hot. + if err != nil { + logBridgeError(cfg.Channel, err) + } + select { + case <-time.After(bridgeRetryDelay): + case <-ctx.Done(): + return + } + } +} + +// buildBridgeEvent creates an Event from a raw pub/sub payload. +// When Transform is set, the transform function controls all event fields; +// only missing Topics and Type are filled from config defaults. +// When Transform is not set, the event is built entirely from config defaults. +func (*Hub) buildBridgeEvent(cfg *BridgeConfig, topic, payload string) *Event { + if cfg.Transform != nil { + transformed := cfg.Transform(payload) + if transformed == nil { + return nil + } + event := *transformed + if len(event.Topics) == 0 { + event.Topics = []string{topic} + } + if event.Type == "" { + event.Type = cfg.EventType + } + return &event + } + + return &Event{ + Type: cfg.EventType, + Data: payload, + Topics: []string{topic}, + Priority: cfg.Priority, + CoalesceKey: cfg.CoalesceKey, + TTL: cfg.TTL, + } +} + +// logBridgeError logs a bridge subscriber error. Retries continue after +// bridgeRetryDelay regardless of error type. +func logBridgeError(channel string, err error) { + log.Warnf("sse: bridge subscriber error, retrying channel=%s error=%v", channel, err) +} diff --git a/middleware/sse/config.go b/middleware/sse/config.go new file mode 100644 index 00000000000..82be4ef1442 --- /dev/null +++ b/middleware/sse/config.go @@ -0,0 +1,117 @@ +package sse + +import ( + "time" + + "github.com/gofiber/fiber/v3" +) + +// Config defines the configuration for the SSE middleware. +// +// The SSE middleware is terminal: it hijacks the response stream and +// never calls c.Next(). Placing handlers after sse.New() in the chain +// results in undefined behavior because Fiber releases the fiber.Ctx +// before the stream writer runs. +type Config struct { + // OnConnect is called when a new client connects, before the SSE + // stream begins. Use it for authentication, topic selection, and + // connection limits. Set conn.Topics and conn.Metadata here. + // Return a non-nil error to reject the connection (sends 403). + // + // Optional. Default: nil + OnConnect func(c fiber.Ctx, conn *Connection) error + + // OnDisconnect is called after a client disconnects. + // + // Optional. Default: nil + OnDisconnect func(conn *Connection) + + // OnPause is called when a connection is paused (browser tab hidden). + // + // Optional. Default: nil + OnPause func(conn *Connection) + + // OnResume is called when a connection is resumed (browser tab visible). + // + // Optional. Default: nil + OnResume func(conn *Connection) + + // Replayer enables Last-Event-ID replay. If nil, replay is disabled. + // + // Optional. Default: nil + Replayer Replayer + + // Bridges declares external pub/sub sources (Redis, NATS, etc.) that + // feed events into the hub. Bridges start automatically when the SSE + // middleware/hub is created (for example, via NewWithHub), not lazily + // when a handler is mounted, and stop on Shutdown. + // + // Optional. Default: nil + Bridges []BridgeConfig + + // FlushInterval is how often batched (P1) and coalesced (P2) events + // are flushed to clients. Instant (P0) events bypass this. + // + // Optional. Default: 2s + FlushInterval time.Duration + + // HeartbeatInterval is how often a comment is sent to idle connections + // to detect disconnects and prevent proxy timeouts. + // + // Optional. Default: 30s + HeartbeatInterval time.Duration + + // MaxLifetime is the maximum duration a single SSE connection can + // stay open. After this, the connection is closed gracefully. + // Set to -1 for unlimited. + // + // Optional. Default: 30m + MaxLifetime time.Duration + + // SendBufferSize is the per-connection channel buffer. If full, + // events are dropped and the client should reconnect. + // + // Optional. Default: 256 + SendBufferSize int + + // RetryMS is the reconnection interval hint sent to clients via the + // retry: directive on connect. + // + // Optional. Default: 3000 + RetryMS int +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{ + FlushInterval: 2 * time.Second, + SendBufferSize: 256, + HeartbeatInterval: 30 * time.Second, + MaxLifetime: 30 * time.Minute, + RetryMS: 3000, +} + +func configDefault(config ...Config) Config { + if len(config) < 1 { + return ConfigDefault + } + + cfg := config[0] + + if cfg.FlushInterval <= 0 { + cfg.FlushInterval = ConfigDefault.FlushInterval + } + if cfg.SendBufferSize <= 0 { + cfg.SendBufferSize = ConfigDefault.SendBufferSize + } + if cfg.HeartbeatInterval <= 0 { + cfg.HeartbeatInterval = ConfigDefault.HeartbeatInterval + } + if cfg.MaxLifetime == 0 { + cfg.MaxLifetime = ConfigDefault.MaxLifetime + } + if cfg.RetryMS <= 0 { + cfg.RetryMS = ConfigDefault.RetryMS + } + + return cfg +} diff --git a/middleware/sse/connection.go b/middleware/sse/connection.go new file mode 100644 index 00000000000..f49812adddd --- /dev/null +++ b/middleware/sse/connection.go @@ -0,0 +1,132 @@ +package sse + +import ( + "bufio" + "sync" + "sync/atomic" + "time" +) + +// Connection represents a single SSE client connection managed by the hub. +type Connection struct { + CreatedAt time.Time + LastEventID atomic.Value + lastWrite atomic.Value + send chan MarshaledEvent + heartbeat chan struct{} + done chan struct{} + dispatcher *dispatcher + // Metadata holds connection metadata set during OnConnect. + // It is frozen (defensive-copied) after OnConnect returns -- do not + // mutate it from other goroutines after the connection is registered. + Metadata map[string]string + ID string + Topics []string + MessagesSent atomic.Int64 + MessagesDropped atomic.Int64 + once sync.Once + paused atomic.Bool +} + +// newConnection creates a Connection with the given buffer size. +func newConnection(id string, topics []string, bufferSize int, flushInterval time.Duration) *Connection { + c := &Connection{ + ID: id, + Topics: topics, + Metadata: make(map[string]string), + CreatedAt: time.Now(), + send: make(chan MarshaledEvent, bufferSize), + heartbeat: make(chan struct{}, 1), + done: make(chan struct{}), + } + c.lastWrite.Store(time.Now()) + c.LastEventID.Store("") + c.dispatcher = newDispatcher(flushInterval) + return c +} + +// Close terminates the connection. Safe to call multiple times. +func (c *Connection) Close() { + c.once.Do(func() { + close(c.done) + }) +} + +// IsClosed returns true if the connection has been terminated. +func (c *Connection) IsClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} + +// trySend attempts to deliver an event to the connection's send channel. +// Returns false if the buffer is full (backpressure). +func (c *Connection) trySend(me MarshaledEvent) bool { //nolint:gocritic // hugeParam: value semantics for channel send + select { + case c.send <- me: + return true + default: + c.MessagesDropped.Add(1) + return false + } +} + +// sendHeartbeat sends a heartbeat signal to the connection. +// Non-blocking — if a heartbeat is already pending it is silently dropped. +func (c *Connection) sendHeartbeat() { + select { + case c.heartbeat <- struct{}{}: + default: + } +} + +// writeLoop runs inside Fiber's SendStreamWriter. It reads from the send +// and heartbeat channels, writing SSE-formatted events to the bufio.Writer. +func (c *Connection) writeLoop(w *bufio.Writer) { + for { + select { + case <-c.done: + return + case <-c.heartbeat: + if err := writeComment(w, "heartbeat"); err != nil { + c.Close() + return + } + if err := w.Flush(); err != nil { + c.Close() + return + } + case me, ok := <-c.send: + if !ok { + return + } + if _, err := me.WriteTo(w); err != nil { + c.Close() + return + } + if err := w.Flush(); err != nil { + c.Close() + return + } + c.MessagesSent.Add(1) + c.lastWrite.Store(time.Now()) + if me.ID != "" { + c.LastEventID.Store(me.ID) + } + } + } +} + +// connMatchesGroup returns true if ALL key-value pairs in the group +// match the connection's metadata. +func connMatchesGroup(conn *Connection, group map[string]string) bool { + for k, v := range group { + if conn.Metadata[k] != v { + return false + } + } + return true +} diff --git a/middleware/sse/dispatcher.go b/middleware/sse/dispatcher.go new file mode 100644 index 00000000000..28968ed7a87 --- /dev/null +++ b/middleware/sse/dispatcher.go @@ -0,0 +1,104 @@ +package sse + +import ( + "sync" + "time" +) + +// dispatcher is a per-connection two-lane queue feeding the write loop. +// +// Lane 1 (events): FIFO buffer of batched P1 events. Each AddEvent call +// appends a distinct event; all are emitted on flush in insertion order. +// +// Lane 2 (state): last-writer-wins map keyed by CoalesceKey for P2 events. +// Duplicate keys overwrite the prior value, so only the latest state is +// delivered. First-seen order is preserved across keys for deterministic +// output. +// +// This split mirrors SSE usage patterns: events are discrete happenings +// (notifications, log lines, messages) that must all reach the client, +// while state is the current value of something (progress %, online +// users count, cursor position) where only the newest snapshot matters. +type dispatcher struct { + // state holds P2 events keyed by CoalesceKey. + state map[string]MarshaledEvent + + // events holds P1 events in insertion order. + events []MarshaledEvent + + // stateOrder preserves first-seen order of coalesce keys. + stateOrder []string + + mu sync.Mutex + + // flushInterval is the target flush cadence (informational). + flushInterval time.Duration +} + +// newDispatcher creates a dispatcher with the given flush interval hint. +func newDispatcher(flushInterval time.Duration) *dispatcher { + return &dispatcher{ + state: make(map[string]MarshaledEvent), + events: make([]MarshaledEvent, 0, 16), + flushInterval: flushInterval, + } +} + +// AddEvent appends a P1 event to the events lane. All added events are +// sent on the next flush in insertion order. +func (d *dispatcher) AddEvent(me MarshaledEvent) { //nolint:gocritic // hugeParam: value semantics match WriteTo() return type + d.mu.Lock() + d.events = append(d.events, me) + d.mu.Unlock() +} + +// AddState upserts a P2 state event keyed by CoalesceKey. If the key +// already has a pending value, the previous value is overwritten +// (last-writer-wins). +func (d *dispatcher) AddState(key string, me MarshaledEvent) { //nolint:gocritic // hugeParam: value semantics match WriteTo() return type + d.mu.Lock() + if _, exists := d.state[key]; !exists { + d.stateOrder = append(d.stateOrder, key) + } + d.state[key] = me + d.mu.Unlock() +} + +// WriteTo drains both lanes and returns the events to write, in order: +// queued events first, then state values in first-seen key order. +func (d *dispatcher) WriteTo() []MarshaledEvent { + d.mu.Lock() + defer d.mu.Unlock() + + eventsLen := len(d.events) + stateLen := len(d.stateOrder) + + if eventsLen == 0 && stateLen == 0 { + return nil + } + + result := make([]MarshaledEvent, 0, eventsLen+stateLen) + + if eventsLen > 0 { + result = append(result, d.events...) + d.events = d.events[:0] + } + + if stateLen > 0 { + for _, key := range d.stateOrder { + result = append(result, d.state[key]) + } + d.state = make(map[string]MarshaledEvent, stateLen) + d.stateOrder = d.stateOrder[:0] + } + + return result +} + +// pending returns the total number of queued events and state updates. +func (d *dispatcher) pending() int { + d.mu.Lock() + n := len(d.events) + len(d.stateOrder) + d.mu.Unlock() + return n +} diff --git a/middleware/sse/event.go b/middleware/sse/event.go new file mode 100644 index 00000000000..8b6fc244330 --- /dev/null +++ b/middleware/sse/event.go @@ -0,0 +1,210 @@ +package sse + +import ( + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/valyala/bytebufferpool" +) + +// Priority controls how an event is delivered to clients. +type Priority int + +const ( + // PriorityInstant bypasses all buffering — the event is written to the + // client connection immediately. Use for errors, auth revocations, + // force-refresh commands, and chat messages. + PriorityInstant Priority = 0 + + // PriorityBatched collects events in a time window (FlushInterval) and + // sends them all at once. Use for status changes, media updates. + PriorityBatched Priority = 1 + + // PriorityCoalesced uses last-writer-wins per CoalesceKey. Multiple + // events with the same key within a flush window are merged — only the + // latest is sent. Use for progress bars, live counters, typing indicators. + PriorityCoalesced Priority = 2 +) + +// Event represents a single SSE event to be published through the hub. +type Event struct { + CreatedAt time.Time + Data any + Group map[string]string + Type string + ID string + CoalesceKey string + Topics []string + TTL time.Duration + Priority Priority +} + +// globalEventID is an auto-incrementing counter for event IDs. +var globalEventID atomic.Uint64 + +// nextEventID returns a monotonically increasing event ID string. +func nextEventID() string { + return "evt_" + strconv.FormatUint(globalEventID.Add(1), 10) +} + +// MarshaledEvent is the wire-ready representation of an SSE event. +// External Replayer implementations receive and return this type. +type MarshaledEvent struct { + // CreatedAt is the timestamp of the source Event (zero if unset). + CreatedAt time.Time + ID string + Type string + Data string + // TTL is the maximum age for this event. Zero means no expiry. + TTL time.Duration + // Retry is the reconnection hint (milliseconds) sent to clients. Zero + // or negative values are omitted from the wire frame — per the SSE + // spec `retry: 0` instructs clients to reconnect immediately, which + // could trigger reconnect storms, so only strictly positive values + // are emitted. + Retry int +} + +// sanitizeSSEField strips carriage returns and newlines from SSE control +// fields (id, event) to prevent SSE injection attacks. An attacker-controlled +// value containing \r or \n could break SSE framing and inject fake events. +func sanitizeSSEField(s string) string { + return strings.NewReplacer("\r\n", "", "\r", "", "\n", "").Replace(s) +} + +// normalizeSSEDataTerminators is used on the data field to convert any CR or +// CRLF sequence into LF before we split on line boundaries. The HTML SSE spec +// treats all three as valid line terminators, so we must emit one "data:" per +// logical line regardless of which terminator the caller used. +// Order matters: CRLF must be replaced first so we don't double-split. +var normalizeSSEDataTerminators = strings.NewReplacer("\r\n", "\n", "\r", "\n") + +// marshalEvent converts an Event into wire-ready format. +func marshalEvent(e *Event) MarshaledEvent { + me := MarshaledEvent{ + ID: sanitizeSSEField(e.ID), + Type: sanitizeSSEField(e.Type), + CreatedAt: e.CreatedAt, + TTL: e.TTL, + Retry: -1, + } + + if me.ID == "" { + me.ID = nextEventID() + } + + switch v := e.Data.(type) { + case nil: + me.Data = "" + case string: + me.Data = v + case []byte: + me.Data = string(v) + default: + // All other types flow through json.Marshal. The previous explicit + // json.Marshaler branch panicked on typed-nil pointers (e.g. + // `var p *Foo = nil` where *Foo has a MarshalJSON that dereferences + // the receiver) because the type switch matches the interface. + // json.Marshal handles typed-nil safely — it checks before invoking + // the method and emits `null` for nil pointers. + b, err := json.Marshal(v) + if err != nil { + errJSON, _ := json.Marshal(err.Error()) //nolint:errcheck,errchkjson // encoding a string never fails + me.Data = fmt.Sprintf(`{"error":%s}`, string(errJSON)) + } else { + me.Data = string(b) + } + } + + return me +} + +// WriteTo writes the SSE-formatted event to w following the Server-Sent +// Events specification. It assembles the frame in a pooled buffer so the +// hot path performs a single Write syscall with zero fmt allocations. +func (me *MarshaledEvent) WriteTo(w io.Writer) (int64, error) { + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + // Sanitize control-sequence fields at the write boundary — not just in + // marshalEvent — so external Replayer implementations returning raw + // MarshaledEvent values can't inject extra SSE fields via embedded + // \r/\n in ID or Type. Defense in depth: WriteTo is the last line + // between an event and the client. + if id := sanitizeSSEField(me.ID); id != "" { + buf.WriteString("id: ") + buf.WriteString(id) + buf.WriteByte('\n') + } + if evtType := sanitizeSSEField(me.Type); evtType != "" { + buf.WriteString("event: ") + buf.WriteString(evtType) + buf.WriteByte('\n') + } + // Retry must be strictly positive to be emitted. Per the SSE spec a + // `retry: 0` directive tells clients to reconnect immediately, which can + // trigger reconnect storms if a Replayer implementation accidentally + // constructs MarshaledEvent without setting Retry (its zero value is 0). + // Treating 0 as "unset" matches the internal marshalEvent default. + if me.Retry > 0 { + buf.WriteString("retry: ") + buf.WriteString(strconv.Itoa(me.Retry)) + buf.WriteByte('\n') + } + + // Normalise CR and CRLF to LF so a caller-supplied "\r" or "\r\n" + // produces one data line per logical line rather than a single line + // containing raw control characters (the HTML SSE parser treats all + // three as line terminators and would mis-frame the client). + // strings.SplitSeq("", "\n") yields "", correctly writing "data: \n" + // for empty data. + data := normalizeSSEDataTerminators.Replace(me.Data) + for line := range strings.SplitSeq(data, "\n") { + buf.WriteString("data: ") + buf.WriteString(line) + buf.WriteByte('\n') + } + buf.WriteByte('\n') + + n, err := w.Write(buf.B) + if err != nil { + return int64(n), fmt.Errorf("sse: write frame: %w", err) + } + return int64(n), nil +} + +// writeComment writes an SSE comment line. +func writeComment(w io.Writer, text string) error { + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + buf.WriteString(": ") + buf.WriteString(text) + buf.WriteString("\n\n") + if _, err := w.Write(buf.B); err != nil { + return fmt.Errorf("sse: write comment: %w", err) + } + return nil +} + +// writeRetry writes the retry directive. Non-positive ms values are +// silently skipped — per the SSE spec `retry: 0` tells clients to +// reconnect immediately, matching the MarshaledEvent.WriteTo semantics. +func writeRetry(w io.Writer, ms int) error { + if ms <= 0 { + return nil + } + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + buf.WriteString("retry: ") + buf.WriteString(strconv.Itoa(ms)) + buf.WriteString("\n\n") + if _, err := w.Write(buf.B); err != nil { + return fmt.Errorf("sse: write retry: %w", err) + } + return nil +} diff --git a/middleware/sse/example_test.go b/middleware/sse/example_test.go new file mode 100644 index 00000000000..0260ba2ebcc --- /dev/null +++ b/middleware/sse/example_test.go @@ -0,0 +1,96 @@ +package sse + +import ( + "context" + "fmt" + "time" + + "github.com/gofiber/fiber/v3" +) + +func Example() { + app := fiber.New() + + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"notifications"} + return nil + }, + }) + + app.Get("/events", handler) + + // Publish from any handler or worker + hub.Publish(Event{ + Type: "update", + Data: map[string]string{"message": "hello"}, + Topics: []string{"notifications"}, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := hub.Shutdown(ctx); err != nil { + panic(err) + } + + fmt.Println("Hub created and shut down successfully") //nolint:errcheck // example test output + // Output: Hub created and shut down successfully +} + +func Example_priorities() { + _, hub := NewWithHub() + + // Instant: delivered immediately, bypasses buffering + hub.Publish(Event{ + Type: "alert", + Data: "critical", + Topics: []string{"alerts"}, + Priority: PriorityInstant, + }) + + // Coalesced: last-writer-wins per CoalesceKey within flush window + for i := 1; i <= 100; i++ { + hub.Publish(Event{ + Type: "progress", + Data: fmt.Sprintf(`{"pct":%d}`, i), + Topics: []string{"progress"}, + Priority: PriorityCoalesced, + CoalesceKey: "import", + }) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := hub.Shutdown(ctx); err != nil { + panic(err) + } + + fmt.Println("Events published") //nolint:errcheck // example test output + // Output: Events published +} + +func Example_topicWildcards() { + _, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + // Subscribe to all order events using NATS-style wildcard + conn.Topics = []string{"orders.*"} + return nil + }, + }) + + // These will all match orders.* + hub.Publish(Event{Type: "created", Topics: []string{"orders.created"}}) + hub.Publish(Event{Type: "updated", Topics: []string{"orders.updated"}}) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := hub.Shutdown(ctx); err != nil { + panic(err) + } + + fmt.Println("Wildcard subscription example") //nolint:errcheck // example test output + // Output: Wildcard subscription example +} diff --git a/middleware/sse/hub.go b/middleware/sse/hub.go new file mode 100644 index 00000000000..1ee8836847e --- /dev/null +++ b/middleware/sse/hub.go @@ -0,0 +1,586 @@ +package sse + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gofiber/fiber/v3/log" +) + +// Hub is the central SSE event broker. It manages client connections, +// event routing, coalescing, and delivery. All methods are goroutine-safe. +type Hub struct { + throttler *adaptiveThrottler + connections map[string]*Connection + topicIndex map[string]map[string]struct{} + wildcardConns map[string]struct{} + register chan *Connection + unregister chan *Connection + events chan Event + shutdown chan struct{} + stopped chan struct{} + metrics hubMetrics + cfg Config + bridges sync.WaitGroup + mu sync.RWMutex + shutdownOnce sync.Once + draining atomic.Bool +} + +// newHub constructs a Hub from resolved Config and starts the run loop. +// Bridges (if any) are started and tracked in hub.bridges; Shutdown waits +// for all of them to finish before reporting stopped. +func newHub(cfg Config) *Hub { //nolint:gocritic // hugeParam: internal constructor, single call site, avoids pointer escape + hub := &Hub{ + cfg: cfg, + register: make(chan *Connection, 64), + unregister: make(chan *Connection, 64), + events: make(chan Event, 1024), + shutdown: make(chan struct{}), + connections: make(map[string]*Connection), + topicIndex: make(map[string]map[string]struct{}), + wildcardConns: make(map[string]struct{}), + throttler: newAdaptiveThrottler(cfg.FlushInterval), + metrics: hubMetrics{eventsByType: make(map[string]*atomic.Int64)}, + stopped: make(chan struct{}), + } + + // Validate every BridgeConfig BEFORE launching ANY goroutine — + // including hub.run(). A panic from here must leak nothing: no run + // loop, no bridge workers, no cancel function. Starting hub.run() + // first would leave a zombie goroutine alive after the panic, + // because NewWithHub never returns the hub to its caller for + // Shutdown. + for i, bc := range cfg.Bridges { + if bc.Subscriber == nil { + panic(fmt.Sprintf("sse: BridgeConfig.Subscriber at index %d must not be nil", i)) + } + if bc.Channel == "" { + panic(fmt.Sprintf("sse: BridgeConfig.Channel at index %d must not be empty", i)) + } + } + + go hub.run() + + if len(cfg.Bridges) > 0 { + ctx, cancel := context.WithCancel(context.Background()) + // Tie cancel to hub.shutdown so a single close(h.shutdown) during + // Shutdown also cancels the bridges' context. Keeping cancel + // visible in a goroutine (rather than just storing a CancelFunc + // on the struct) lets gosec G118 see that it's always invoked. + go func() { + <-hub.shutdown + cancel() + }() + for _, bc := range cfg.Bridges { + hub.bridges.Add(1) + go func(cfg BridgeConfig) { + defer hub.bridges.Done() + hub.runBridge(ctx, cfg) + }(bc) + } + } + + return hub +} + +// Publish sends an event to all connections subscribed to the event's topics. +// This method is goroutine-safe and non-blocking. If the internal event buffer +// is full, the event is dropped and eventsDropped is incremented. +func (h *Hub) Publish(event Event) { //nolint:gocritic // hugeParam: public API, value semantics preferred + // Reject early if the hub is draining. Without this, a concurrent + // Shutdown() can race with Publish() and enqueue an event the run + // loop will never dispatch — inflating EventsPublished and leaving + // the caller under the false impression the event was delivered. + if h.draining.Load() { + h.metrics.eventsDropped.Add(1) + return + } + if event.TTL > 0 && event.CreatedAt.IsZero() { + event.CreatedAt = time.Now() + } + select { + case h.events <- event: + h.metrics.eventsPublished.Add(1) + case <-h.shutdown: + // Hub is shutting down, discard + h.metrics.eventsDropped.Add(1) + default: + // Buffer full — drop event to avoid blocking callers. + h.metrics.eventsDropped.Add(1) + } +} + +// SetPaused pauses or resumes a connection by ID. Paused connections +// skip P1/P2 events (visibility hint for hidden browser tabs). +// P0 (instant) events are always delivered regardless. +func (h *Hub) SetPaused(connID string, paused bool) { //nolint:revive // flag-parameter: public API toggle + h.mu.RLock() + conn, ok := h.connections[connID] + h.mu.RUnlock() + if ok { + wasPaused := conn.paused.Swap(paused) + if paused && !wasPaused && h.cfg.OnPause != nil { + h.cfg.OnPause(conn) + } + if !paused && wasPaused && h.cfg.OnResume != nil { + h.cfg.OnResume(conn) + } + } +} + +// Shutdown gracefully drains all connections and stops the hub. Any +// configured bridges are canceled and awaited before the hub reports stopped. +// Safe to call multiple times — subsequent calls are no-ops. +// Pass context.Background() for an unbounded wait. +func (h *Hub) Shutdown(ctx context.Context) error { + h.draining.Store(true) + h.shutdownOnce.Do(func() { + // Closing h.shutdown fans out to the bridge-cancel goroutine + // registered in newHub, the run loop, and watchers. + close(h.shutdown) + }) + + // Bridges must finish before we report stopped so their in-flight + // Publish calls don't race with a re-used hub — but wait must still + // honor the caller's deadline so a wedged bridge can't hang Shutdown. + bridgesDone := make(chan struct{}) + go func() { + h.bridges.Wait() + close(bridgesDone) + }() + + select { + case <-bridgesDone: + case <-ctx.Done(): + return fmt.Errorf("sse: shutdown: %w", ctx.Err()) + } + + select { + case <-h.stopped: + return nil + case <-ctx.Done(): + return fmt.Errorf("sse: shutdown: %w", ctx.Err()) + } +} + +// Stats returns a snapshot of the hub's current state. +func (h *Hub) Stats() HubStats { + h.mu.RLock() + defer h.mu.RUnlock() + + byTopic := make(map[string]int, len(h.topicIndex)) + for topic, conns := range h.topicIndex { + byTopic[topic] = len(conns) + } + + return HubStats{ + ActiveConnections: len(h.connections), + TotalTopics: len(h.topicIndex), + EventsPublished: h.metrics.eventsPublished.Load(), + EventsDropped: h.metrics.eventsDropped.Load(), + ConnectionsByTopic: byTopic, + EventsByType: h.metrics.snapshotEventsByType(), + } +} + +// initStream writes the initial SSE preamble: retry hint, replayed events, +// and the connected event. +func (h *Hub) initStream(w *bufio.Writer, conn *Connection, lastEventID string) error { + if err := writeRetry(w, h.cfg.RetryMS); err != nil { + return err + } + + if err := h.replayEvents(w, conn, lastEventID); err != nil { + return err + } + + return sendConnectedEvent(w, conn) +} + +// replayEvents replays missed events if the client sent a Last-Event-ID. +func (h *Hub) replayEvents(w *bufio.Writer, conn *Connection, lastEventID string) error { + if lastEventID == "" || h.cfg.Replayer == nil { + return nil + } + events, err := h.cfg.Replayer.Replay(lastEventID, conn.Topics) + if err != nil { + // Replay is best-effort; log and continue without replayed events. + log.Warnf("sse: replayer error, continuing without replay: %v", err) + return nil + } + if len(events) == 0 { + return nil + } + for _, me := range events { + if _, werr := me.WriteTo(w); werr != nil { + return werr + } + } + if err := w.Flush(); err != nil { + return fmt.Errorf("sse: flush replay: %w", err) + } + return nil +} + +// sendConnectedEvent writes the connected event with the connection ID +// and subscribed topics. +func sendConnectedEvent(w *bufio.Writer, conn *Connection) error { + topicsJSON, err := json.Marshal(conn.Topics) + if err != nil { + topicsJSON = []byte("[]") + } + connected := MarshaledEvent{ + ID: nextEventID(), + Type: "connected", + Data: fmt.Sprintf(`{"connection_id":%q,"topics":%s}`, conn.ID, string(topicsJSON)), + Retry: -1, + } + if _, err := connected.WriteTo(w); err != nil { + return err + } + if err := w.Flush(); err != nil { + return fmt.Errorf("sse: flush connected event: %w", err) + } + return nil +} + +// watchLifetime starts a goroutine that closes the connection after +// MaxLifetime has elapsed. +func (h *Hub) watchLifetime(conn *Connection) { + if h.cfg.MaxLifetime <= 0 { + return + } + go func() { + timer := time.NewTimer(h.cfg.MaxLifetime) + defer timer.Stop() + select { + case <-timer.C: + conn.Close() + case <-conn.done: + } + }() +} + +// shutdownDrainDelay is the time between sending the server-shutdown event +// and closing the connection, allowing the client to process the event. +const shutdownDrainDelay = 200 * time.Millisecond + +// broadcastShutdown queues a server-shutdown event on every live connection. +// Called from the run loop on the shutdown signal BEFORE any Close() so that +// writeLoop has a chance to flush the event to the network. A short drain +// delay afterwards gives writers time to complete the flush before close. +func (h *Hub) broadcastShutdown() { + h.mu.RLock() + conns := make([]*Connection, 0, len(h.connections)) + for _, conn := range h.connections { + if !conn.IsClosed() { + conns = append(conns, conn) + } + } + h.mu.RUnlock() + + for _, conn := range conns { + conn.trySend(MarshaledEvent{ + ID: nextEventID(), + Type: "server-shutdown", + Data: "{}", + Retry: -1, + }) + } +} + +// run is the hub's main event loop. +func (h *Hub) run() { + defer close(h.stopped) + + flushTicker := time.NewTicker(h.cfg.FlushInterval) + defer flushTicker.Stop() + + heartbeatTicker := time.NewTicker(h.cfg.HeartbeatInterval) + defer heartbeatTicker.Stop() + + cleanupTicker := time.NewTicker(5 * time.Minute) + defer cleanupTicker.Stop() + + for { + select { + case conn := <-h.register: + h.addConnection(conn) + + case conn := <-h.unregister: + h.removeConnection(conn) + + case event := <-h.events: + h.routeEvent(&event) + + case <-flushTicker.C: + h.flushAll() + + case <-heartbeatTicker.C: + h.sendHeartbeats() + + case <-cleanupTicker.C: + h.throttler.cleanup(time.Now().Add(-10 * time.Minute)) + + case <-h.shutdown: + // Notify clients first, wait briefly for writeLoops to flush, + // then close. Prevents a race where Close() beats the + // server-shutdown event to the network. + h.broadcastShutdown() + time.Sleep(shutdownDrainDelay) + h.mu.Lock() + for _, conn := range h.connections { + conn.Close() + } + h.mu.Unlock() + return + } + } +} + +// addConnection registers a new connection and indexes it by topic. +func (h *Hub) addConnection(conn *Connection) { + h.mu.Lock() + defer h.mu.Unlock() + + h.connections[conn.ID] = conn + + hasWildcard := false + for _, topic := range conn.Topics { + if strings.ContainsAny(topic, "*>") { + hasWildcard = true + } else { + if h.topicIndex[topic] == nil { + h.topicIndex[topic] = make(map[string]struct{}) + } + h.topicIndex[topic][conn.ID] = struct{}{} + } + } + if hasWildcard { + h.wildcardConns[conn.ID] = struct{}{} + } + + log.Infof("sse: connection opened conn_id=%s topics=%v total=%d", + conn.ID, conn.Topics, len(h.connections)) +} + +// removeConnection unregisters a connection and removes it from topic indexes. +func (h *Hub) removeConnection(conn *Connection) { + h.mu.Lock() + defer h.mu.Unlock() + + if _, exists := h.connections[conn.ID]; !exists { + return + } + + for _, topic := range conn.Topics { + if idx, ok := h.topicIndex[topic]; ok { + delete(idx, conn.ID) + if len(idx) == 0 { + delete(h.topicIndex, topic) + } + } + } + + delete(h.wildcardConns, conn.ID) + delete(h.connections, conn.ID) + h.throttler.remove(conn.ID) + + log.Infof("sse: connection closed conn_id=%s sent=%d dropped=%d total=%d", + conn.ID, conn.MessagesSent.Load(), conn.MessagesDropped.Load(), len(h.connections)) +} + +// routeEvent delivers an event to all connections subscribed to its topics. +func (h *Hub) routeEvent(event *Event) { + if event.TTL > 0 && !event.CreatedAt.IsZero() { + if time.Since(event.CreatedAt) > event.TTL { + h.metrics.eventsDropped.Add(1) + return + } + } + + me := marshalEvent(event) + h.metrics.trackEventType(event.Type) + + // Skip replayer for group-scoped events to avoid cross-tenant leaks + // on reconnect. Store errors are logged but non-fatal — replay is a + // best-effort feature and one missing event shouldn't break delivery. + if h.cfg.Replayer != nil && len(event.Group) == 0 { + if err := h.cfg.Replayer.Store(me, event.Topics); err != nil { + log.Warnf("sse: replayer store error, continuing: %v", err) + } + } + + h.mu.RLock() + defer h.mu.RUnlock() + + seen := h.matchConnections(event) + + for connID := range seen { + conn, ok := h.connections[connID] + if !ok || conn.IsClosed() { + continue + } + if conn.paused.Load() && event.Priority != PriorityInstant { + continue + } + h.deliverToConn(conn, event, me) + } +} + +// matchConnections collects all connection IDs that should receive the event. +// When both Topics and Group are set, only connections matching BOTH are +// included (intersection semantics) to prevent tenant/topic leaks (CRITICAL-1). +func (h *Hub) matchConnections(event *Event) map[string]struct{} { + seen := make(map[string]struct{}) + + for _, topic := range event.Topics { + if idx, ok := h.topicIndex[topic]; ok { + for connID := range idx { + seen[connID] = struct{}{} + } + } + } + + h.matchWildcardConns(event, seen) + + // If event has a Group, filter seen to intersection with group-matching conns. + if len(event.Group) > 0 { + if len(event.Topics) > 0 { + // Intersection: keep only topic-matched conns that also match group. + for connID := range seen { + conn := h.connections[connID] + if conn == nil || !connMatchesGroup(conn, event.Group) { + delete(seen, connID) + } + } + } else { + // Group-only event: match by group alone. + h.matchGroupConns(event, seen) + } + } + + return seen +} + +// matchWildcardConns adds wildcard-subscribed connections that match the event topics. +func (h *Hub) matchWildcardConns(event *Event, seen map[string]struct{}) { + for connID := range h.wildcardConns { + if _, already := seen[connID]; already { + continue + } + conn, ok := h.connections[connID] + if !ok { + continue + } + for _, eventTopic := range event.Topics { + if connMatchesTopic(conn, eventTopic) { + seen[connID] = struct{}{} + break + } + } + } +} + +// matchGroupConns adds connections that match the event's group metadata. +func (h *Hub) matchGroupConns(event *Event, seen map[string]struct{}) { + if len(event.Group) == 0 { + return + } + for connID, conn := range h.connections { + if _, already := seen[connID]; already { + continue + } + if connMatchesGroup(conn, event.Group) { + seen[connID] = struct{}{} + } + } +} + +// deliverToConn routes an event to a connection based on priority. +func (h *Hub) deliverToConn(conn *Connection, event *Event, me MarshaledEvent) { //nolint:gocritic // hugeParam: internal, copy is cheap + switch event.Priority { + case PriorityInstant: + if !conn.trySend(me) { + h.metrics.eventsDropped.Add(1) + } + case PriorityBatched: + conn.dispatcher.AddEvent(me) + case PriorityCoalesced: + key := event.CoalesceKey + if key == "" { + key = event.Type + } + conn.dispatcher.AddState(key, me) + default: + // Unknown priority: drop to avoid misrouting. + h.metrics.eventsDropped.Add(1) + } +} + +// flushAll drains each connection's dispatcher and sends buffered events. +func (h *Hub) flushAll() { + h.mu.RLock() + conns := make([]*Connection, 0, len(h.connections)) + for _, conn := range h.connections { + if !conn.IsClosed() && !conn.paused.Load() { + conns = append(conns, conn) + } + } + h.mu.RUnlock() + + now := time.Now() + for _, conn := range conns { + if conn.IsClosed() { + continue + } + + bufCap := cap(conn.send) + saturation := float64(0) + if bufCap > 0 { + saturation = float64(len(conn.send)) / float64(bufCap) + } + + if !h.throttler.shouldFlush(conn.ID, saturation) { + continue + } + + events := conn.dispatcher.WriteTo() + for _, me := range events { + // Re-check TTL after dispatching delay: coalesced events may + // sit in the queue past their deadline (MAJOR-6). + if me.TTL > 0 && !me.CreatedAt.IsZero() && now.Sub(me.CreatedAt) > me.TTL { + h.metrics.eventsDropped.Add(1) + continue + } + if !conn.trySend(me) { + h.metrics.eventsDropped.Add(1) + } + } + } +} + +// sendHeartbeats sends a comment to connections that haven't received +// real data recently. +func (h *Hub) sendHeartbeats() { + h.mu.RLock() + defer h.mu.RUnlock() + + now := time.Now() + for _, conn := range h.connections { + if conn.IsClosed() { + continue + } + lastWrite, _ := conn.lastWrite.Load().(time.Time) //nolint:errcheck // type assertion on atomic.Value + if now.Sub(lastWrite) >= h.cfg.HeartbeatInterval { + conn.sendHeartbeat() + } + } +} diff --git a/middleware/sse/replayer.go b/middleware/sse/replayer.go new file mode 100644 index 00000000000..570f700f641 --- /dev/null +++ b/middleware/sse/replayer.go @@ -0,0 +1,14 @@ +package sse + +// Replayer stores events for replay when a client reconnects with Last-Event-ID. +// Implement this interface to plug in any storage backend (Redis Streams, +// database, in-memory ring buffer, etc.). +type Replayer interface { + // Store persists an event for potential future replay. + Store(event MarshaledEvent, topics []string) error + + // Replay returns all events after lastEventID that match any of the + // given topics, in chronological order. Returns nil if lastEventID + // is unknown (caller should treat as a fresh connection). + Replay(lastEventID string, topics []string) ([]MarshaledEvent, error) +} diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go new file mode 100644 index 00000000000..2a243e17307 --- /dev/null +++ b/middleware/sse/sse.go @@ -0,0 +1,163 @@ +// Package sse provides Server-Sent Events middleware for Fiber. +// +// It is the only SSE implementation built natively for Fiber's +// fasthttp architecture — no net/http adapters, no broken disconnect +// detection. +// +// Features: event coalescing (last-writer-wins), three priority lanes +// (instant/batched/coalesced), NATS-style topic wildcards, adaptive +// per-connection throttling, connection groups (publish by metadata), +// graceful Kubernetes-style drain, pluggable Last-Event-ID replay, +// and a SubscriberBridge adapter for external pub/sub sources such as +// Redis and NATS. +// +// Quick start: +// +// handler, hub := sse.NewWithHub(sse.Config{ +// OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { +// conn.Topics = []string{"notifications"} +// return nil +// }, +// }) +// app.Get("/events", handler) +// hub.Publish(sse.Event{Type: "ping", Data: "hello", Topics: []string{"notifications"}}) +// +// The middleware is terminal: the returned handler hijacks the response +// stream via Fiber's SendStreamWriter and never calls c.Next(). Do not +// chain additional handlers after it. +package sse + +import ( + "bufio" + "crypto/rand" + "encoding/hex" + "maps" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" +) + +// New creates a new SSE middleware handler. Use this when you don't need +// direct access to the Hub (e.g., simple streaming without Publish). +// +// For most use cases, prefer [NewWithHub] instead. +func New(config ...Config) fiber.Handler { + handler, _ := NewWithHub(config...) + return handler +} + +// NewWithHub creates a new SSE middleware handler and returns it along +// with the Hub for publishing events. This is the primary entry point. +// +// handler, hub := sse.NewWithHub(sse.Config{ +// OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { +// conn.Topics = []string{"notifications", "live"} +// conn.Metadata["tenant_id"] = c.Locals("tenant_id").(string) +// return nil +// }, +// }) +// app.Get("/events", handler) +// +// // From any handler or worker: +// hub.Publish(sse.Event{Type: "update", Data: "hello", Topics: []string{"live"}}) +func NewWithHub(config ...Config) (fiber.Handler, *Hub) { + cfg := configDefault(config...) + hub := newHub(cfg) + + handler := func(c fiber.Ctx) error { + // Reject during graceful drain. + if hub.draining.Load() { + c.Set("Retry-After", "5") + return c.Status(fiber.StatusServiceUnavailable).SendString("server draining, please reconnect") + } + + conn := newConnection( + generateID(), + nil, + cfg.SendBufferSize, + cfg.FlushInterval, + ) + + // Let the application authenticate and configure the connection. + // The returned error is logged server-side (so operators can tell + // auth-fail from rate-limit from tenant-mismatch, etc.) but never + // exposed to the client — callers may include user / tenant + // identifiers or internal policy reasons that would leak + // information to an unauthenticated peer. + if cfg.OnConnect != nil { + if err := cfg.OnConnect(c, conn); err != nil { + log.Warnf("sse: OnConnect rejected connection: %v", err) + return fiber.NewError(fiber.StatusForbidden, "forbidden") + } + } + + // Freeze metadata — defensive copy to prevent concurrent mutation + // after the connection is registered with the hub. + frozen := make(map[string]string, len(conn.Metadata)) + maps.Copy(frozen, conn.Metadata) + conn.Metadata = frozen + + if len(conn.Topics) == 0 { + return c.Status(fiber.StatusBadRequest).SendString("no topics subscribed") + } + + // SSE response headers. + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("X-Accel-Buffering", "no") + + // Capture Last-Event-ID before entering the stream writer. + lastEventID := c.Get("Last-Event-ID") + if lastEventID == "" { + lastEventID = c.Query("lastEventID") + } + + // Abandon the ctx so Fiber does not return it to the pool while + // fasthttp is still invoking the stream writer in a background + // goroutine. + c.Abandon() + + return c.SendStreamWriter(func(w *bufio.Writer) { + defer func() { + select { + case hub.unregister <- conn: + case <-hub.shutdown: + } + conn.Close() + if cfg.OnDisconnect != nil { + cfg.OnDisconnect(conn) + } + }() + + // Register BEFORE writing the preamble / replay so events + // published during replay buffer in conn.send instead of being + // dropped. Event IDs are monotonic, so live events always have + // higher IDs than any replayed event — no duplicates are + // possible with a Last-Event-ID strictly-after replayer. + select { + case hub.register <- conn: + case <-hub.shutdown: + return + } + + if err := hub.initStream(w, conn, lastEventID); err != nil { + return + } + + hub.watchLifetime(conn) + conn.writeLoop(w) + }) + } + + return handler, hub +} + +// generateID produces a random 32-character hex string for connection IDs. +func generateID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + panic("sse: failed to generate connection ID: " + err.Error()) + } + return hex.EncodeToString(b) +} diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go new file mode 100644 index 00000000000..9f6cabd373e --- /dev/null +++ b/middleware/sse/sse_test.go @@ -0,0 +1,2256 @@ +package sse + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +// startSSEServer spins up a real TCP listener serving the given handler and +// returns the base URL + cleanup. Use for end-to-end response-body checks +// since app.Test() blocks on SSE streams that never terminate. +func startSSEServer(t *testing.T, handler fiber.Handler) (string, func()) { //nolint:gocritic // unnamedResult: nonamedreturns rule forbids names; types are self-explanatory + t.Helper() + app := fiber.New() + app.Get("/events", handler) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + go func() { + _ = app.Listener(ln) //nolint:errcheck // best-effort test listener + }() + + baseURL := "http://" + ln.Addr().String() + cleanup := func() { + // Close the listener first to force-abort in-flight SSE writers; + // app.Shutdown would otherwise wait for the long-lived handler. + _ = ln.Close() //nolint:errcheck // may already be closed + shutdownCtx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _ = app.ShutdownWithContext(shutdownCtx) //nolint:errcheck // best-effort test shutdown + } + return baseURL, cleanup +} + +// sseFrameTimeout is the deadline for reading a single SSE frame in tests. +const sseFrameTimeout = 2 * time.Second + +// readSSEFrame reads one SSE frame (ending in blank line) from r. +// Fails the test if no complete frame arrives within sseFrameTimeout. +func readSSEFrame(t *testing.T, r *bufio.Reader) string { + t.Helper() + type result struct { + err error + frame string + } + done := make(chan result, 1) + go func() { + var buf bytes.Buffer + for { + line, err := r.ReadString('\n') + if err != nil { + done <- result{err: err} + return + } + _, _ = buf.WriteString(line) //nolint:errcheck // bytes.Buffer.WriteString never fails + if line == "\n" { + done <- result{frame: buf.String()} + return + } + } + }() + select { + case res := <-done: + require.NoError(t, res.err) + return res.frame + case <-time.After(sseFrameTimeout): + t.Fatal("timed out waiting for SSE frame") + return "" + } +} + +func Test_SSE_E2E_HeadersAndConnectedFrame(t *testing.T) { + t.Parallel() + + handler, hub := NewWithHub(Config{ + MaxLifetime: 500 * time.Millisecond, + HeartbeatInterval: 100 * time.Millisecond, + FlushInterval: 50 * time.Millisecond, + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"updates"} + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + base, cleanup := startSSEServer(t, handler) + defer cleanup() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/events", http.NoBody) + require.NoError(t, err) + + resp, err := (&http.Client{Timeout: 5 * time.Second}).Do(req) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) + + // RFC 8895 + W3C SSE: Content-Type must be text/event-stream. + require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + require.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) + require.Equal(t, "keep-alive", resp.Header.Get("Connection")) + require.Equal(t, "no", resp.Header.Get("X-Accel-Buffering")) + require.Equal(t, http.StatusOK, resp.StatusCode) + + br := bufio.NewReader(resp.Body) + + // First: retry directive frame from writeRetry. + retryFrame := readSSEFrame(t, br) + require.Contains(t, retryFrame, "retry: 3000") + + // Second: connected event with connection_id and topics. + connectedFrame := readSSEFrame(t, br) + require.Contains(t, connectedFrame, "event: connected") + require.Contains(t, connectedFrame, "connection_id") + require.Contains(t, connectedFrame, `"topics":["updates"]`) +} + +func Test_SSE_E2E_PublishedEventDeliveredToClient(t *testing.T) { + t.Parallel() + + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"orders"} + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + base, cleanup := startSSEServer(t, handler) + defer cleanup() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/events", http.NoBody) + require.NoError(t, err) + resp, err := (&http.Client{Timeout: 5 * time.Second}).Do(req) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) + + br := bufio.NewReader(resp.Body) + _ = readSSEFrame(t, br) // retry + _ = readSSEFrame(t, br) // connected + + // Give the hub time to register the connection before publishing. + time.Sleep(50 * time.Millisecond) + + hub.Publish(Event{ + Type: "order-created", + Data: `{"id":"ord_123","total":99}`, + Topics: []string{"orders"}, + Priority: PriorityInstant, + }) + + frame := readSSEFrame(t, br) + require.Contains(t, frame, "event: order-created") + require.Contains(t, frame, `data: {"id":"ord_123","total":99}`) + require.Contains(t, frame, "id: evt_") +} + +func Test_SSE_E2E_MultilineDataProducesMultipleDataLines(t *testing.T) { + t.Parallel() + + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"logs"} + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + base, cleanup := startSSEServer(t, handler) + defer cleanup() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/events", http.NoBody) + require.NoError(t, err) + resp, err := (&http.Client{Timeout: 5 * time.Second}).Do(req) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) + + br := bufio.NewReader(resp.Body) + _ = readSSEFrame(t, br) + _ = readSSEFrame(t, br) + time.Sleep(50 * time.Millisecond) + + hub.Publish(Event{ + Type: "log", + Data: "line1\nline2\nline3", + Topics: []string{"logs"}, + Priority: PriorityInstant, + }) + + frame := readSSEFrame(t, br) + require.Contains(t, frame, "data: line1\n") + require.Contains(t, frame, "data: line2\n") + require.Contains(t, frame, "data: line3\n") +} + +func Test_SSE_E2E_IDAndTypeSanitizedAgainstInjection(t *testing.T) { + t.Parallel() + + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"t"} + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + base, cleanup := startSSEServer(t, handler) + defer cleanup() + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/events", http.NoBody) + require.NoError(t, err) + resp, err := (&http.Client{Timeout: 5 * time.Second}).Do(req) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) + + br := bufio.NewReader(resp.Body) + _ = readSSEFrame(t, br) + _ = readSSEFrame(t, br) + time.Sleep(50 * time.Millisecond) + + hub.Publish(Event{ + ID: "bad\nid: injected", + Type: "evt\nevent: sneaky", + Data: "x", + Topics: []string{"t"}, + Priority: PriorityInstant, + }) + + frame := readSSEFrame(t, br) + + // Split frame into logical lines; id: and event: must each appear on + // exactly one line. Injection attempts that embed `\nid: injected` or + // `\nevent: sneaky` would create extra lines the SSE parser would + // interpret as additional fields — we must collapse them away. + lines := strings.Split(frame, "\n") + var idLines, eventLines int + for _, line := range lines { + if strings.HasPrefix(line, "id: ") { + idLines++ + } + if strings.HasPrefix(line, "event: ") { + eventLines++ + } + } + require.Equal(t, 1, idLines, "id: injection must be sanitized") + require.Equal(t, 1, eventLines, "event: injection must be sanitized") +} + +func Test_SSE_New(t *testing.T) { + t.Parallel() + + handler, hub := NewWithHub() + require.NotNil(t, handler) + require.NotNil(t, hub) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_New_DefaultConfig(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + require.Equal(t, 2*time.Second, hub.cfg.FlushInterval) + require.Equal(t, 256, hub.cfg.SendBufferSize) + require.Equal(t, 30*time.Second, hub.cfg.HeartbeatInterval) + require.Equal(t, 30*time.Minute, hub.cfg.MaxLifetime) + require.Equal(t, 3000, hub.cfg.RetryMS) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_New_CustomConfig(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{ + FlushInterval: 5 * time.Second, + SendBufferSize: 128, + HeartbeatInterval: 10 * time.Second, + MaxLifetime: time.Hour, + RetryMS: 5000, + }) + require.Equal(t, 5*time.Second, hub.cfg.FlushInterval) + require.Equal(t, 128, hub.cfg.SendBufferSize) + require.Equal(t, 10*time.Second, hub.cfg.HeartbeatInterval) + require.Equal(t, time.Hour, hub.cfg.MaxLifetime) + require.Equal(t, 5000, hub.cfg.RetryMS) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_NoTopics(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, _ *Connection) error { + // Don't set any topics + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_SSE_OnConnectReject(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, _ *Connection) error { + return errors.New("unauthorized") + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_SSE_GenerateID(t *testing.T) { + t.Parallel() + + ids := make(map[string]struct{}) + for range 1000 { + id := generateID() + require.Len(t, id, 32) + _, exists := ids[id] + require.False(t, exists, "duplicate ID generated") + ids[id] = struct{}{} + } +} + +func Test_SSE_TopicMatch(t *testing.T) { + t.Parallel() + + tests := []struct { + pattern string + topic string + want bool + }{ + {"events", "events", true}, + {"events", "events.sub", false}, + {"notifications.*", "notifications.orders", true}, + {"notifications.*", "notifications.orders.new", false}, + {"analytics.>", "analytics.live", true}, + {"analytics.>", "analytics.live.visitors", true}, + {"analytics.>", "analytics", false}, + {"*", "anything", true}, + {">", "anything", true}, + {">", "a.b.c", true}, + // > must be last token — invalid patterns should not match + {"a.>.c", "a.b.c", false}, + {">.b", "a.b", false}, + } + + for _, tt := range tests { + got := topicMatch(tt.pattern, tt.topic) + require.Equal(t, tt.want, got, "topicMatch(%q, %q)", tt.pattern, tt.topic) + } +} + +func Test_SSE_MarshalEvent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data any + want string + }{ + {"nil", nil, ""}, + {"string", "hello", "hello"}, + {"bytes", []byte("world"), "world"}, + {"struct", map[string]string{"key": "val"}, `{"key":"val"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + me := marshalEvent(&Event{Data: tt.data}) + require.Equal(t, tt.want, me.Data) + require.NotEmpty(t, me.ID) // auto-generated + }) + } +} + +func Test_SSE_MarshaledEvent_WriteTo(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + ID: "evt_1", + Type: "test", + Data: "hello world", + } + + var buf bytes.Buffer + n, err := me.WriteTo(&buf) + require.NoError(t, err) + require.Positive(t, n) + + output := buf.String() + require.Contains(t, output, "id: evt_1\n") + require.Contains(t, output, "event: test\n") + require.Contains(t, output, "data: hello world\n") + // A zero-value Retry field must not produce `retry: 0` (which would + // tell clients to reconnect immediately per the SSE spec). + require.NotContains(t, output, "retry:") + require.True(t, strings.HasSuffix(output, "\n\n")) +} + +func Test_SSE_MarshaledEvent_WriteTo_Multiline(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + ID: "evt_2", + Type: "test", + Data: "line1\nline2\nline3", + } + + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + + output := buf.String() + require.Contains(t, output, "data: line1\n") + require.Contains(t, output, "data: line2\n") + require.Contains(t, output, "data: line3\n") +} + +func Test_SSE_MarshaledEvent_WriteTo_RetryZeroOmitted(t *testing.T) { + t.Parallel() + + // Retry: 0 (the zero value) must NOT emit `retry: 0\n`. Per the SSE + // spec that directive tells clients to reconnect immediately — + // emitting it for an unset field could cascade into a reconnect storm. + me := MarshaledEvent{ID: "evt_zero", Type: "test", Data: "x", Retry: 0} + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + require.NotContains(t, buf.String(), "retry:") +} + +func Test_SSE_MarshaledEvent_WriteTo_SanitizesInjectionAtBoundary(t *testing.T) { + t.Parallel() + + // An external Replayer can construct MarshaledEvent directly, bypassing + // marshalEvent's sanitization. WriteTo is the last line of defense — + // control sequences in ID or Type must be stripped so an attacker can't + // inject additional SSE fields onto the wire by embedding \n. + me := MarshaledEvent{ + ID: "evt_1\nevent: injected\nid: fake", + Type: "custom\ndata: also_injected", + Data: "payload", + } + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + + output := buf.String() + + // Invariant: exactly one event frame, exactly one id header, exactly + // one event header. The SSE parser only recognizes `id:` / `event:` + // at the start of a line — so count lines starting with them, not raw + // substring occurrences (the attacker's text survives INSIDE the + // field value but cannot start a new line). + var idLines, eventLines int + for line := range strings.SplitSeq(output, "\n") { + switch { + case strings.HasPrefix(line, "id: "): + idLines++ + case strings.HasPrefix(line, "event: "): + eventLines++ + default: + } + } + require.Equal(t, 1, idLines, "exactly one id line") + require.Equal(t, 1, eventLines, "exactly one event line") + + // Frame terminator: ends with exactly one blank line separator. + require.True(t, strings.HasSuffix(output, "\n\n")) + // No partial frames — only one `\n\n` separator total. + require.Equal(t, 1, strings.Count(output, "\n\n")) +} + +func Test_SSE_MarshaledEvent_WriteTo_TypedNilJSONMarshaler(t *testing.T) { + t.Parallel() + + // A typed-nil pointer whose type implements json.Marshaler used to panic + // in the explicit `case json.Marshaler:` branch because the receiver + // was dereferenced without a nil check. All values now flow through + // json.Marshal in the default branch which is nil-safe (emits "null"). + var nilMarshaler *panicOnMarshal + evt := Event{ID: "evt_tn", Type: "test", Data: nilMarshaler, Topics: []string{"t"}} + me := marshalEvent(&evt) + + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + require.Contains(t, buf.String(), "data: null\n") +} + +// panicOnMarshal dereferences its receiver in MarshalJSON — a typed-nil +// pointer to this type would panic if invoked directly. Used to prove +// marshalEvent routes through json.Marshal's nil-safe path. +type panicOnMarshal struct{ Name string } + +func (p *panicOnMarshal) MarshalJSON() ([]byte, error) { + // Intentionally dereferences — would panic if called on a typed-nil. + return []byte(`"` + p.Name + `"`), nil +} + +func Test_SSE_MarshaledEvent_WriteTo_Retry(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + ID: "evt_3", + Type: "test", + Data: "x", + Retry: 3000, + } + + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + require.Contains(t, buf.String(), "retry: 3000\n") +} + +func Test_SSE_Dispatcher(t *testing.T) { + t.Parallel() + + c := newDispatcher(time.Second) + + // Add batched events + c.AddEvent(MarshaledEvent{ID: "1", Data: "a"}) + c.AddEvent(MarshaledEvent{ID: "2", Data: "b"}) + + // Add coalesced events (last wins) + c.AddState("key1", MarshaledEvent{ID: "3", Data: "old"}) + c.AddState("key1", MarshaledEvent{ID: "4", Data: "new"}) + c.AddState("key2", MarshaledEvent{ID: "5", Data: "other"}) + + require.Equal(t, 4, c.pending()) + + events := c.WriteTo() + require.Len(t, events, 4) + + // Batched first + require.Equal(t, "a", events[0].Data) + require.Equal(t, "b", events[1].Data) + + // Coalesced: key1 = "new" (last wins), key2 = "other" + require.Equal(t, "new", events[2].Data) + require.Equal(t, "other", events[3].Data) + + // Should be empty now + require.Nil(t, c.WriteTo()) +} + +func Test_SSE_AdaptiveThrottler(t *testing.T) { + t.Parallel() + + at := newAdaptiveThrottler(2 * time.Second) + + // First flush always passes + require.True(t, at.shouldFlush("conn1", 0.0)) + + // Second flush immediately — should fail (too soon) + require.False(t, at.shouldFlush("conn1", 0.0)) + + // Clean up + at.remove("conn1") +} + +func Test_SSE_Publish_Stats(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + hub.Publish(Event{Type: "test", Topics: []string{"t"}, Data: "hello"}) + time.Sleep(50 * time.Millisecond) // let run loop process + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) +} + +func Test_SSE_Shutdown(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := hub.Shutdown(ctx) + require.NoError(t, err) + require.True(t, hub.draining.Load()) +} + +func Test_SSE_Shutdown_Idempotent(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // First call shuts down + require.NoError(t, hub.Shutdown(ctx)) + + // Second call must not panic (sync.Once guards close) + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_Shutdown_Background_Context(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + err := hub.Shutdown(context.Background()) + require.NoError(t, err) +} + +func Test_SSE_Draining_RejectsConnection(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"test"} + return nil + }, + }) + + app.Get("/events", handler) + + // Start draining + hub.draining.Store(true) + defer func() { + close(hub.shutdown) + <-hub.stopped + }() + + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode) +} + +func Test_SSE_Connection_Lifecycle(t *testing.T) { + t.Parallel() + + conn := newConnection("test-id", []string{"t"}, 10, time.Second) + require.Equal(t, "test-id", conn.ID) + require.False(t, conn.IsClosed()) + + conn.Close() + require.True(t, conn.IsClosed()) + + // Double close should not panic + conn.Close() +} + +func Test_SSE_Connection_TrySend_Backpressure(t *testing.T) { + t.Parallel() + + conn := newConnection("test", nil, 2, time.Second) + + require.True(t, conn.trySend(MarshaledEvent{Data: "1"})) + require.True(t, conn.trySend(MarshaledEvent{Data: "2"})) + + // Buffer full + require.False(t, conn.trySend(MarshaledEvent{Data: "3"})) + require.Equal(t, int64(1), conn.MessagesDropped.Load()) +} + +func Test_SSE_WriteComment(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := writeComment(&buf, "heartbeat") + require.NoError(t, err) + require.Equal(t, ": heartbeat\n\n", buf.String()) +} + +func Test_SSE_WriteRetry(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := writeRetry(&buf, 3000) + require.NoError(t, err) + require.Equal(t, "retry: 3000\n\n", buf.String()) +} + +func Test_SSE_MaxLifetime_Unlimited(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{ + MaxLifetime: -1, // unlimited + }) + require.Equal(t, time.Duration(-1), hub.cfg.MaxLifetime) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_SetPaused(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // SetPaused on non-existent connection should not panic + hub.SetPaused("nonexistent", true) + + // Add a connection manually + conn := newConnection("test-conn", []string{"t"}, 10, time.Second) + hub.mu.Lock() + hub.connections["test-conn"] = conn + hub.mu.Unlock() + + hub.SetPaused("test-conn", true) + require.True(t, conn.paused.Load()) + + hub.SetPaused("test-conn", false) + require.False(t, conn.paused.Load()) +} + +// --------------------------------------------------------------------------- +// Coverage-boost tests +// --------------------------------------------------------------------------- + +func Test_SSE_New_Wrapper(t *testing.T) { + t.Parallel() + handler := New() + require.NotNil(t, handler) +} + +func Test_SSE_SanitizeSSEField(t *testing.T) { + t.Parallel() + + require.Equal(t, "clean", sanitizeSSEField("clean")) + require.Equal(t, "ab", sanitizeSSEField("a\nb")) + require.Equal(t, "ab", sanitizeSSEField("a\rb")) + require.Equal(t, "ab", sanitizeSSEField("a\r\nb")) + require.Equal(t, "abc", sanitizeSSEField("a\r\nb\nc")) +} + +func Test_SSE_MarshalEvent_SanitizesIDAndType(t *testing.T) { + t.Parallel() + + me := marshalEvent(&Event{ + ID: "id\r\ninjected", + Type: "type\ninjected", + Data: "safe", + }) + require.Equal(t, "idinjected", me.ID) + require.Equal(t, "typeinjected", me.Type) +} + +func Test_SSE_MarshalEvent_JsonMarshalerError(t *testing.T) { + t.Parallel() + + me := marshalEvent(&Event{Data: badMarshaler{}}) + require.Contains(t, me.Data, "error") +} + +func Test_SSE_MarshalEvent_DefaultMarshalError(t *testing.T) { + t.Parallel() + + // A channel cannot be JSON-marshaled + me := marshalEvent(&Event{Data: make(chan int)}) + require.Contains(t, me.Data, "error") +} + +// badMarshaler implements json.Marshaler and always returns an error. +type badMarshaler struct{} + +func (badMarshaler) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshal failed") +} + +func Test_SSE_WriteTo_EmptyFields(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + Data: "x", + Retry: -1, + } + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + output := buf.String() + // No id: or event: lines + require.NotContains(t, output, "id:") + require.NotContains(t, output, "event:") + require.Contains(t, output, "data: x\n") +} + +func Test_SSE_ConnMatchesGroup(t *testing.T) { + t.Parallel() + + conn := newConnection("c1", []string{"t"}, 10, time.Second) + conn.Metadata["tenant_id"] = "t_1" + conn.Metadata["role"] = "admin" + + require.True(t, connMatchesGroup(conn, map[string]string{"tenant_id": "t_1"})) + require.True(t, connMatchesGroup(conn, map[string]string{"tenant_id": "t_1", "role": "admin"})) + require.False(t, connMatchesGroup(conn, map[string]string{"tenant_id": "t_2"})) + require.False(t, connMatchesGroup(conn, map[string]string{"missing": "key"})) + require.True(t, connMatchesGroup(conn, map[string]string{})) // empty group matches all +} + +func Test_SSE_SendHeartbeat(t *testing.T) { + t.Parallel() + + conn := newConnection("hb", []string{"t"}, 10, time.Second) + + // First heartbeat should succeed + conn.sendHeartbeat() + // Second should be silently dropped (buffer 1) + conn.sendHeartbeat() + + // Drain the heartbeat channel + select { + case <-conn.heartbeat: + default: + t.Fatal("expected heartbeat in channel") + } +} + +func Test_SSE_WriteLoop_Events(t *testing.T) { + t.Parallel() + + conn := newConnection("wl", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + // Send an event + heartbeat, then close + conn.trySend(MarshaledEvent{ID: "e1", Type: "test", Data: "hello", Retry: -1}) + conn.sendHeartbeat() + + go func() { + time.Sleep(50 * time.Millisecond) + conn.Close() + }() + + conn.writeLoop(w) + + output := buf.String() + require.Contains(t, output, "id: e1\n") + require.Contains(t, output, "event: test\n") + require.Contains(t, output, "data: hello\n") + require.Contains(t, output, ": heartbeat\n") + require.Equal(t, int64(1), conn.MessagesSent.Load()) +} + +func Test_SSE_WriteLoop_ChannelClose(t *testing.T) { + t.Parallel() + + conn := newConnection("wlc", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + // Close the send channel directly to test the !ok path + close(conn.send) + conn.writeLoop(w) + // Should return without panic +} + +func Test_SSE_TopicMatchesAny(t *testing.T) { + t.Parallel() + + require.True(t, topicMatchesAny([]string{"orders", "products"}, "orders")) + require.True(t, topicMatchesAny([]string{"orders.*"}, "orders.created")) + require.False(t, topicMatchesAny([]string{"orders", "products"}, "users")) + require.False(t, topicMatchesAny(nil, "anything")) +} + +func Test_SSE_ConnMatchesTopic(t *testing.T) { + t.Parallel() + + conn := newConnection("ct", []string{"orders.*", "products"}, 10, time.Second) + require.True(t, connMatchesTopic(conn, "orders.created")) + require.True(t, connMatchesTopic(conn, "products")) + require.False(t, connMatchesTopic(conn, "users")) +} + +func Test_SSE_EffectiveInterval_AllBranches(t *testing.T) { + t.Parallel() + + at := newAdaptiveThrottler(2 * time.Second) + + // saturation > 0.8 → maxInterval + require.Equal(t, at.maxInterval, at.effectiveInterval(0.9)) + // saturation > 0.5 → baseInterval * 2 + require.Equal(t, at.baseInterval*2, at.effectiveInterval(0.6)) + // saturation < 0.1 → minInterval + require.Equal(t, at.minInterval, at.effectiveInterval(0.05)) + // default → baseInterval + require.Equal(t, at.baseInterval, at.effectiveInterval(0.3)) +} + +func Test_SSE_Throttler_Cleanup(t *testing.T) { + t.Parallel() + + at := newAdaptiveThrottler(time.Second) + at.shouldFlush("old-conn", 0.0) + at.shouldFlush("new-conn", 0.0) + + // Make "old-conn" stale + at.mu.Lock() + at.lastFlush["old-conn"] = time.Now().Add(-20 * time.Minute) + at.mu.Unlock() + + at.cleanup(time.Now().Add(-10 * time.Minute)) + + at.mu.Lock() + _, oldExists := at.lastFlush["old-conn"] + _, newExists := at.lastFlush["new-conn"] + at.mu.Unlock() + + require.False(t, oldExists, "old conn should be cleaned up") + require.True(t, newExists, "new conn should remain") +} + +func Test_SSE_SetPaused_Callbacks(t *testing.T) { + t.Parallel() + + var paused, resumed bool + _, hub := NewWithHub(Config{ + OnPause: func(_ *Connection) { paused = true }, + OnResume: func(_ *Connection) { resumed = true }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("cb-conn", []string{"t"}, 10, time.Second) + hub.mu.Lock() + hub.connections["cb-conn"] = conn + hub.mu.Unlock() + + hub.SetPaused("cb-conn", true) + require.True(t, paused) + + hub.SetPaused("cb-conn", false) + require.True(t, resumed) +} + +func Test_SSE_RouteEvent_WithGroup(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Add two connections with different tenants + conn1 := newConnection("c1", []string{"orders"}, 10, time.Second) + conn1.Metadata["tenant_id"] = "t_1" + conn2 := newConnection("c2", []string{"orders"}, 10, time.Second) + conn2.Metadata["tenant_id"] = "t_2" + + hub.mu.Lock() + hub.connections["c1"] = conn1 + hub.connections["c2"] = conn2 + hub.topicIndex["orders"] = map[string]struct{}{"c1": {}, "c2": {}} + hub.mu.Unlock() + + // Publish with group targeting t_1 only + hub.Publish(Event{ + Type: "test", + Topics: []string{"orders"}, + Data: "for-t1", + Group: map[string]string{"tenant_id": "t_1"}, + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + // conn1 should have received the event, conn2 should not + require.Equal(t, int64(0), conn1.MessagesDropped.Load()) + // Check send channel + select { + case me := <-conn1.send: + require.Contains(t, me.Data, "for-t1") + default: + t.Fatal("expected event in conn1 send channel") + } + + select { + case <-conn2.send: + t.Fatal("conn2 should NOT have received the event") + default: + // correct + } +} + +func Test_SSE_RouteEvent_GroupOnly(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Connection with metadata but no topic match — group-only delivery + conn := newConnection("g1", []string{"unrelated"}, 10, time.Second) + conn.Metadata["role"] = "admin" + + hub.mu.Lock() + hub.connections["g1"] = conn + hub.topicIndex["unrelated"] = map[string]struct{}{"g1": {}} + hub.mu.Unlock() + + // Publish with group only (no topic overlap) + hub.Publish(Event{ + Type: "admin-alert", + Data: "alert", + Group: map[string]string{"role": "admin"}, + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + select { + case me := <-conn.send: + require.Contains(t, me.Data, "alert") + default: + t.Fatal("expected event via group match") + } +} + +func Test_SSE_RouteEvent_WildcardConn(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("wc1", []string{"orders.*"}, 10, time.Second) + + hub.mu.Lock() + hub.connections["wc1"] = conn + hub.wildcardConns["wc1"] = struct{}{} + hub.mu.Unlock() + + hub.Publish(Event{ + Type: "test", + Topics: []string{"orders.created"}, + Data: "wildcard-match", + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + select { + case me := <-conn.send: + require.Contains(t, me.Data, "wildcard-match") + default: + t.Fatal("wildcard connection should have received the event") + } +} + +func Test_SSE_RouteEvent_PausedSkipsNonInstant(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("p1", []string{"t"}, 10, time.Second) + conn.paused.Store(true) + + hub.mu.Lock() + hub.connections["p1"] = conn + hub.topicIndex["t"] = map[string]struct{}{"p1": {}} + hub.mu.Unlock() + + // P1 event should be skipped for paused connection + hub.Publish(Event{ + Type: "batch", + Topics: []string{"t"}, + Data: "batched", + Priority: PriorityBatched, + }) + + time.Sleep(100 * time.Millisecond) + + require.Equal(t, 0, conn.dispatcher.pending()) + + // P0 (instant) should still deliver + hub.Publish(Event{ + Type: "urgent", + Topics: []string{"t"}, + Data: "instant", + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + select { + case me := <-conn.send: + require.Contains(t, me.Data, "instant") + default: + t.Fatal("P0 event should deliver to paused connection") + } +} + +func Test_SSE_RouteEvent_TTLExpired(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Publish an expired event + hub.Publish(Event{ + Type: "old", + Topics: []string{"t"}, + Data: "expired", + Priority: PriorityInstant, + TTL: time.Millisecond, + CreatedAt: time.Now().Add(-time.Second), + }) + + time.Sleep(100 * time.Millisecond) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsDropped) +} + +func Test_SSE_DeliverToConn_AllPriorities(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("dc", []string{"t"}, 10, time.Second) + + me := MarshaledEvent{ID: "e1", Data: "test"} + + // Test instant delivery + hub.deliverToConn(conn, &Event{Priority: PriorityInstant}, me) + select { + case <-conn.send: + default: + t.Fatal("instant event should be in send channel") + } + + // Test batched delivery + hub.deliverToConn(conn, &Event{Priority: PriorityBatched}, me) + require.Equal(t, 1, conn.dispatcher.pending()) + conn.dispatcher.WriteTo() + + // Test coalesced delivery + hub.deliverToConn(conn, &Event{Priority: PriorityCoalesced, Type: "progress", CoalesceKey: "k1"}, me) + require.Equal(t, 1, conn.dispatcher.pending()) + conn.dispatcher.WriteTo() + + // Test coalesced without explicit key — uses Type + hub.deliverToConn(conn, &Event{Priority: PriorityCoalesced, Type: "counter"}, me) + require.Equal(t, 1, conn.dispatcher.pending()) +} + +func Test_SSE_FlushAll(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{FlushInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("fl", []string{"t"}, 10, 50*time.Millisecond) + + hub.mu.Lock() + hub.connections["fl"] = conn + hub.topicIndex["t"] = map[string]struct{}{"fl": {}} + hub.mu.Unlock() + + // Add batched events to the coalescer + conn.dispatcher.AddEvent(MarshaledEvent{ID: "b1", Data: "batch1"}) + conn.dispatcher.AddEvent(MarshaledEvent{ID: "b2", Data: "batch2"}) + + // Wait for throttler to allow flush, then flush + time.Sleep(100 * time.Millisecond) + hub.flushAll() + + // Events should now be in the send channel + require.Len(t, conn.send, 2) +} + +func Test_SSE_FlushAll_TTLExpiry(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{FlushInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("fle", []string{"t"}, 10, 50*time.Millisecond) + + hub.mu.Lock() + hub.connections["fle"] = conn + hub.topicIndex["t"] = map[string]struct{}{"fle": {}} + hub.mu.Unlock() + + // Add an expired event to the coalescer + conn.dispatcher.AddEvent(MarshaledEvent{ + ID: "exp", + Data: "expired", + TTL: time.Millisecond, + CreatedAt: time.Now().Add(-time.Second), + }) + + time.Sleep(100 * time.Millisecond) + hub.flushAll() + + // Event should be dropped, not delivered + require.Empty(t, conn.send) + require.Equal(t, int64(1), hub.metrics.eventsDropped.Load()) +} + +func Test_SSE_SendHeartbeats(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{HeartbeatInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("hb", []string{"t"}, 10, time.Second) + // Set lastWrite to long ago + conn.lastWrite.Store(time.Now().Add(-time.Minute)) + + hub.mu.Lock() + hub.connections["hb"] = conn + hub.mu.Unlock() + + hub.sendHeartbeats() + + // Should have a heartbeat pending + select { + case <-conn.heartbeat: + default: + t.Fatal("expected heartbeat") + } +} + +func Test_SSE_SendHeartbeats_SkipsClosed(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{HeartbeatInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("closed-hb", []string{"t"}, 10, time.Second) + conn.lastWrite.Store(time.Now().Add(-time.Minute)) + conn.Close() + + hub.mu.Lock() + hub.connections["closed-hb"] = conn + hub.mu.Unlock() + + // Should not panic + hub.sendHeartbeats() +} + +func Test_SSE_RemoveConnection(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("rm", []string{"orders", "products"}, 10, time.Second) + + hub.addConnection(conn) + + stats := hub.Stats() + require.Equal(t, 1, stats.ActiveConnections) + require.Equal(t, 2, stats.TotalTopics) + + hub.removeConnection(conn) + + stats = hub.Stats() + require.Equal(t, 0, stats.ActiveConnections) + require.Equal(t, 0, stats.TotalTopics) + + // Remove again should be no-op + hub.removeConnection(conn) +} + +func Test_SSE_RemoveConnection_Wildcard(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("rmw", []string{"orders.*"}, 10, time.Second) + + hub.addConnection(conn) + + hub.mu.RLock() + _, hasWildcard := hub.wildcardConns["rmw"] + hub.mu.RUnlock() + require.True(t, hasWildcard) + + hub.removeConnection(conn) + + hub.mu.RLock() + _, hasWildcard = hub.wildcardConns["rmw"] + hub.mu.RUnlock() + require.False(t, hasWildcard) +} + +func Test_SSE_Publish_BufferFull(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Fill the event buffer (size 1024) + for range 2000 { + hub.Publish(Event{Type: "flood", Topics: []string{"t"}, Data: "x"}) + } + + time.Sleep(100 * time.Millisecond) + stats := hub.Stats() + // At least some events must have landed in the pipeline AND some must + // have been dropped by the non-blocking `default:` branch in Publish. + // Asserting only EventsPublished > 0 would let a regression that makes + // Publish blocking pass silently — the drop counter is the actual + // invariant this test exists to pin. + require.Positive(t, stats.EventsPublished) + require.Positive(t, stats.EventsDropped) +} + +// testReplayer is a minimal in-memory Replayer implementation used by tests +// that exercise hub.replayEvents and hub.initStream. The production +// MemoryReplayer has been removed from the library surface. +type testReplayer struct { + entries []testReplayEntry + mu sync.RWMutex +} + +type testReplayEntry struct { + topics []string + event MarshaledEvent +} + +//nolint:gocritic // hugeParam: signature must match Replayer interface +func (r *testReplayer) Store(event MarshaledEvent, topics []string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.entries = append(r.entries, testReplayEntry{event: event, topics: topics}) + return nil +} + +func (r *testReplayer) Replay(lastEventID string, topics []string) ([]MarshaledEvent, error) { + if lastEventID == "" { + return nil, nil + } + r.mu.RLock() + defer r.mu.RUnlock() + // Find the index of lastEventID + idx := -1 + for i, entry := range r.entries { + if entry.event.ID == lastEventID { + idx = i + break + } + } + if idx == -1 { + return nil, nil + } + var out []MarshaledEvent + for _, entry := range r.entries[idx+1:] { + if topicsOverlap(entry.topics, topics) { + out = append(out, entry.event) + } + } + return out, nil +} + +func (r *testReplayer) count() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.entries) +} + +func topicsOverlap(a, b []string) bool { + for _, x := range a { + for _, y := range b { + if topicMatch(y, x) || topicMatch(x, y) || x == y { + return true + } + } + } + return false +} + +func Test_SSE_ReplayEvents(t *testing.T) { + t.Parallel() + + replayer := &testReplayer{} + require.NoError(t, replayer.Store(MarshaledEvent{ID: "r1", Data: "d1", Retry: -1}, []string{"t"})) + require.NoError(t, replayer.Store(MarshaledEvent{ID: "r2", Data: "d2", Retry: -1}, []string{"t"})) + + _, hub := NewWithHub(Config{Replayer: replayer}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("replay-conn", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + err := hub.replayEvents(w, conn, "r1") + require.NoError(t, err) + require.Contains(t, buf.String(), "id: r2") +} + +func Test_SSE_ReplayEvents_NoReplayer(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("no-replay", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + err := hub.replayEvents(w, conn, "some-id") + require.NoError(t, err) + require.Empty(t, buf.String()) +} + +func Test_SSE_InitStream(t *testing.T) { + t.Parallel() + + replayer := &testReplayer{} + require.NoError(t, replayer.Store(MarshaledEvent{ID: "i1", Data: "d1", Retry: -1}, []string{"t"})) + require.NoError(t, replayer.Store(MarshaledEvent{ID: "i2", Data: "d2", Retry: -1}, []string{"t"})) + + _, hub := NewWithHub(Config{Replayer: replayer}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("init-conn", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + err := hub.initStream(w, conn, "i1") + require.NoError(t, err) + + output := buf.String() + require.Contains(t, output, "retry: 3000") + require.Contains(t, output, "id: i2") + require.Contains(t, output, `event: connected`) +} + +func Test_SSE_RouteEvent_ReplayerStore(t *testing.T) { + t.Parallel() + + replayer := &testReplayer{} + _, hub := NewWithHub(Config{Replayer: replayer}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Publish a non-group event — should be stored in replayer + hub.Publish(Event{Type: "test", Topics: []string{"t"}, Data: "stored"}) + time.Sleep(100 * time.Millisecond) + + // Publish a group event — should NOT be stored in replayer + hub.Publish(Event{ + Type: "test", + Topics: []string{"t"}, + Data: "not-stored", + Group: map[string]string{"tenant_id": "t_1"}, + }) + time.Sleep(100 * time.Millisecond) + + // The replayer should only have 1 event (the non-group one) + require.Equal(t, 1, replayer.count()) +} + +func Test_SSE_Shutdown_Timeout(t *testing.T) { + // Not parallel. Previously this test used t.Parallel() and returned + // without waiting for hub.run() to drain — run() is still alive inside + // <-h.shutdown executing broadcastShutdown + time.Sleep(drainDelay) + // (~200ms) and would outlive the test, mutating hub.connections under + // mu.Lock concurrently with other parallel tests. Running serial and + // awaiting hub.stopped eliminates that cross-test goroutine leak. + + _, hub := NewWithHub() + + // Pre-cancel context — Shutdown should surface ctx.Err(). + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := hub.Shutdown(ctx) + require.Error(t, err, "Shutdown with canceled ctx must return an error") + require.ErrorIs(t, err, context.Canceled) + + // Wait for the run loop to actually exit so we don't leak the + // goroutine into subsequent tests. Bounded wait so a regression that + // never closes `stopped` fails loudly. + select { + case <-hub.stopped: + case <-time.After(2 * time.Second): + t.Fatal("hub.run() did not exit after Shutdown with canceled ctx") + } +} + +func Benchmark_SSE_Publish(b *testing.B) { + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(b, hub.Shutdown(ctx)) + }() + + event := Event{ + Type: "test", + Topics: []string{"benchmark"}, + Data: "hello", + } + + b.ResetTimer() + for b.Loop() { + hub.Publish(event) + } +} + +func Benchmark_SSE_TopicMatch(b *testing.B) { + b.ResetTimer() + for b.Loop() { + topicMatch("notifications.*", "notifications.orders") + } +} + +func Benchmark_SSE_TopicMatch_Exact(b *testing.B) { + b.ResetTimer() + for b.Loop() { + topicMatch("notifications.orders", "notifications.orders") + } +} + +func Benchmark_SSE_MarshalEvent(b *testing.B) { + event := &Event{ + Type: "test", + Data: map[string]string{"key": "value", "foo": "bar"}, + } + + b.ResetTimer() + for b.Loop() { + marshalEvent(event) + } +} + +func Benchmark_SSE_WriteTo(b *testing.B) { + me := MarshaledEvent{ + ID: "evt_1", + Type: "test", + Data: `{"key":"value"}`, + } + + w := bufio.NewWriter(io.Discard) + + b.ResetTimer() + for b.Loop() { + me.WriteTo(w) //nolint:errcheck // benchmark: error irrelevant for perf measurement + } +} + +func Benchmark_SSE_Coalescer(b *testing.B) { + c := newDispatcher(time.Second) + me := MarshaledEvent{ID: "1", Data: "test"} + + b.ResetTimer() + for b.Loop() { + c.AddState("key", me) + c.WriteTo() + } +} + +func Benchmark_SSE_GenerateID(b *testing.B) { + b.ResetTimer() + for b.Loop() { + generateID() + } +} + +// mockBridge implements SubscriberBridge for testing. +type mockBridge struct { + onSubscribe func(ctx context.Context, channel string, onMessage func(string)) error +} + +func (m *mockBridge) Subscribe(ctx context.Context, channel string, onMessage func(string)) error { + return m.onSubscribe(ctx, channel, onMessage) +} + +func Test_SSE_Bridge_Publishes(t *testing.T) { + t.Parallel() + + delivered := make(chan string, 1) + bridge := &mockBridge{ + onSubscribe: func(ctx context.Context, _ string, onMessage func(string)) error { + onMessage("test-payload") + delivered <- "ok" + <-ctx.Done() + return ctx.Err() + }, + } + + _, hub := NewWithHub(Config{ + Bridges: []BridgeConfig{{ + Subscriber: bridge, + Channel: "test-channel", + EventType: "notification", + }}, + }) + + select { + case <-delivered: + case <-time.After(2 * time.Second): + t.Fatal("bridge did not deliver message in time") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) +} + +func Test_SSE_Bridge_CancelsOnShutdown(t *testing.T) { + t.Parallel() + + subscribed := make(chan struct{}, 1) + bridge := &mockBridge{ + onSubscribe: func(ctx context.Context, _ string, _ func(string)) error { + subscribed <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + _, hub := NewWithHub(Config{ + Bridges: []BridgeConfig{{ + Subscriber: bridge, + Channel: "ch", + EventType: "evt", + }}, + }) + + select { + case <-subscribed: + case <-time.After(2 * time.Second): + t.Fatal("bridge did not subscribe in time") + } + // Shutdown should cancel the bridge context and wait for the goroutine. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_Bridge_Multiple(t *testing.T) { + t.Parallel() + + delivered := make(chan struct{}, 2) + bridge := &mockBridge{ + onSubscribe: func(ctx context.Context, channel string, onMessage func(string)) error { + onMessage("msg-from-" + channel) + delivered <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + _, hub := NewWithHub(Config{ + Bridges: []BridgeConfig{ + {Subscriber: bridge, Channel: "ch1", EventType: "e1"}, + {Subscriber: bridge, Channel: "ch2", EventType: "e2"}, + }, + }) + + for range 2 { + select { + case <-delivered: + case <-time.After(2 * time.Second): + t.Fatal("bridge did not deliver both messages in time") + } + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + + stats := hub.Stats() + require.Equal(t, int64(2), stats.EventsPublished) +} + +func Test_SSE_Bridge_Transform(t *testing.T) { + t.Parallel() + + done := make(chan struct{}, 1) + bridge := &mockBridge{ + onSubscribe: func(ctx context.Context, _ string, onMessage func(string)) error { + onMessage("raw-data") + done <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + _, hub := NewWithHub(Config{ + Bridges: []BridgeConfig{{ + Subscriber: bridge, + Channel: "ch", + EventType: "default", + Transform: func(payload string) *Event { + return &Event{ + Type: "transformed", + Data: "transformed:" + payload, + Topics: []string{"custom-topic"}, + } + }, + }}, + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bridge did not deliver in time") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) +} + +func Test_SSE_Bridge_TransformNilSkipsMessage(t *testing.T) { + t.Parallel() + + done := make(chan struct{}, 1) + bridge := &mockBridge{ + onSubscribe: func(ctx context.Context, _ string, onMessage func(string)) error { + onMessage("skip-this") + done <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + _, hub := NewWithHub(Config{ + Bridges: []BridgeConfig{{ + Subscriber: bridge, + Channel: "ch", + EventType: "evt", + Transform: func(_ string) *Event { + return nil + }, + }}, + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("bridge did not deliver in time") + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + + stats := hub.Stats() + require.Equal(t, int64(0), stats.EventsPublished) +} + +func Test_SSE_Bridge_RetriesOnError(t *testing.T) { + // Cannot run in parallel — we mutate the package-level bridgeRetryDelay + // so the retry loop is deterministic within the test's time budget. + original := bridgeRetryDelay + bridgeRetryDelay = 20 * time.Millisecond + t.Cleanup(func() { bridgeRetryDelay = original }) + + var attempts atomic.Int32 + secondAttemptBlocked := make(chan struct{}) + bridge := &mockBridge{ + onSubscribe: func(ctx context.Context, _ string, _ func(string)) error { + n := attempts.Add(1) + if n < 2 { + return errors.New("transient error") + } + // Second attempt reached — signal the test and block until + // Shutdown cancels the ctx. This proves the loop retried + // past the error rather than exiting. + close(secondAttemptBlocked) + <-ctx.Done() + return ctx.Err() + }, + } + + _, hub := NewWithHub(Config{ + Bridges: []BridgeConfig{{ + Subscriber: bridge, + Channel: "ch", + EventType: "e", + }}, + }) + + // Wait for the retry to actually happen. Before reaching here, + // attempts must be exactly 2: one error + one in-progress call. + select { + case <-secondAttemptBlocked: + case <-time.After(2 * time.Second): + t.Fatalf("bridge did not retry after error (attempts=%d)", attempts.Load()) + } + require.Equal(t, int32(2), attempts.Load(), "expected exactly 2 Subscribe calls") + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_Bridge_BuildEvent_Defaults(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Non-transform: event built entirely from config defaults. + cfg := &BridgeConfig{ + Channel: "ch", + EventType: "my-event", + CoalesceKey: "k", + TTL: 5 * time.Second, + Priority: PriorityCoalesced, + } + event := hub.buildBridgeEvent(cfg, "my-topic", "payload") + require.NotNil(t, event) + require.Equal(t, "my-event", event.Type) + require.Equal(t, "payload", event.Data) + require.Equal(t, []string{"my-topic"}, event.Topics) + require.Equal(t, PriorityCoalesced, event.Priority) + require.Equal(t, "k", event.CoalesceKey) + require.Equal(t, 5*time.Second, event.TTL) + + // Transform path: only missing Topics/Type filled from defaults. + cfgT := &BridgeConfig{ + EventType: "fallback-type", + Transform: func(_ string) *Event { + return &Event{Priority: PriorityInstant, Data: "x"} + }, + } + event = hub.buildBridgeEvent(cfgT, "fallback-topic", "raw") + require.NotNil(t, event) + require.Equal(t, "fallback-type", event.Type) + require.Equal(t, []string{"fallback-topic"}, event.Topics) + require.Equal(t, PriorityInstant, event.Priority) + + // Transform nil filters message. + cfgT2 := &BridgeConfig{ + EventType: "x", + Transform: func(_ string) *Event { return nil }, + } + event = hub.buildBridgeEvent(cfgT2, "default-topic", "x") + require.Nil(t, event) +} + +func Test_SSE_Bridge_PanicsWithoutSubscriber(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + _, _ = NewWithHub(Config{ + Bridges: []BridgeConfig{{ + Channel: "ch", + EventType: "e", + }}, + }) + }) +} + +// ────────────────────────────────────────────────────────────────────────────── +// Coverage boosters — targeted tests for previously-uncovered branches. +// ────────────────────────────────────────────────────────────────────────────── + +func Test_SSE_Publish_DropsDuringDrain(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + // Flip the drain flag and Publish — event must be counted as dropped, + // not published. Exercises the early-return branch in Publish. + hub.draining.Store(true) + hub.Publish(Event{Type: "x", Topics: []string{"t"}, Data: "d"}) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsDropped) + require.Equal(t, int64(0), stats.EventsPublished) + + hub.draining.Store(false) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_Publish_StampsCreatedAtWhenTTLSet(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // TTL > 0 with zero CreatedAt — Publish must stamp CreatedAt so + // routeEvent can compute age correctly. + before := time.Now() + hub.Publish(Event{ + Type: "x", + Topics: []string{"t"}, + Data: "d", + TTL: time.Second, + }) + time.Sleep(30 * time.Millisecond) + + // No direct getter for the enqueued event; just assert the + // corresponding counter to ensure we hit the enqueue branch. + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) + require.Less(t, time.Since(before), time.Second) +} + +func Test_SSE_WriteRetry_SkipsNonPositive(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + require.NoError(t, writeRetry(&buf, 0)) + require.NoError(t, writeRetry(&buf, -42)) + require.Empty(t, buf.String(), "non-positive ms must not emit a retry: directive") + + require.NoError(t, writeRetry(&buf, 1500)) + require.Contains(t, buf.String(), "retry: 1500\n") +} + +func Test_SSE_TrackEventType_EmptyDefaultsToMessage(t *testing.T) { + t.Parallel() + + m := &hubMetrics{eventsByType: make(map[string]*atomic.Int64)} + m.trackEventType("") + m.trackEventType("") + m.trackEventType("custom") + + snap := m.snapshotEventsByType() + require.Equal(t, int64(2), snap["message"], "empty event type falls back to \"message\"") + require.Equal(t, int64(1), snap["custom"]) +} + +func Test_SSE_MatchGroupConns_EmptyGroupIsNoOp(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // With no Group set on the event, matchGroupConns must short-circuit + // without scanning connections — the early-return branch. + seen := make(map[string]struct{}) + hub.mu.RLock() + hub.matchGroupConns(&Event{Type: "x", Topics: []string{"t"}}, seen) + hub.mu.RUnlock() + require.Empty(t, seen) +} + +func Test_SSE_WatchLifetime_NoOpWhenDisabled(t *testing.T) { + t.Parallel() + + // MaxLifetime <= 0 must leave watchLifetime as a no-op (no goroutine + // spawned, no eventual Close on the connection). + _, hub := NewWithHub(Config{MaxLifetime: -1}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("c1", []string{"t"}, 8, 100*time.Millisecond) + hub.watchLifetime(conn) + + // If watchLifetime spawned a goroutine it would close `conn.done` + // eventually (with MaxLifetime<=0 it must NOT). Allow plenty of + // scheduler time then assert conn is still alive. + time.Sleep(50 * time.Millisecond) + require.False(t, conn.IsClosed(), "watchLifetime must not close the conn when MaxLifetime<=0") +} + +func Test_SSE_ReplayEvents_NoReplayerOrEmptyLastEventID(t *testing.T) { + t.Parallel() + + // nil Replayer OR empty Last-Event-ID — both branches return nil + // without touching the writer. + hub := &Hub{cfg: Config{Replayer: nil}} + conn := newConnection("c1", []string{"t"}, 8, 100*time.Millisecond) + var buf bytes.Buffer + require.NoError(t, hub.replayEvents(bufio.NewWriter(&buf), conn, "some-id")) + require.Empty(t, buf.String()) + + hub2 := &Hub{cfg: Config{Replayer: &testReplayer{}}} + require.NoError(t, hub2.replayEvents(bufio.NewWriter(&buf), conn, "")) + require.Empty(t, buf.String()) +} + +// failingWriter writes `limit` bytes successfully then returns errWrite on +// subsequent writes. Used to hit the error branches in initStream, replayEvents, +// sendConnectedEvent, and writeLoop without spinning up a real TCP listener. +type failingWriter struct { + err error + written int + limit int +} + +func (fw *failingWriter) Write(p []byte) (int, error) { + if fw.err != nil && fw.written >= fw.limit { + return 0, fw.err + } + fw.written += len(p) + return len(p), nil +} + +func Test_SSE_InitStream_PropagatesWriteErrors(t *testing.T) { + t.Parallel() + + // Fail on the very first write so writeRetry returns an error — + // exercises initStream's first `if err != nil { return err }` branch. + hub := &Hub{cfg: Config{RetryMS: 3000}} + conn := newConnection("c1", []string{"t"}, 8, 100*time.Millisecond) + fw := &failingWriter{limit: 0, err: errors.New("forced write error")} + err := hub.initStream(bufio.NewWriter(fw), conn, "") + require.Error(t, err) +} + +func Test_SSE_ReplayEvents_HandlesReplayerError(t *testing.T) { + t.Parallel() + + // Replayer returning an error must be treated as best-effort — caller + // gets nil and no events are written. + hub := &Hub{cfg: Config{Replayer: &errorReplayer{err: errors.New("store down")}}} + conn := newConnection("c1", []string{"t"}, 8, 100*time.Millisecond) + var buf bytes.Buffer + require.NoError(t, hub.replayEvents(bufio.NewWriter(&buf), conn, "last-id")) + require.Empty(t, buf.String()) +} + +type errorReplayer struct{ err error } + +func (*errorReplayer) Store(MarshaledEvent, []string) error { return nil } +func (e *errorReplayer) Replay(string, []string) ([]MarshaledEvent, error) { + return nil, e.err +} + +func Test_SSE_ReplayEvents_WritesEventsAndFlushes(t *testing.T) { + t.Parallel() + + // Replayer returning events must produce written frames terminated + // with a flush — exercises the write-and-flush branch. + r := &testReplayer{} + // testReplayer.Replay returns entries AFTER the lastEventID marker + // entry — store the marker first, then the two we want replayed. + require.NoError(t, r.Store(MarshaledEvent{ID: "last", Type: "test", Data: "anchor"}, []string{"t"})) + require.NoError(t, r.Store(MarshaledEvent{ID: "e1", Type: "test", Data: "one"}, []string{"t"})) + require.NoError(t, r.Store(MarshaledEvent{ID: "e2", Type: "test", Data: "two"}, []string{"t"})) + + hub := &Hub{cfg: Config{Replayer: r}} + conn := newConnection("c1", []string{"t"}, 8, 100*time.Millisecond) + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + require.NoError(t, hub.replayEvents(bw, conn, "last")) + require.NoError(t, bw.Flush()) + out := buf.String() + require.Contains(t, out, "id: e1\n") + require.Contains(t, out, "id: e2\n") + require.Contains(t, out, "data: one\n") + require.Contains(t, out, "data: two\n") +} + +func Test_SSE_SendConnectedEvent_PropagatesWriteError(t *testing.T) { + t.Parallel() + + conn := newConnection("c1", []string{"t"}, 8, 100*time.Millisecond) + fw := &failingWriter{limit: 0, err: errors.New("no space")} + err := sendConnectedEvent(bufio.NewWriter(fw), conn) + require.Error(t, err) +} + +func Test_SSE_WriteLoop_HeartbeatFlush(t *testing.T) { + t.Parallel() + + // Drive writeLoop: fire a heartbeat then a real event, then close, + // so we hit the heartbeat branch, the normal event branch, and the + // done-exit branch — previously uncovered paths in writeLoop. + conn := newConnection("c1", []string{"t"}, 8, 50*time.Millisecond) + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + + done := make(chan struct{}) + go func() { + conn.writeLoop(bw) + close(done) + }() + + conn.sendHeartbeat() + // Heartbeat fires, wait briefly for flush. + time.Sleep(30 * time.Millisecond) + + conn.trySend(MarshaledEvent{ID: "evt_x", Type: "test", Data: "hi"}) + time.Sleep(30 * time.Millisecond) + + conn.Close() + <-done + + require.NoError(t, bw.Flush()) + out := buf.String() + require.Contains(t, out, ": heartbeat\n", "heartbeat comment present") + require.Contains(t, out, "data: hi\n", "real event data present") + require.Equal(t, int64(1), conn.MessagesSent.Load(), "real event counted once") + require.Equal(t, "evt_x", conn.LastEventID.Load()) +} diff --git a/middleware/sse/stats.go b/middleware/sse/stats.go new file mode 100644 index 00000000000..f0202c2e074 --- /dev/null +++ b/middleware/sse/stats.go @@ -0,0 +1,74 @@ +package sse + +import ( + "sync" + "sync/atomic" +) + +// HubStats provides a snapshot of the hub's current state. +type HubStats struct { + // ConnectionsByTopic maps each topic to its subscriber count. + ConnectionsByTopic map[string]int `json:"connections_by_topic"` + + // EventsByType maps each SSE event type to its lifetime count. + EventsByType map[string]int64 `json:"events_by_type"` + + // EventsPublished is the lifetime count of events published to the hub. + EventsPublished int64 `json:"events_published"` + + // EventsDropped is the lifetime count of events dropped before delivery, for any reason. + EventsDropped int64 `json:"events_dropped"` + + // ActiveConnections is the total number of open SSE connections. + ActiveConnections int `json:"active_connections"` + + // TotalTopics is the number of unique topics with at least one subscriber. + TotalTopics int `json:"total_topics"` +} + +// hubMetrics tracks lifetime counters for the hub. +type hubMetrics struct { + eventsByType map[string]*atomic.Int64 + eventsByTypeMu sync.RWMutex + eventsPublished atomic.Int64 + eventsDropped atomic.Int64 +} + +// trackEventType increments the counter for a specific event type. +func (m *hubMetrics) trackEventType(eventType string) { + if eventType == "" { + eventType = "message" + } + + m.eventsByTypeMu.RLock() + counter, ok := m.eventsByType[eventType] + m.eventsByTypeMu.RUnlock() + + if ok { + counter.Add(1) + return + } + + m.eventsByTypeMu.Lock() + if counter, ok = m.eventsByType[eventType]; ok { + m.eventsByTypeMu.Unlock() + counter.Add(1) + return + } + counter = &atomic.Int64{} + counter.Add(1) + m.eventsByType[eventType] = counter + m.eventsByTypeMu.Unlock() +} + +// snapshotEventsByType returns a copy of the per-event-type counters. +func (m *hubMetrics) snapshotEventsByType() map[string]int64 { + m.eventsByTypeMu.RLock() + defer m.eventsByTypeMu.RUnlock() + + result := make(map[string]int64, len(m.eventsByType)) + for k, v := range m.eventsByType { + result[k] = v.Load() + } + return result +} diff --git a/middleware/sse/throttle.go b/middleware/sse/throttle.go new file mode 100644 index 00000000000..d259c0f0adc --- /dev/null +++ b/middleware/sse/throttle.go @@ -0,0 +1,91 @@ +package sse + +import ( + "sync" + "time" +) + +// adaptiveThrottler monitors per-connection buffer saturation and adjusts +// the effective flush interval. Connections with high buffer usage get +// longer flush intervals (fewer sends), reducing backpressure. +type adaptiveThrottler struct { + lastFlush map[string]time.Time + mu sync.Mutex + baseInterval time.Duration + minInterval time.Duration + maxInterval time.Duration +} + +func newAdaptiveThrottler(baseInterval time.Duration) *adaptiveThrottler { + // Start with the "nice" bounds, then tighten them so the invariant + // min <= base <= max holds for any baseInterval. Without the extra + // clamp, extreme configs (e.g. baseInterval < 25ms or > 10s) would + // invert the throttling policy — saturated connections flushing + // faster than idle ones. + minInt := max(baseInterval/4, 100*time.Millisecond) + maxInt := min(baseInterval*4, 10*time.Second) + if minInt > baseInterval { + minInt = baseInterval + } + if maxInt < baseInterval { + maxInt = baseInterval + } + return &adaptiveThrottler{ + lastFlush: make(map[string]time.Time), + baseInterval: baseInterval, + minInterval: minInt, + maxInterval: maxInt, + } +} + +// effectiveInterval calculates the flush interval for a connection based +// on its buffer saturation (0.0 = empty, 1.0 = full). +func (at *adaptiveThrottler) effectiveInterval(saturation float64) time.Duration { + switch { + case saturation > 0.8: + return at.maxInterval + case saturation > 0.5: + return at.baseInterval * 2 + case saturation < 0.1: + return at.minInterval + default: + return at.baseInterval + } +} + +// shouldFlush returns true if enough time has passed since the last flush. +func (at *adaptiveThrottler) shouldFlush(connID string, saturation float64) bool { + at.mu.Lock() + defer at.mu.Unlock() + + interval := at.effectiveInterval(saturation) + last, ok := at.lastFlush[connID] + if !ok { + at.lastFlush[connID] = time.Now() + return true + } + + if time.Since(last) >= interval { + at.lastFlush[connID] = time.Now() + return true + } + return false +} + +// remove cleans up tracking for a disconnected connection. +func (at *adaptiveThrottler) remove(connID string) { + at.mu.Lock() + delete(at.lastFlush, connID) + at.mu.Unlock() +} + +// cleanup removes stale entries older than the given cutoff. +func (at *adaptiveThrottler) cleanup(cutoff time.Time) { + at.mu.Lock() + defer at.mu.Unlock() + for k, v := range at.lastFlush { + if v.Before(cutoff) { + delete(at.lastFlush, k) + } + } +} diff --git a/middleware/sse/topic.go b/middleware/sse/topic.go new file mode 100644 index 00000000000..dd1f87f4cd1 --- /dev/null +++ b/middleware/sse/topic.go @@ -0,0 +1,61 @@ +package sse + +import ( + "strings" +) + +// topicMatch checks if a subscription pattern matches a concrete topic. +// Supports NATS-style wildcards: +// +// - * matches exactly one segment (between dots) +// - > matches one or more trailing segments (must be last token) +// - No wildcards = exact match +// +// Examples: +// +// topicMatch("notifications.*", "notifications.orders") → true +// topicMatch("notifications.*", "notifications.orders.new") → false +// topicMatch("analytics.>", "analytics.live") → true +// topicMatch("analytics.>", "analytics.live.visitors") → true +func topicMatch(pattern, topic string) bool { + if !strings.ContainsAny(pattern, "*>") { + return pattern == topic + } + + patParts := strings.Split(pattern, ".") + topParts := strings.Split(topic, ".") + + for i, pp := range patParts { + switch pp { + case ">": + // > must be the last token and matches 1+ remaining segments + return i == len(patParts)-1 && i < len(topParts) + case "*": + if i >= len(topParts) { + return false + } + default: + if i >= len(topParts) || pp != topParts[i] { + return false + } + } + } + + return len(patParts) == len(topParts) +} + +// topicMatchesAny returns true if the concrete topic matches any of the patterns. +func topicMatchesAny(patterns []string, topic string) bool { + for _, p := range patterns { + if topicMatch(p, topic) { + return true + } + } + return false +} + +// connMatchesTopic returns true if a connection's subscription patterns +// match the given concrete topic. +func connMatchesTopic(conn *Connection, topic string) bool { + return topicMatchesAny(conn.Topics, topic) +}