diff --git a/.github/README.md b/.github/README.md index bbefbc1f550..8d86363fa42 100644 --- a/.github/README.md +++ b/.github/README.md @@ -739,6 +739,7 @@ Here is a list of middleware that are included within the Fiber framework. | [favicon](https://github.com/gofiber/fiber/tree/main/middleware/favicon) | Ignore favicon from logs or serve from memory if a file path is provided. | | [healthcheck](https://github.com/gofiber/fiber/tree/main/middleware/healthcheck) | Liveness and Readiness probes for Fiber. | | [helmet](https://github.com/gofiber/fiber/tree/main/middleware/helmet) | Helps secure your apps by setting various HTTP headers. | +| [hostauthorization](https://github.com/gofiber/fiber/tree/main/middleware/hostauthorization) | Validates the Host header against a configurable allowlist, protecting against DNS rebinding attacks. | | [idempotency](https://github.com/gofiber/fiber/tree/main/middleware/idempotency) | Allows for fault-tolerant APIs where duplicate requests do not erroneously cause the same action performed multiple times on the server-side. | | [keyauth](https://github.com/gofiber/fiber/tree/main/middleware/keyauth) | Adds support for key based authentication. | | [limiter](https://github.com/gofiber/fiber/tree/main/middleware/limiter) | Adds Rate-limiting support to Fiber. Use to limit repeated requests to public APIs and/or endpoints such as password reset. | diff --git a/docs/middleware/hostauthorization.md b/docs/middleware/hostauthorization.md new file mode 100644 index 00000000000..990e65087c7 --- /dev/null +++ b/docs/middleware/hostauthorization.md @@ -0,0 +1,224 @@ +--- +id: hostauthorization +--- + +# Host Authorization + +Host authorization middleware for [Fiber](https://github.com/gofiber/fiber) that validates the incoming `Host` header against a configurable allowlist. Protects against [DNS rebinding attacks](https://en.wikipedia.org/wiki/DNS_rebinding) where an attacker-controlled domain resolves to the application's internal IP, causing browsers to send requests with a malicious Host header. + +## Signatures + +```go +func New(config ...Config) fiber.Handler +``` + +## Examples + +Import the middleware package: + +```go +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/hostauthorization" +) +``` + +Once your Fiber app is initialized, choose one of the following approaches: + +### Basic Usage + +```go +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"api.myapp.com"}, +})) + +app.Get("/users", func(c fiber.Ctx) error { + return c.JSON(getUsers()) +}) + +// Host: api.myapp.com → 200 OK +// Host: evil.com → 403 Forbidden +``` + +### Subdomain Wildcards + +A `*.` prefix matches any subdomain but **not** the bare domain itself: + +```go +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"*.myapp.com"}, +})) + +// Host: api.myapp.com → 200 OK +// Host: www.myapp.com → 200 OK +// Host: myapp.com → 403 Forbidden +``` + +To allow both the bare domain and all subdomains, include both: + +```go +AllowedHosts: []string{"myapp.com", "*.myapp.com"}, +``` + +### Internationalized Domain Names (IDN) + +Browsers always transmit the `Host` header in ASCII (Punycode) form, so IDN entries in `AllowedHosts` are converted to Punycode at startup. You can configure entries in either form — they are equivalent: + +```go +AllowedHosts: []string{"münchen.example.com"} // Unicode +AllowedHosts: []string{"xn--mnchen-3ya.example.com"} // Punycode (what the browser sends) +``` + +Both match an incoming request whose Host header is `xn--mnchen-3ya.example.com`. + +### Skipping Health Checks + +Use `Next` to bypass host validation for specific paths: + +```go +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"myapp.com", "*.myapp.com"}, + Next: func(c fiber.Ctx) bool { + return c.Path() == "/healthz" + }, +})) + +// Host: evil.com GET /healthz → 200 OK (skipped) +// Host: evil.com GET /users → 403 Forbidden +``` + +### Dynamic Validation + +Use `AllowedHostsFunc` for hosts that can't be known at startup: + +```go +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHostsFunc: func(host string) bool { + // Look up tenant domains from database, cache, etc. + return isRegisteredTenant(host) + }, +})) +``` + +`AllowedHostsFunc` is only called when static `AllowedHosts` don't match, so you can combine both: + +```go +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"myapp.com", "*.myapp.com"}, + AllowedHostsFunc: func(host string) bool { + return isRegisteredCustomDomain(host) + }, +})) +``` + +### Custom Error Response + +The default response is **403 Forbidden**. **421 Misdirected Request** ([RFC 9110 §15.5.20](https://www.rfc-editor.org/rfc/rfc9110#section-15.5.20)) is a semantically closer choice for "wrong host for this server" — CDNs like Cloudflare and Fastly use it for this case. Either is reasonable; pick one via `ErrorHandler`: + +```go +// 403 with a JSON body +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"myapp.com"}, + ErrorHandler: func(c fiber.Ctx, err error) error { + return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ + "error": "unauthorized host", + }) + }, +})) + +// 421 Misdirected Request — closer to the RFC-defined semantics +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"myapp.com"}, + ErrorHandler: func(c fiber.Ctx, _ error) error { + return c.SendStatus(fiber.StatusMisdirectedRequest) // 421 + }, +})) +``` + +### Combined with Domain() Router + +`hostauthorization` acts as a security gate; [`Domain()`](https://docs.gofiber.io) handles routing: + +```go +// Security layer — reject anything not from our hosts +app.Use(hostauthorization.New(hostauthorization.Config{ + AllowedHosts: []string{"myapp.com", "*.myapp.com"}, + Next: func(c fiber.Ctx) bool { + return c.Path() == "/healthz" + }, +})) + +// Routing layer — direct allowed hosts to the right handlers +app.Domain("api.myapp.com").Get("/users", listUsers) +app.Domain(":tenant.myapp.com").Get("/dashboard", tenantDashboard) +app.Get("/healthz", healthCheck) +``` + +## Config + +| Property | Type | Description | Default | +|:-----------------|:------------------------------|:--------------------------------------------------------------------------------------------------|:--------| +| Next | `func(fiber.Ctx) bool` | Defines a function to skip this middleware when returned true. | `nil` | +| AllowedHosts | `[]string` | List of permitted hosts. Supports exact match and subdomain wildcard (`*.example.com`). | `nil` | +| AllowedHostsFunc | `func(string) bool` | Dynamic validator called only when no static AllowedHosts rule matches. Receives the normalized hostname: port stripped, trailing dot removed, IPv6 brackets removed, lowercased, IDN converted to Punycode. | `nil` | +| ErrorHandler | `fiber.ErrorHandler` | Called when a request is rejected. Receives `ErrForbiddenHost` as the error. | 403 | + +Either `AllowedHosts` or `AllowedHostsFunc` (or both) must be provided. The middleware panics at startup if neither is set. + +## Default Config + +```go +var ConfigDefault = Config{} +``` + +There is no useful default — you must provide at least `AllowedHosts` or `AllowedHostsFunc`. + +## Host Matching + +The middleware matches hosts in this order: + +1. **Exact match** — case-insensitive, port and trailing dot stripped, IDN labels in Punycode form +2. **Subdomain wildcard** — `"*.myapp.com"` matches `api.myapp.com` but not `myapp.com` +3. **AllowedHostsFunc** — called only if no static rule matched + +The first match wins. If nothing matches, `ErrorHandler` is called. + +## Host Normalization + +Before matching, both incoming hosts and `AllowedHosts` entries are normalized at startup: + +- Port is stripped (`example.com:8080` → `example.com`) +- Trailing dot removed (`example.com.` → `example.com`) +- IPv6 brackets removed (`[::1]` → `::1`) +- Lowercased +- IDN labels converted to ASCII/Punycode (`münchen.example.com` → `xn--mnchen-3ya.example.com`) +- RFC 1035 length limits enforced at startup: ≤253 chars total, ≤63 chars per label (panic on violation) + +## Filtering by Client IP + +This middleware filters by the `Host` *header*, not by the client's source IP. To restrict access by client IP, use Fiber's [`TrustProxy` / `TrustProxyConfig`](https://docs.gofiber.io/whats_new#trusted-proxies) configuration — those are the correct knobs for IP allowlisting and CIDR ranges of trusted proxies. + +## Proxy Support + +The middleware uses Fiber's `c.Hostname()`, which respects `X-Forwarded-Host` when [`TrustProxy`](https://docs.gofiber.io/api/fiber#config) is enabled. When `TrustProxy` is disabled (the default), `X-Forwarded-Host` is ignored and the raw `Host` header is used. + +fasthttp itself is HTTP/1.x only. HTTP/2 support requires an external library (e.g. `fasthttp2`) plugged in via `Server.NextProto`. Those libraries are responsible for mapping the HTTP/2 `:authority` pseudo-header to a Host value before the request reaches Fiber handlers, so the middleware should work transparently once H2 is wired up — but this is the H2 library's responsibility, not fasthttp's or this middleware's. + +## RFC Compliance + +- **RFC 9110 Section 7.2** — Host and port are separate components; port is stripped before matching +- **RFC 9110 Section 17.1** — Origin servers should reject misdirected requests +- **RFC 9112 Section 3.2** — Requests with missing Host headers should be rejected +- **RFC 1035** — `AllowedHosts` entries are validated against the 253-char total / 63-char per-label limits +- Returns **403 Forbidden** (not 400) because the request is syntactically valid but semantically unauthorized + +:::note +**RFC 9110 §15.5.20** defines **421 Misdirected Request** as a semantically closer response for host mismatches ("the request was directed at a server unable or unwilling to produce an authoritative response for the target URI"). CDNs like Cloudflare and Fastly use 421 for this case. To use 421 instead of 403, set a custom `ErrorHandler`: + +```go +ErrorHandler: func(c fiber.Ctx, err error) error { + return c.SendStatus(fiber.StatusMisdirectedRequest) // 421 +}, +``` + +::: diff --git a/middleware/hostauthorization/config.go b/middleware/hostauthorization/config.go new file mode 100644 index 00000000000..3ad8b2a68e7 --- /dev/null +++ b/middleware/hostauthorization/config.go @@ -0,0 +1,67 @@ +package hostauthorization + +import ( + "errors" + + "github.com/gofiber/fiber/v3" +) + +// ErrForbiddenHost is returned when the Host header does not match any allowed host. +var ErrForbiddenHost = errors.New("hostauthorization: forbidden host") + +// Config defines the config for the host authorization middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // Use this to exclude health check endpoints or other paths from host validation. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // AllowedHostsFunc is a dynamic validator called only when no static + // AllowedHosts rule matches. Receives the normalized hostname: port stripped, + // trailing dot removed, IPv6 brackets removed, lowercased. + // Return true to allow. + // + // Optional. Default: nil + AllowedHostsFunc func(host string) bool + + // ErrorHandler is called when a request is rejected. + // Receives ErrForbiddenHost as the error. + // + // Optional. Default: returns 403 Forbidden. + ErrorHandler fiber.ErrorHandler + + // AllowedHosts is the list of permitted host values. + // Supports two match types: + // - Exact: "api.myapp.com" + // - Subdomain: "*.myapp.com" (matches any subdomain, NOT the bare domain — list both for apex+subdomains) + // + // Entries are normalized at startup: port stripped, trailing dot removed, + // lowercased, IDN labels converted to Punycode, RFC 1035 length limits enforced + // (≤253 total / ≤63 per-label). + // + // Required if AllowedHostsFunc is nil. + AllowedHosts []string +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{} + +func configDefault(config ...Config) Config { + cfg := ConfigDefault + if len(config) > 0 { + cfg = config[0] + } + + if len(cfg.AllowedHosts) == 0 && cfg.AllowedHostsFunc == nil { + panic("hostauthorization: AllowedHosts or AllowedHostsFunc is required") + } + + if cfg.ErrorHandler == nil { + cfg.ErrorHandler = func(c fiber.Ctx, _ error) error { + return c.SendStatus(fiber.StatusForbidden) + } + } + + return cfg +} diff --git a/middleware/hostauthorization/hostauthorization.go b/middleware/hostauthorization/hostauthorization.go new file mode 100644 index 00000000000..75a3d9902aa --- /dev/null +++ b/middleware/hostauthorization/hostauthorization.go @@ -0,0 +1,165 @@ +package hostauthorization + +import ( + "fmt" + "net" + "strings" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/utils/v2" + utilsstrings "github.com/gofiber/utils/v2/strings" + "golang.org/x/net/idna" +) + +// RFC 1035 length limits. +const ( + maxDomainLength = 253 + maxLabelLength = 63 +) + +type parsedHosts struct { + exact map[string]struct{} + wildcardSuffixes []string +} + +// parseAllowedHosts splits AllowedHosts into exact and wildcard groups, +// normalizing entries (port strip, lowercase, IDN→Punycode) and enforcing +// RFC 1035 length limits. Panics on misconfiguration so it surfaces at startup. +func parseAllowedHosts(hosts []string) parsedHosts { + parsed := parsedHosts{ + exact: make(map[string]struct{}, len(hosts)), + } + + for _, h := range hosts { + h = utils.TrimSpace(h) + if h == "" { + continue + } + + // Reject the leading-dot form some other tools use; we want "*.example.com". + if len(h) > 1 && h[0] == '.' { + panic("hostauthorization: invalid host " + h + " — subdomain wildcards use the \"*.example.com\" form") + } + + isWildcard := strings.HasPrefix(h, "*.") + if isWildcard { + h = h[2:] + } + + h = normalizeHost(h) + if h == "" { + continue + } + + validateHostLength(h) + + if isWildcard { + // Stored with leading dot so the hot-path HasSuffix check stays alloc-free. + parsed.wildcardSuffixes = append(parsed.wildcardSuffixes, "."+h) + } else { + parsed.exact[h] = struct{}{} + } + } + + return parsed +} + +func validateHostLength(host string) { + if len(host) > maxDomainLength { + panic(fmt.Sprintf("hostauthorization: host %q exceeds RFC 1035 maximum of %d characters (%d chars)", + host, maxDomainLength, len(host))) + } + // IPv6 hosts contain colons and aren't dotted labels. + if strings.IndexByte(host, ':') >= 0 { + return + } + for label := range strings.SplitSeq(host, ".") { + if len(label) > maxLabelLength { + panic(fmt.Sprintf("hostauthorization: host %q has label %q exceeding RFC 1035 limit of %d characters (%d chars)", + host, label, maxLabelLength, len(label))) + } + } +} + +// normalizeHost strips port, trailing dot, and IPv6 brackets, lowercases, +// and converts IDN labels to Punycode (matching what browsers send). +func normalizeHost(host string) string { + // Fast path for plain hostnames — avoids net.SplitHostPort's error allocation. + if host != "" && host[0] != '[' && strings.IndexByte(host, ':') < 0 { + host = strings.TrimSuffix(host, ".") + host = utilsstrings.ToLower(host) + return toPunycode(host) + } + + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } else { + host = strings.TrimPrefix(host, "[") + host = strings.TrimSuffix(host, "]") + } + + host = strings.TrimSuffix(host, ".") + host = utilsstrings.ToLower(host) + return toPunycode(host) +} + +func toPunycode(host string) string { + if host == "" || strings.IndexByte(host, ':') >= 0 || isASCII(host) { + return host + } + if ascii, err := idna.Lookup.ToASCII(host); err == nil { + return ascii + } + // Non-convertible input falls through; it won't match any Punycode entry, + // which is the correct security default. + return host +} + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= 0x80 { + return false + } + } + return true +} + +// matchHost evaluates exact → wildcard → AllowedHostsFunc. +// The func is a fallback only — never called when a static rule matched. +func matchHost(host string, parsed parsedHosts, allowedHostsFunc func(string) bool) bool { + if _, ok := parsed.exact[host]; ok { + return true + } + + for _, suffix := range parsed.wildcardSuffixes { + if strings.HasSuffix(host, suffix) { + return true + } + } + + if allowedHostsFunc != nil && allowedHostsFunc(host) { + return true + } + + return false +} + +// New creates a new host authorization middleware handler. +func New(config ...Config) fiber.Handler { + cfg := configDefault(config...) + parsed := parseAllowedHosts(cfg.AllowedHosts) + + return func(c fiber.Ctx) error { + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + host := normalizeHost(c.Hostname()) + + if matchHost(host, parsed, cfg.AllowedHostsFunc) { + return c.Next() + } + + return cfg.ErrorHandler(c, ErrForbiddenHost) + } +} diff --git a/middleware/hostauthorization/hostauthorization_test.go b/middleware/hostauthorization/hostauthorization_test.go new file mode 100644 index 00000000000..ce1886982cb --- /dev/null +++ b/middleware/hostauthorization/hostauthorization_test.go @@ -0,0 +1,792 @@ +package hostauthorization + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_ConfigDefault(t *testing.T) { + t.Parallel() + + cfg := configDefault(Config{ + AllowedHosts: []string{"example.com"}, + }) + require.NotNil(t, cfg.ErrorHandler) + require.Equal(t, []string{"example.com"}, cfg.AllowedHosts) +} + +func Test_ConfigPanicNoHostsOrFunc(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + configDefault(Config{}) + }) +} + +func Test_ConfigPanicEmptySlice(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + configDefault(Config{ + AllowedHosts: []string{}, + }) + }) +} + +func Test_ConfigPanicNoArgs(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + configDefault() + }) +} + +func Test_ConfigAllowedHostsFuncOnly(t *testing.T) { + t.Parallel() + + cfg := configDefault(Config{ + AllowedHostsFunc: func(host string) bool { + return host == "example.com" + }, + }) + require.NotNil(t, cfg.AllowedHostsFunc) +} + +func Test_ConfigPanicHostExceedsRFC1035TotalLength(t *testing.T) { + t.Parallel() + + tooLong := strings.Repeat("a", 254) + require.Panics(t, func() { + New(Config{ + AllowedHosts: []string{tooLong}, + }) + }) +} + +func Test_ConfigPanicLeadingDotForm(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + New(Config{ + AllowedHosts: []string{".myapp.com"}, + }) + }) +} + +func Test_ConfigPanicLabelExceedsRFC1035Length(t *testing.T) { + t.Parallel() + + tooLong := strings.Repeat("a", 64) + ".example.com" + require.Panics(t, func() { + New(Config{ + AllowedHosts: []string{tooLong}, + }) + }) +} + +func Test_ConfigCustomErrorHandler(t *testing.T) { + t.Parallel() + + custom := func(c fiber.Ctx, _ error) error { + return c.Status(fiber.StatusTeapot).SendString("nope") + } + + cfg := configDefault(Config{ + AllowedHosts: []string{"example.com"}, + ErrorHandler: custom, + }) + require.NotNil(t, cfg.ErrorHandler) +} + +func Test_NormalizeHost(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + {"plain host", "example.com", "example.com"}, + {"uppercase", "EXAMPLE.COM", "example.com"}, + {"trailing dot", "example.com.", "example.com"}, + {"host with port", "example.com:8080", "example.com"}, + {"uppercase host with port", "EXAMPLE.COM:8080", "example.com"}, + {"ipv4", "192.168.1.1", "192.168.1.1"}, + {"ipv4 with port", "192.168.1.1:8080", "192.168.1.1"}, + {"ipv6 brackets", "[::1]", "::1"}, + {"ipv6 bare", "::1", "::1"}, + {"ipv6 with port", "[::1]:8080", "::1"}, + {"empty", "", ""}, + {"idn unicode → punycode", "münchen.example.com", "xn--mnchen-3ya.example.com"}, + {"idn already punycode", "xn--mnchen-3ya.example.com", "xn--mnchen-3ya.example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.expected, normalizeHost(tt.input)) + }) + } +} + +func Test_ParseAllowedHosts_SkipsBlankEntries(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"", " ", "example.com"}) + + require.True(t, matchHost("example.com", parsed, nil)) + require.False(t, matchHost("", parsed, nil)) +} + +func Test_MatchExact(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"example.com", "api.myapp.com"}) + + require.True(t, matchHost("example.com", parsed, nil)) + require.True(t, matchHost("api.myapp.com", parsed, nil)) + require.False(t, matchHost("evil.com", parsed, nil)) +} + +func Test_MatchExactCaseInsensitive(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"Example.COM"}) + + require.True(t, matchHost("example.com", parsed, nil)) +} + +func Test_MatchSubdomainWildcard(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"*.myapp.com"}) + + require.True(t, matchHost("api.myapp.com", parsed, nil)) + require.True(t, matchHost("www.myapp.com", parsed, nil)) + require.True(t, matchHost("deep.sub.myapp.com", parsed, nil)) + require.False(t, matchHost("myapp.com", parsed, nil), "bare domain must NOT match subdomain wildcard") + require.False(t, matchHost("evil.com", parsed, nil)) +} + +func Test_MatchSubdomainWildcard_IDN(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"*.münchen.example.com"}) + + require.True(t, matchHost(normalizeHost("api.münchen.example.com"), parsed, nil)) + require.True(t, matchHost("api.xn--mnchen-3ya.example.com", parsed, nil)) + require.False(t, matchHost("xn--mnchen-3ya.example.com", parsed, nil), "bare domain must NOT match subdomain wildcard") +} + +func Test_MatchAllowedHostsFunc(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"example.com"}) + fn := func(host string) bool { + return host == "dynamic.com" + } + + require.True(t, matchHost("example.com", parsed, fn)) + require.True(t, matchHost("dynamic.com", parsed, fn)) + require.False(t, matchHost("evil.com", parsed, fn)) +} + +func Test_MatchMixedCategories(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{ + "example.com", + "*.myapp.com", + "127.0.0.1", + }) + + require.True(t, matchHost("example.com", parsed, nil)) + require.True(t, matchHost("api.myapp.com", parsed, nil)) + require.True(t, matchHost("127.0.0.1", parsed, nil)) + require.False(t, matchHost("evil.com", parsed, nil)) + require.False(t, matchHost("192.168.1.1", parsed, nil)) +} + +func Test_MatchEmptyHost(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"example.com"}) + + require.False(t, matchHost("", parsed, nil)) +} + +func Test_HostAuthorization_AllowedHost(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_RejectedHost(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "evil.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_HostAuthorization_EmptyHost(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "" + + // app.Test() substitutes "localhost" when req.Host is empty, which isn't in the allowlist. + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_HostAuthorization_HostWithPort(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com:8080" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_AllowedHostWithPort(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com:8080"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com:8080" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + req2 := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req2.Host = "example.com" + + resp2, err := app.Test(req2) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp2.StatusCode) +} + +func Test_HostAuthorization_Next(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + Next: func(c fiber.Ctx) bool { + return c.Path() == "/healthz" + }, + })) + + app.Get("/healthz", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/healthz", http.NoBody) + req.Host = "evil.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_SubdomainWildcard_Allowed(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"*.myapp.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "api.myapp.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_SubdomainWildcard_BareDomainRejected(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"*.myapp.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "myapp.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_HostAuthorization_AllowedHostsFunc_Allowed(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHostsFunc: func(host string) bool { + return host == "dynamic.com" + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "dynamic.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_AllowedHostsFunc_Rejected(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHostsFunc: func(host string) bool { + return host == "dynamic.com" + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "evil.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_HostAuthorization_CustomErrorHandler(t *testing.T) { + t.Parallel() + app := fiber.New() + + var handlerErr error + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + ErrorHandler: func(c fiber.Ctx, err error) error { + handlerErr = err + return c.Status(fiber.StatusTeapot).SendString("custom rejection") + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "evil.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusTeapot, resp.StatusCode) + require.ErrorIs(t, handlerErr, ErrForbiddenHost) +} + +func Test_HostAuthorization_CaseInsensitive(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"Example.COM"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "EXAMPLE.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_TrailingDot(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com." + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_ExactIP(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"127.0.0.1"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "127.0.0.1" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_IDN_PunycodeRequest(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"münchen.example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "xn--mnchen-3ya.example.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_OverlappingRules(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{ + "api.myapp.com", + "*.myapp.com", + }, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "api.myapp.com" + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_XForwardedHost_TrustProxy_Allowed(t *testing.T) { + t.Parallel() + + // app.Test() uses remote address 0.0.0.0; trust that proxy IP. + app := fiber.New(fiber.Config{ + TrustProxy: true, + TrustProxyConfig: fiber.TrustProxyConfig{ + Proxies: []string{"0.0.0.0"}, + }, + }) + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "proxy.internal" + req.Header.Set("X-Forwarded-Host", "example.com") + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_HostAuthorization_XForwardedHost_TrustProxy_Rejected(t *testing.T) { + t.Parallel() + + app := fiber.New(fiber.Config{ + TrustProxy: true, + TrustProxyConfig: fiber.TrustProxyConfig{ + Proxies: []string{"0.0.0.0"}, + }, + }) + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com" + req.Header.Set("X-Forwarded-Host", "evil.com") + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_HostAuthorization_XForwardedHost_NoTrustProxy(t *testing.T) { + t.Parallel() + + app := fiber.New() + + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com" + req.Header.Set("X-Forwarded-Host", "evil.com") + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode, "X-Forwarded-Host should be ignored without TrustProxy") +} + +func Test_ErrForbiddenHostString(t *testing.T) { + t.Parallel() + + // Locked in so callers matching on err.Error() get a failing test on change. + require.Equal(t, "hostauthorization: forbidden host", ErrForbiddenHost.Error()) +} + +func Test_AllowedHostsFuncFallback(t *testing.T) { + t.Parallel() + + called := 0 + parsed := parseAllowedHosts([]string{"example.com"}) + fn := func(_ string) bool { + called++ + return false + } + + result := matchHost("example.com", parsed, fn) + require.True(t, result) + require.Equal(t, 0, called, "AllowedHostsFunc must not be called when a static host matches") + + result = matchHost("other.com", parsed, fn) + require.False(t, result) + require.Equal(t, 1, called, "AllowedHostsFunc must be called when no static rule matches") +} + +func Test_NormalizeHost_IPv6WithPortInConfig(t *testing.T) { + t.Parallel() + + parsed := parseAllowedHosts([]string{"[::1]:8080"}) + + require.True(t, matchHost("::1", parsed, nil)) + require.False(t, matchHost("::2", parsed, nil)) +} + +// --- Benchmarks --- + +func Benchmark_matchHost_ExactMatch(b *testing.B) { + parsed := parseAllowedHosts([]string{"example.com"}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + matchHost("example.com", parsed, nil) + } +} + +func Benchmark_matchHost_WildcardMatch(b *testing.B) { + parsed := parseAllowedHosts([]string{"*.myapp.com"}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + matchHost("api.myapp.com", parsed, nil) + } +} + +func Benchmark_matchHost_Mixed(b *testing.B) { + parsed := parseAllowedHosts([]string{ + "example.com", + "*.myapp.com", + "127.0.0.1", + }) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + matchHost("api.myapp.com", parsed, nil) + } +} + +// Worst-case linear HasSuffix scan: target only matches the last entry. +func Benchmark_matchHost_ManyWildcards(b *testing.B) { + const n = 100 + hosts := make([]string, n) + for i := 0; i < n; i++ { + hosts[i] = fmt.Sprintf("*.tenant%d.example.com", i) + } + parsed := parseAllowedHosts(hosts) + target := fmt.Sprintf("api.tenant%d.example.com", n-1) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + matchHost(target, parsed, nil) + } +} + +func Benchmark_HostAuthorization_ExactMatch(b *testing.B) { + app := fiber.New() + app.Use(New(Config{ + AllowedHosts: []string{"example.com"}, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "example.com" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := app.Test(req) + if err != nil { + b.Fatal(err) + } + resp.Body.Close() //nolint:errcheck // benchmark cleanup + } +} + +func Benchmark_HostAuthorization_Mixed(b *testing.B) { + app := fiber.New() + app.Use(New(Config{ + AllowedHosts: []string{ + "example.com", + "*.myapp.com", + "127.0.0.1", + }, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Host = "api.myapp.com" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := app.Test(req) + if err != nil { + b.Fatal(err) + } + resp.Body.Close() //nolint:errcheck // benchmark cleanup + } +} + +func FuzzNormalizeHost(f *testing.F) { + f.Add("example.com") + f.Add("example.com.") + f.Add("[::1]:8080") + f.Add("[::1]") + f.Add("*.myapp.com") + f.Add("192.168.1.1:443") + f.Add("münchen.example.com") + f.Add("") + f.Fuzz(func(_ *testing.T, input string) { + _ = normalizeHost(input) + }) +}