Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RequestPredicates;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
Expand Down Expand Up @@ -189,8 +190,8 @@ private WebFluxSseServerTransportProvider(McpJsonMapper jsonMapper, String baseU
this.contextExtractor = contextExtractor;
this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.GET(this.sseEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleSseConnection)
.POST(this.messageEndpoint, RequestPredicates.accept(MediaType.APPLICATION_JSON), this::handleMessage)
.build();

if (keepAliveInterval != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.reactive.function.server.RequestPredicates;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
Expand Down Expand Up @@ -83,8 +84,8 @@ private WebFluxStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndp
this.contextExtractor = contextExtractor;
this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
.POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON),
this::handlePost)
.build();
}

Expand Down Expand Up @@ -114,10 +115,6 @@ public RouterFunction<?> getRouterFunction() {
return this.routerFunction;
}

private Mono<ServerResponse> handleGet(ServerRequest request) {
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
}

private Mono<ServerResponse> handlePost(ServerRequest request) {
if (this.isClosing) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
Expand All @@ -134,12 +131,6 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {

McpTransportContext transportContext = this.contextExtractor.extract(request);

List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
&& acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) {
return ServerResponse.badRequest().build();
}

return request.bodyToMono(String.class).<ServerResponse>flatMap(body -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, body);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.reactive.function.server.RequestPredicates;
import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions;
import org.springframework.web.reactive.function.server.ServerRequest;
Expand Down Expand Up @@ -104,8 +105,9 @@ private WebFluxStreamableServerTransportProvider(McpJsonMapper jsonMapper, Strin
this.disallowDelete = disallowDelete;
this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
.GET(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleGet)
.POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON),
this::handlePost)
.DELETE(this.mcpEndpoint, this::handleDelete)
.build();

Expand Down Expand Up @@ -216,11 +218,6 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
McpTransportContext transportContext = this.contextExtractor.extract(request);

return Mono.defer(() -> {
List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
return ServerResponse.badRequest().build();
}

if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) {
return ServerResponse.badRequest().build(); // TODO: say we need a session
// id
Expand Down Expand Up @@ -279,12 +276,6 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {

McpTransportContext transportContext = this.contextExtractor.extract(request);

List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
&& acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) {
return ServerResponse.badRequest().build();
}

return request.bodyToMono(String.class).<ServerResponse>flatMap(body -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.jsonMapper, body);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import reactor.core.publisher.Mono;

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.function.RequestPredicates;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerRequest;
Expand Down Expand Up @@ -175,8 +177,8 @@ private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUr
this.contextExtractor = contextExtractor;
this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.GET(this.sseEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleSseConnection)
.POST(this.messageEndpoint, RequestPredicates.accept(MediaType.APPLICATION_JSON), this::handleMessage)
.build();

if (keepAliveInterval != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.springframework.ai.mcp.server.webmvc.transport;

import java.io.IOException;
import java.util.List;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonDefaults;
Expand All @@ -37,6 +36,7 @@

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.function.RequestPredicates;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerRequest;
Expand Down Expand Up @@ -85,8 +85,8 @@ private WebMvcStatelessServerTransport(McpJsonMapper jsonMapper, String mcpEndpo
this.contextExtractor = contextExtractor;
this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
.POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON),
this::handlePost)
.build();
}

Expand Down Expand Up @@ -116,10 +116,6 @@ public RouterFunction<ServerResponse> getRouterFunction() {
return this.routerFunction;
}

private ServerResponse handleGet(ServerRequest request) {
return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
}

private ServerResponse handlePost(ServerRequest request) {
if (this.isClosing) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
Expand All @@ -136,12 +132,6 @@ private ServerResponse handlePost(ServerRequest request) {

McpTransportContext transportContext = this.contextExtractor.extract(request);

List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON)
&& acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) {
return ServerResponse.badRequest().build();
}

var handler = this.mcpHandler;
if (handler == null) {
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.function.RequestPredicates;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerRequest;
Expand Down Expand Up @@ -151,8 +152,9 @@ private WebMvcStreamableServerTransportProvider(McpJsonMapper jsonMapper, String
this.contextExtractor = contextExtractor;
this.securityValidator = securityValidator;
this.routerFunction = RouterFunctions.route()
.GET(this.mcpEndpoint, this::handleGet)
.POST(this.mcpEndpoint, this::handlePost)
.GET(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM), this::handleGet)
.POST(this.mcpEndpoint, RequestPredicates.accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON),
this::handlePost)
.DELETE(this.mcpEndpoint, this::handleDelete)
.build();

Expand Down Expand Up @@ -280,11 +282,6 @@ private ServerResponse handleGet(ServerRequest request) {
return ServerResponse.status(e.getStatusCode()).body(message);
}

List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM");
}

McpTransportContext transportContext = this.contextExtractor.extract(request);

if (request.headers().header(HttpHeaders.MCP_SESSION_ID).isEmpty()) {
Expand Down Expand Up @@ -369,15 +366,6 @@ private ServerResponse handlePost(ServerRequest request) {
return ServerResponse.status(e.getStatusCode()).body(message);
}

List<MediaType> acceptHeaders = request.headers().asHttpHeaders().getAccept();
if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)
|| !acceptHeaders.contains(MediaType.APPLICATION_JSON)) {
return ServerResponse.badRequest()
.body(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND)
.message("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")
.build());
}

McpTransportContext transportContext = this.contextExtractor.extract(request);

try {
Expand Down