Skip to content

Commit 573e026

Browse files
committed
gh-19068 - prevent thundering herd on jwks fetch
Signed-off-by: anschnapp <a.snap@t-online.de>
1 parent 6e894fd commit 573e026

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.text.ParseException;
2020
import java.util.Collections;
2121
import java.util.List;
22+
import java.util.Objects;
2223
import java.util.Set;
2324
import java.util.concurrent.atomic.AtomicReference;
2425

@@ -44,6 +45,12 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
4445
*/
4546
private final AtomicReference<Mono<JWKSet>> cachedJWKSet = new AtomicReference<>(Mono.empty());
4647

48+
/**
49+
* In-flight JWK set fetch request, used to coalesce concurrent fetches into a single
50+
* HTTP call.
51+
*/
52+
private final AtomicReference<@Nullable Mono<JWKSet>> inflightRequest = new AtomicReference<>();
53+
4754
/**
4855
* The cached JWK set URL.
4956
*/
@@ -101,24 +108,23 @@ private Mono<List<JWK>> get(JWKSelector jwkSelector, JWKSet jwkSet) {
101108
}
102109

103110
/**
104-
* Updates the cached JWK set from the configured URL.
111+
* Updates the cached JWK set from the configured URL. Concurrent calls are coalesced
112+
* into a single HTTP request to prevent thundering herd during cold start.
105113
* @return The updated JWK set.
106114
* @throws RemoteKeySourceException If JWK retrieval failed.
107115
*/
108116
private Mono<JWKSet> getJWKSet() {
109-
// @formatter:off
110-
return this.jwkSetUrlProvider
111-
.flatMap((jwkSetURL) -> this.webClient.get()
112-
.uri(jwkSetURL)
113-
.retrieve()
114-
.bodyToMono(String.class)
115-
)
116-
.map(this::parse)
117-
.doOnNext((jwkSet) -> this.cachedJWKSet
118-
.set(Mono.just(jwkSet))
119-
)
120-
.cache();
121-
// @formatter:on
117+
Mono<JWKSet> fetch = Mono.defer(() -> this.jwkSetUrlProvider
118+
.flatMap((jwkSetURL) -> this.webClient.get().uri(jwkSetURL).retrieve().bodyToMono(String.class))
119+
.map(this::parse)
120+
.doOnNext((jwkSet) -> {
121+
this.cachedJWKSet.set(Mono.just(jwkSet));
122+
this.inflightRequest.set(null);
123+
})
124+
.doOnError((ex) -> this.inflightRequest.set(null))
125+
.doOnCancel(() -> this.inflightRequest.set(null))
126+
.switchIfEmpty(Mono.fromRunnable(() -> this.inflightRequest.set(null)))).cache();
127+
return Objects.requireNonNull(this.inflightRequest.updateAndGet((v) -> (v != null) ? v : fetch));
122128
}
123129

124130
private JWKSet parse(String body) {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19+
import java.time.Duration;
1920
import java.util.Collections;
2021
import java.util.List;
22+
import java.util.concurrent.TimeUnit;
2123
import java.util.function.Supplier;
2224

2325
import com.nimbusds.jose.jwk.JWK;
@@ -32,7 +34,9 @@
3234
import org.junit.jupiter.api.extension.ExtendWith;
3335
import org.mockito.Mock;
3436
import org.mockito.junit.jupiter.MockitoExtension;
37+
import reactor.core.publisher.Flux;
3538
import reactor.core.publisher.Mono;
39+
import reactor.core.scheduler.Schedulers;
3640

3741
import org.springframework.web.reactive.function.client.WebClientResponseException;
3842

@@ -166,6 +170,43 @@ public void getWhenNoMatchAndKeyIdMatchThenEmpty() {
166170
assertThat(this.source.get(this.selector).block()).isEmpty();
167171
}
168172

173+
@Test
174+
public void getWhenConcurrentRequestsThenSingleFetch() {
175+
// given
176+
given(this.matcher.matches(any())).willReturn(true);
177+
int concurrentRequests = 10;
178+
for (int i = 0; i < concurrentRequests; i++) {
179+
this.server.enqueue(new MockResponse().setBody(this.keys).setBodyDelay(100, TimeUnit.MILLISECONDS));
180+
}
181+
182+
// when
183+
List<List<JWK>> results = Flux.range(0, concurrentRequests)
184+
.flatMap((i) -> this.source.get(this.selector).subscribeOn(Schedulers.parallel()), concurrentRequests)
185+
.collectList()
186+
.block(Duration.ofSeconds(5));
187+
188+
// then
189+
assertThat(results).hasSize(concurrentRequests);
190+
assertThat(this.server.getRequestCount()).isEqualTo(1);
191+
}
192+
193+
@Test
194+
public void getWhenEmptyResponseThenNextCallSucceeds() {
195+
// given
196+
given(this.matcher.matches(any())).willReturn(true);
197+
this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(this.mockStringSupplier));
198+
// first call: supplier returns null URL, causing empty Mono from jwkSetUrlProvider
199+
willReturn(null).given(this.mockStringSupplier).get();
200+
201+
// when: first call completes empty
202+
List<JWK> firstResult = this.source.get(this.selector).block(Duration.ofSeconds(5));
203+
204+
// then: inflight is cleared and second call can succeed
205+
willReturn(this.server.url("/").toString()).given(this.mockStringSupplier).get();
206+
List<JWK> secondResult = this.source.get(this.selector).block(Duration.ofSeconds(5));
207+
assertThat(secondResult).isNotEmpty();
208+
}
209+
169210
@Test
170211
public void getShouldRecoverAndReturnKeysAfterErrorCase() {
171212
given(this.matcher.matches(any())).willReturn(true);

0 commit comments

Comments
 (0)