Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/improve-sso-samlresponse-validation
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Improved SAMLResponse validation by rejecting large responses, deeply nested documents, or documents with too many nodes.
* Added rate limiting to the SSO callback endpoint.
3 changes: 3 additions & 0 deletions cmd/fleet/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,9 @@ func runServeCmd(cmd *cobra.Command, configManager configpkg.Manager, debug, dev
if config.MDM.SSORateLimitPerMinute > 0 {
extra = append(extra, service.WithMdmSsoRateLimit(throttled.PerMin(config.MDM.SSORateLimitPerMinute)))
}
if config.Auth.SSORateLimitPerMinute > 0 {
extra = append(extra, service.WithSsoRateLimit(throttled.PerMin(config.Auth.SSORateLimitPerMinute)))
}
Comment thread
lucasmrod marked this conversation as resolved.
extra = append(extra, service.WithHTTPSigVerifier(httpSigVerifier))

apiHandler = service.MakeHandler(svc, config, httpLogger, limiterStore, redisPool, carveStore,
Expand Down
4 changes: 4 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ type AuthConfig struct {
SaltKeySize int `yaml:"salt_key_size"`
SsoSessionValidityPeriod time.Duration `yaml:"sso_session_validity_period"`
RequireHTTPMessageSignature bool `yaml:"require_http_message_signature"`
SSORateLimitPerMinute int `yaml:"sso_rate_limit_per_minute"`
}

// AppConfig defines configs related to HTTP
Expand Down Expand Up @@ -1386,6 +1387,8 @@ func (man Manager) addConfigs() {
"Timeout from SSO start to SSO callback")
man.addConfigBool("auth.require_http_message_signature", false,
"Require HTTP message signatures for fleetd requests (Premium feature)")
man.addConfigInt("auth.sso_rate_limit_per_minute", 0,
"Number of allowed requests per minute to the SSO callback endpoint (default uses the login rate limit value in a dedicated bucket)")
Comment thread
lucasmrod marked this conversation as resolved.

// App
man.addConfigString("app.token_key", "CHANGEME",
Expand Down Expand Up @@ -1875,6 +1878,7 @@ func (man Manager) LoadConfig() FleetConfig {
SaltKeySize: man.getConfigInt("auth.salt_key_size"),
SsoSessionValidityPeriod: man.getConfigDuration("auth.sso_session_validity_period"),
RequireHTTPMessageSignature: man.getConfigBool("auth.require_http_message_signature"),
SSORateLimitPerMinute: man.getConfigInt("auth.sso_rate_limit_per_minute"),
},
App: AppConfig{
TokenKeySize: man.getConfigInt("app.token_key_size"),
Expand Down
6 changes: 6 additions & 0 deletions server/fleet/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ const (
// MaxMultiScriptQuerySize, sets a max size for payloads that take multiple scripts and SQL queries.
MaxMultiScriptQuerySize int64 = 5 * units.MiB
MaxMicrosoftMDMSize int64 = 2 * units.MiB
// MaxSSOCallbackSize bounds the body of the unauthenticated SSO callback
// endpoints (regular and MDM). The body carries a base64-encoded
// SAMLResponse; legitimate responses are well under 50 KiB even after
// base64 inflation, so 256 KiB leaves generous headroom for large
// enterprise IdP responses while keeping pre-auth attacks surface small.
MaxSSOCallbackSize int64 = 256 * units.KiB
// MaxAppleMDMRequestBodySize bounds Apple MDM check-in and command-result
// request bodies. Results are stored in a MEDIUMTEXT column (max 16,777,215
// bytes), so the limit must not exceed that boundary.
Expand Down
58 changes: 44 additions & 14 deletions server/service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func checkLicenseExpiration(svc fleet.Service) func(context.Context, http.Respon
type extraHandlerOpts struct {
loginRateLimit *throttled.Rate
mdmSsoRateLimit *throttled.Rate
ssoRateLimit *throttled.Rate
httpSigVerifier mux.MiddlewareFunc
}

Expand All @@ -87,6 +88,15 @@ func WithMdmSsoRateLimit(r throttled.Rate) ExtraHandlerOption {
}
}

// WithSsoRateLimit configures the rate of the SSO callback's dedicated rate
// limit bucket (the rate defaults to the login rate limit otherwise; the bucket
// is always separate from the login bucket).
func WithSsoRateLimit(r throttled.Rate) ExtraHandlerOption {
return func(o *extraHandlerOpts) {
o.ssoRateLimit = &r
}
}

func WithHTTPSigVerifier(m mux.MiddlewareFunc) ExtraHandlerOption {
return func(o *extraHandlerOpts) {
o.httpSigVerifier = m
Expand Down Expand Up @@ -1147,8 +1157,38 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ne.WithCustomMiddleware(orgLogoLimiter).
GET("/api/_version_/fleet/logo", getOrgLogoEndpoint, getOrgLogoRequest{})

// Rate limiters shared across the login/SSO endpoints. These are defined
// here (ahead of the password-login registrations below) so the
// unauthenticated SSO callback can reuse the same login bucket.
limiter := ratelimit.NewMiddleware(limitStore)

// By default, MDM SSO shares the login rate limit bucket; if MDM SSO limit is overridden, MDM SSO gets its
// own rate limit bucket.
loginRateLimit := throttled.PerMin(10)
if extra.loginRateLimit != nil {
loginRateLimit = *extra.loginRateLimit
}
loginLimiter := limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})
mdmSsoLimiter := loginLimiter
if extra.mdmSsoRateLimit != nil {
mdmSsoLimiter = limiter.Limit("mdm_sso", throttled.RateQuota{MaxRate: *extra.mdmSsoRateLimit, MaxBurst: 9})
}
// The SSO callback gets its own dedicated bucket (separate from the login
// bucket) so a flood on the unauthenticated callback can't exhaust the
// rate-limit budget that legitimate password logins depend on. The rate
// defaults to the login rate unless explicitly overridden.
ssoRateLimit := loginRateLimit

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to do the same thing for the mdmSsoLimiter? Right now both the logic and mdm sso limiter share the same bucket

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep this new rate limiting separate to reduce breaking other endpoints (this new rate limit will be separate from login/MDM-SSO endpoints)

if extra.ssoRateLimit != nil {
ssoRateLimit = *extra.ssoRateLimit
}
ssoLimiter := limiter.Limit("sso", throttled.RateQuota{MaxRate: ssoRateLimit, MaxBurst: 9})

ne.POST("/api/v1/fleet/sso", initiateSSOEndpoint, initiateSSORequest{})
ne.POST("/api/v1/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{})
// The SSO callback is unauthenticated and internet-reachable. Rate-limit it
// (dedicated bucket) and cap the body to keep pre-auth attacks surface small.
ne.WithCustomMiddleware(ssoLimiter).
WithRequestBodySizeLimit(fleet.MaxSSOCallbackSize).
POST("/api/v1/fleet/sso/callback", makeCallbackSSOEndpoint(config.Server.URLPrefix), callbackSSORequest{})
ne.GET("/api/v1/fleet/sso", settingsSSOEndpoint, nil)

// the websocket distributed query results endpoint is a bit different - the
Expand All @@ -1160,23 +1200,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
makeStreamDistributedQueryCampaignResultsHandler(config.Server, svc, logger))

quota := throttled.RateQuota{MaxRate: throttled.PerHour(10), MaxBurst: forgotPasswordRateLimitMaxBurst}
limiter := ratelimit.NewMiddleware(limitStore)
ne.
WithCustomMiddleware(limiter.Limit("forgot_password", quota)).
POST("/api/_version_/fleet/forgot_password", forgotPasswordEndpoint, forgotPasswordRequest{})

// By default, MDM SSO shares the login rate limit bucket; if MDM SSO limit is overridden, MDM SSO gets its
// own rate limit bucket.
loginRateLimit := throttled.PerMin(10)
if extra.loginRateLimit != nil {
loginRateLimit = *extra.loginRateLimit
}
loginLimiter := limiter.Limit("login", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})
mdmSsoLimiter := loginLimiter
if extra.mdmSsoRateLimit != nil {
mdmSsoLimiter = limiter.Limit("mdm_sso", throttled.RateQuota{MaxRate: *extra.mdmSsoRateLimit, MaxBurst: 9})
}

ne.WithCustomMiddleware(loginLimiter).
POST("/api/_version_/fleet/login", loginEndpoint, fleet.LoginRequest{})
ne.WithCustomMiddleware(limiter.Limit("mfa", throttled.RateQuota{MaxRate: loginRateLimit, MaxBurst: 9})).
Expand All @@ -1191,7 +1218,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC

neAppleMDM.WithCustomMiddleware(mdmSsoLimiter).
POST("/api/_version_/fleet/mdm/sso", initiateMDMSSOEndpoint, initiateMDMSSORequest{})
// Same posture as the regular SSO callback: rate-limited (already) plus a
// tight body cap to keep pre-auth attacks surface small.
ne.WithCustomMiddleware(mdmSsoLimiter).
WithRequestBodySizeLimit(fleet.MaxSSOCallbackSize).
POST("/api/_version_/fleet/mdm/sso/callback", callbackMDMSSOEndpoint, callbackMDMSSORequest{})

// Register all deprecated URL path aliases from the declarative table.
Expand Down
11 changes: 11 additions & 0 deletions server/service/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,17 @@ func decodeCallbackRequest(ctx context.Context, r *http.Request) (
Message: "missing SAMLResponse",
}, "missing SAMLResponse in SSO callback")
}
// Cap the SAMLResponse value itself, not just the request body. FormValue
// reads from both the POST body and the URL query string, and
// WithRequestBodySizeLimit only bounds the body — so without this check the
// body cap is trivially bypassed by sending the payload as a
// ?SAMLResponse= query argument. This guards both the regular and MDM SSO
// callbacks, which share this decoder.
if int64(len(samlResponseValue)) > fleet.MaxSSOCallbackSize {
return "", nil, ctxerr.Wrap(ctx, &fleet.BadRequestError{
Message: "SAMLResponse too large",
}, "SAMLResponse exceeds maximum size in SSO callback")
}
decodedSAMLResponseValue, err := sso.DecodeSAMLResponse(samlResponseValue)
if err != nil {
return "", nil, ctxerr.Wrap(ctx, &fleet.BadRequestError{
Expand Down
31 changes: 31 additions & 0 deletions server/service/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package service

import (
"context"
"encoding/base64"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -652,3 +656,30 @@ func TestInitiateSSOWithInvalidURL(t *testing.T) {
require.ErrorAs(t, err, &badReqErr)
require.Contains(t, badReqErr.Message, "invalid SSO URL")
}

func TestDecodeCallbackRequestSAMLResponseSizeCap(t *testing.T) {
// The SSO callbacks read SAMLResponse from FormValue, which covers both the
// POST body and the URL query string. WithRequestBodySizeLimit only bounds
// the body, so the value-level cap must reject an oversized query argument.
t.Run("oversized SAMLResponse in query string is rejected", func(t *testing.T) {
oversized := strings.Repeat("A", int(fleet.MaxSSOCallbackSize)+1)
r := httptest.NewRequest("POST", "/api/v1/fleet/sso/callback?SAMLResponse="+oversized, nil)

_, _, err := decodeCallbackRequest(t.Context(), r)
require.Error(t, err)
var bre *fleet.BadRequestError
require.ErrorAs(t, err, &bre)
require.Contains(t, bre.Message, "too large")
})

t.Run("normally-sized SAMLResponse passes the size check", func(t *testing.T) {
small := base64.StdEncoding.EncodeToString([]byte("<x/>"))
form := url.Values{"SAMLResponse": {small}}
r := httptest.NewRequest("POST", "/api/v1/fleet/sso/callback", strings.NewReader(form.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

_, decoded, err := decodeCallbackRequest(t.Context(), r)
require.NoError(t, err)
require.Equal(t, "<x/>", string(decoded))
})
}
53 changes: 53 additions & 0 deletions server/sso/authorization_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,22 @@ import (
"slices"
"strings"

"github.com/beevik/etree"
"github.com/crewjam/saml"
"github.com/fleetdm/fleet/v4/server/fleet"
)

const (
// maxSAMLResponseDepth bounds how deeply nested the SAMLResponse XML may
// be. Legitimate SAML responses are shallow; deep nesting is the signature
// of a canonicalization bomb, which runs before any certificate or signature
// check on the unauthenticated SSO callback endpoints.
maxSAMLResponseDepth = 100
// maxSAMLResponseElements bounds the total number of XML elements in the
// SAMLResponse, for the same reason (a bomb can be wide rather than deep).
maxSAMLResponseElements = 5000
)

// Since there's not a standard for display names, I have collected the most
// commonly used attribute names for it.
//
Expand Down Expand Up @@ -109,8 +121,49 @@ func validateAudiences(assertion *saml.Assertion, expectedAudiences []string) er
return fmt.Errorf("wrong audience: %+v", assertion.Conditions.AudienceRestrictions)
}

// validateSAMLResponseShape parses the decoded SAMLResponse XML and rejects
// documents that are excessively deep or have too many elements before they
// reach goxmldsig's pre-signature canonicalization, which (as of the time of writing)
// has no traversal limit of its own.
func validateSAMLResponseShape(samlResponse []byte) error {
doc := etree.NewDocument()
if err := doc.ReadFromBytes(samlResponse); err != nil {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one minor note here is that this loads the whole payload into memory, but given the current guards at the middleware layer, an attack would consume at most 256 KB * 10 = ~2.5 MB per minute which should be ok.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, both the rate limit + size cap should keep things safe.

return fmt.Errorf("parsing SAMLResponse XML: %w", err)
}
root := doc.Root()
if root == nil {
return errors.New("SAMLResponse has no root element")
}

count := 0
var walk func(el *etree.Element, depth int) error
walk = func(el *etree.Element, depth int) error {
if depth > maxSAMLResponseDepth {
return fmt.Errorf("SAMLResponse exceeds maximum nesting depth of %d", maxSAMLResponseDepth)
}
count++
if count > maxSAMLResponseElements {
return fmt.Errorf("SAMLResponse exceeds maximum element count of %d", maxSAMLResponseElements)
}
for _, child := range el.ChildElements() {
if err := walk(child, depth+1); err != nil {
return err
}
}
return nil
}
return walk(root, 1)
}

// ParseAndVerifySAMLResponse runs the parsing and validation of SAMLResponses.
func ParseAndVerifySAMLResponse(samlProvider *saml.ServiceProvider, samlResponse []byte, requestID string, acsURL *url.URL) (fleet.Auth, error) {
// Reject oversized/over-nested documents before handing them to
// crewjam/saml -> goxmldsig, whose pre-signature canonicalization is (at the time of writing)
// unbounded and runs without authentication.
if err := validateSAMLResponseShape(samlResponse); err != nil {
return nil, err
}

verifiedAssertion, err := samlProvider.ParseXMLResponse(samlResponse, []string{requestID}, *acsURL)
if err != nil {
if samlErr, ok := err.(*saml.InvalidResponseError); ok {
Expand Down
66 changes: 66 additions & 0 deletions server/sso/authorization_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sso
import (
"fmt"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -688,3 +689,68 @@ func TestDecodeOktaResponseWithCustomAttrs(t *testing.T) {
},
}, attrs)
}

func TestValidateSAMLResponseShape(t *testing.T) {
t.Run("valid shallow response passes", func(t *testing.T) {
const samlResponse = `<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol">
<saml:Assertion xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion">
<saml:Subject><saml:NameID>john@example.com</saml:NameID></saml:Subject>
</saml:Assertion>
</samlp:Response>`
require.NoError(t, validateSAMLResponseShape([]byte(samlResponse)))
})

t.Run("invalid XML is rejected", func(t *testing.T) {
require.Error(t, validateSAMLResponseShape([]byte("not xml <<<")))
})

t.Run("empty document is rejected", func(t *testing.T) {
require.Error(t, validateSAMLResponseShape([]byte("")))
})

t.Run("excessive nesting depth is rejected", func(t *testing.T) {
var sb strings.Builder
depth := maxSAMLResponseDepth + 50
sb.WriteString(`<root xmlns:p0="urn:p0">`)
for i := 1; i <= depth; i++ {
fmt.Fprintf(&sb, `<p%d:n xmlns:p%d="urn:p%d">`, i, i, i)
}
for i := depth; i >= 1; i-- {
fmt.Fprintf(&sb, `</p%d:n>`, i)
}
sb.WriteString(`</root>`)

err := validateSAMLResponseShape([]byte(sb.String()))
require.Error(t, err)
require.Contains(t, err.Error(), "maximum nesting depth")
})

t.Run("just under the depth limit passes", func(t *testing.T) {
var sb strings.Builder
// root is depth 1, so depth-1 additional nested children keep us at the cap.
sb.WriteString(`<root>`)
for i := 1; i < maxSAMLResponseDepth; i++ {
sb.WriteString(`<n>`)
}
for i := 1; i < maxSAMLResponseDepth; i++ {
sb.WriteString(`</n>`)
}
sb.WriteString(`</root>`)
require.NoError(t, validateSAMLResponseShape([]byte(sb.String())))
})

t.Run("excessive element count is rejected", func(t *testing.T) {
// A wide (shallow) document with more than maxSAMLResponseElements
// children also triggers the O(N^2) canonicalization cost.
var sb strings.Builder
sb.WriteString(`<root>`)
for i := 0; i <= maxSAMLResponseElements; i++ {
sb.WriteString(`<n/>`)
}
sb.WriteString(`</root>`)

err := validateSAMLResponseShape([]byte(sb.String()))
require.Error(t, err)
require.Contains(t, err.Error(), "maximum element count")
})
}
Loading