diff --git a/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxSseServerTransportProvider.java b/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxSseServerTransportProvider.java index de70c12f07..56ce34c2ad 100644 --- a/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxSseServerTransportProvider.java +++ b/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxSseServerTransportProvider.java @@ -49,6 +49,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RequestPredicates; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; @@ -189,8 +190,8 @@ private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseU this.contextExtractor = contextExtractor; this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) + .GET(this.sseEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleSseConnection) + .POST(this.messageEndpoint, RequestPredicates.accept(MediaType.APPLICATION_JSON), this::handleMessage) .build(); if (keepAliveInterval != null) { diff --git a/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStatelessServerTransport.java b/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStatelessServerTransport.java index c94f62b306..fe177f80c3 100644 --- a/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStatelessServerTransport.java +++ b/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStatelessServerTransport.java @@ -39,6 +39,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.web.reactive.function.server.RequestPredicates; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; @@ -83,8 +84,8 @@ private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndp this.contextExtractor = contextExtractor; this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) + .POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON), + this::handlePost) .build(); } @@ -114,10 +115,6 @@ public RouterFunction getRouterFunction() { return this.routerFunction; } - private Mono handleGet(ServerRequest request) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); - } - private Mono handlePost(ServerRequest request) { if (this.isClosing) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); @@ -134,12 +131,6 @@ private Mono handlePost(ServerRequest request) { McpTransportContext transportContext = this.contextExtractor.extract(request); - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) - && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { - return ServerResponse.badRequest().build(); - } - return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, body); diff --git a/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStreamableServerTransportProvider.java b/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStreamableServerTransportProvider.java index 068a3e1eea..cb9e74f565 100644 --- a/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp/transport/mcp-spring-webflux/src/main/java/org/springframework/ai/mcp/server/webflux/transport/WebFluxStreamableServerTransportProvider.java @@ -50,6 +50,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RequestPredicates; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; @@ -104,8 +105,9 @@ private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, Strin this.disallowDelete = disallowDelete; this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) + .GET(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleGet) + .POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON), + this::handlePost) .DELETE(this.mcpEndpoint, this::handleDelete) .build(); @@ -216,11 +218,6 @@ private Mono handleGet(ServerRequest request) { McpTransportContext transportContext = this.contextExtractor.extract(request); return Mono.defer(() -> { - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { - return ServerResponse.badRequest().build(); - } - if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { return ServerResponse.badRequest().build(); // TODO: say we need a session // id @@ -279,12 +276,6 @@ private Mono handlePost(ServerRequest request) { McpTransportContext transportContext = this.contextExtractor.extract(request); - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) - && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { - return ServerResponse.badRequest().build(); - } - return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, body); diff --git a/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcSseServerTransportProvider.java b/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcSseServerTransportProvider.java index 6c2736557f..a7ab45588b 100644 --- a/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcSseServerTransportProvider.java +++ b/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcSseServerTransportProvider.java @@ -44,6 +44,8 @@ import reactor.core.publisher.Mono; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RequestPredicates; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerRequest; @@ -175,8 +177,8 @@ private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUr this.contextExtractor = contextExtractor; this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) + .GET(this.sseEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleSseConnection) + .POST(this.messageEndpoint, RequestPredicates.accept(MediaType.APPLICATION_JSON), this::handleMessage) .build(); if (keepAliveInterval != null) { diff --git a/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStatelessServerTransport.java b/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStatelessServerTransport.java index 0bde2ed777..4cc1332678 100644 --- a/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStatelessServerTransport.java +++ b/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStatelessServerTransport.java @@ -17,7 +17,6 @@ package org.springframework.ai.mcp.server.webmvc.transport; import java.io.IOException; -import java.util.List; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.json.McpJsonDefaults; @@ -37,6 +36,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RequestPredicates; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerRequest; @@ -85,8 +85,8 @@ private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpo this.contextExtractor = contextExtractor; this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) + .POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON), + this::handlePost) .build(); } @@ -116,10 +116,6 @@ public RouterFunction getRouterFunction() { return this.routerFunction; } - private ServerResponse handleGet(ServerRequest request) { - return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); - } - private ServerResponse handlePost(ServerRequest request) { if (this.isClosing) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); @@ -136,12 +132,6 @@ private ServerResponse handlePost(ServerRequest request) { McpTransportContext transportContext = this.contextExtractor.extract(request); - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) - && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { - return ServerResponse.badRequest().build(); - } - var handler = this.mcpHandler; if (handler == null) { return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) diff --git a/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStreamableServerTransportProvider.java b/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStreamableServerTransportProvider.java index de4136bc9d..0e061a9d03 100644 --- a/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp/transport/mcp-spring-webmvc/src/main/java/org/springframework/ai/mcp/server/webmvc/transport/WebMvcStreamableServerTransportProvider.java @@ -47,6 +47,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RequestPredicates; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerRequest; @@ -151,8 +152,9 @@ private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String this.contextExtractor = contextExtractor; this.securityValidator = securityValidator; this.routerFunction = RouterFunctions.route() - .GET(this.mcpEndpoint, this::handleGet) - .POST(this.mcpEndpoint, this::handlePost) + .GET(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleGet) + .POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON), + this::handlePost) .DELETE(this.mcpEndpoint, this::handleDelete) .build(); @@ -280,11 +282,6 @@ private ServerResponse handleGet(ServerRequest request) { return ServerResponse.status(e.getStatusCode()).body(message); } - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { - return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); - } - McpTransportContext transportContext = this.contextExtractor.extract(request); if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) { @@ -369,15 +366,6 @@ private ServerResponse handlePost(ServerRequest request) { return ServerResponse.status(e.getStatusCode()).body(message); } - List acceptHeaders = request.headers().asHttpHeaders().getAccept(); - if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) - || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { - return ServerResponse.badRequest() - .body(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) - .message("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON") - .build()); - } - McpTransportContext transportContext = this.contextExtractor.extract(request); try {