-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathproxy.py
More file actions
67 lines (53 loc) · 2.1 KB
/
proxy.py
File metadata and controls
67 lines (53 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Provide utilities for proxying messages between two MCP transports."""
from __future__ import annotations
import inspect
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
import anyio
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.message import SessionMessage
MessageStream = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]]
ErrorHandler = Callable[[Exception], None | Awaitable[None]]
@asynccontextmanager
async def mcp_proxy(
transport_to_client: MessageStream,
transport_to_server: MessageStream,
on_error: ErrorHandler | None = None,
) -> AsyncGenerator[None]:
"""Proxy messages bidirectionally between two MCP transports."""
client_read, client_write = transport_to_client
server_read, server_write = transport_to_server
async with anyio.create_task_group() as task_group:
task_group.start_soon(_forward_messages, client_read, server_write, on_error)
task_group.start_soon(_forward_messages, server_read, client_write, on_error)
try:
yield
finally:
task_group.cancel_scope.cancel()
async def _forward_messages(
read_stream: ReadStream[SessionMessage | Exception],
write_stream: WriteStream[SessionMessage],
on_error: ErrorHandler | None,
) -> None:
try:
async with write_stream:
async with read_stream:
async for item in read_stream:
if isinstance(item, Exception):
await _run_error_handler(item, on_error)
continue
try:
await write_stream.send(item)
except anyio.ClosedResourceError:
break
except anyio.ClosedResourceError:
return
async def _run_error_handler(error: Exception, on_error: ErrorHandler | None) -> None:
if on_error is None:
return
try:
result = on_error(error)
if inspect.isawaitable(result):
await result
except Exception:
return