-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_stdio.py
More file actions
128 lines (104 loc) · 5.02 KB
/
test_stdio.py
File metadata and controls
128 lines (104 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import io
import sys
from io import TextIOWrapper
import anyio
import pytest
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
@pytest.mark.anyio
async def test_stdio_server():
stdin = io.StringIO()
stdout = io.StringIO()
messages = [
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=2, result={}),
]
for message in messages:
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0)
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
received_messages: list[JSONRPCMessage] = []
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception): # pragma: no cover
raise message
received_messages.append(message.message)
if len(received_messages) == 2:
break
# Verify received messages
assert len(received_messages) == 2
assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={})
# Test sending responses from the server
responses = [
JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=4, result={}),
]
async with write_stream:
for response in responses:
session_message = SessionMessage(response)
await write_stream.send(session_message)
stdout.seek(0)
output_lines = stdout.readlines()
assert len(output_lines) == 2
received_responses = [jsonrpc_message_adapter.validate_json(line.strip()) for line in output_lines]
assert len(received_responses) == 2
assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={})
@pytest.mark.anyio
async def test_stdio_server_no_crlf(monkeypatch: pytest.MonkeyPatch):
"""Raw bytes written to stdout must use LF (\\n), never CRLF (\\r\\n).
On Windows, TextIOWrapper with the default newline=None translates \\n to
\\r\\n on write, which corrupts NDJSON framing for JSON-RPC. The fix is to
pass newline="" to TextIOWrapper so no translation occurs.
"""
raw_stdout = io.BytesIO()
# Wrap with newline="" so we can inspect the exact bytes that
# stdio_server writes. The key assertion is that the raw bytes
# contain \n and never \r\n.
stdout_wrapper = TextIOWrapper(raw_stdout, encoding="utf-8", newline="")
stdin_wrapper = TextIOWrapper(io.BytesIO(b""), encoding="utf-8")
message = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
with anyio.fail_after(5):
async with stdio_server(
stdin=anyio.AsyncFile(stdin_wrapper),
stdout=anyio.AsyncFile(stdout_wrapper),
) as (read_stream, write_stream):
async with write_stream:
await write_stream.send(SessionMessage(message))
async with read_stream:
pass
stdout_wrapper.flush()
raw_bytes = raw_stdout.getvalue()
assert len(raw_bytes) > 0, "expected output bytes"
assert raw_bytes.endswith(b"\n"), "output must end with LF"
assert b"\r\n" not in raw_bytes, "output must not contain CRLF"
@pytest.mark.anyio
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
"""Non-UTF-8 bytes on stdin must not crash the server.
Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and
is delivered as an in-stream exception. Subsequent valid messages must
still be processed.
"""
# \xff\xfe are invalid UTF-8 start bytes.
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
# stdio_server()'s default path wraps it with errors='replace'.
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
with anyio.fail_after(5):
async with stdio_server() as (read_stream, write_stream):
await write_stream.aclose()
async with read_stream: # pragma: no branch
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
first = await read_stream.receive()
assert isinstance(first, Exception)
# Second line: valid message still comes through
second = await read_stream.receive()
assert isinstance(second, SessionMessage)
assert second.message == valid