-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_sampling_callback.py
More file actions
175 lines (149 loc) · 6.56 KB
/
test_sampling_callback.py
File metadata and controls
175 lines (149 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import pytest
from mcp import Client
from mcp.client.session import ClientSession
from mcp.server.mcpserver import Context, MCPServer
from mcp.shared._context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
CreateMessageResultWithTools,
SamplingMessage,
TextContent,
ToolUseContent,
)
@pytest.mark.anyio
async def test_sampling_callback():
server = MCPServer("test")
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="This is a response from the sampling callback"),
model="test-model",
stop_reason="endTurn",
)
async def sampling_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
@server.tool("test_sampling")
async def test_sampling_tool(message: str, ctx: Context) -> bool:
value = await ctx.session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
assert value == callback_return
return True
# Test with sampling callback
async with Client(server, sampling_callback=sampling_callback) as client:
# Make a request to trigger sampling callback
result = await client.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Test without sampling callback
async with Client(server) as client:
# Make a request to trigger sampling callback
result = await client.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.is_error is True
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
@pytest.mark.anyio
async def test_set_sampling_callback():
server = MCPServer("test")
updated_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Updated response"),
model="updated-model",
stop_reason="endTurn",
)
async def updated_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return updated_return
@server.tool("do_sample")
async def do_sample(message: str, ctx: Context) -> bool:
value = await ctx.session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
assert value == updated_return
return True
async with Client(server) as client:
# Before setting callback — default rejects with error
result = await client.call_tool("do_sample", {"message": "test"})
assert result.is_error is True
# Set new callback — should succeed
client.session.set_sampling_callback(updated_callback)
result = await client.call_tool("do_sample", {"message": "test"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Reset to None — back to default error
client.session.set_sampling_callback(None)
result = await client.call_tool("do_sample", {"message": "test"})
assert result.is_error is True
@pytest.mark.anyio
async def test_create_message_backwards_compat_single_content():
"""Test backwards compatibility: create_message without tools returns single content."""
server = MCPServer("test")
# Callback returns single content (text)
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Hello from LLM"),
model="test-model",
stop_reason="endTurn",
)
async def sampling_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
@server.tool("test_backwards_compat")
async def test_tool(message: str, ctx: Context) -> bool:
# Call create_message WITHOUT tools
result = await ctx.session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
# Backwards compat: result should be CreateMessageResult
assert isinstance(result, CreateMessageResult)
# Content should be single (not a list) - this is the key backwards compat check
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello from LLM"
# CreateMessageResult should NOT have content_as_list (that's on WithTools)
assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None))
return True
async with Client(server, sampling_callback=sampling_callback) as client:
result = await client.call_tool("test_backwards_compat", {"message": "Test"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
@pytest.mark.anyio
async def test_create_message_result_with_tools_type():
"""Test that CreateMessageResultWithTools supports content_as_list."""
# Test the type itself, not the overload (overload requires client capability setup)
result = CreateMessageResultWithTools(
role="assistant",
content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}),
model="test-model",
stop_reason="toolUse",
)
# CreateMessageResultWithTools should have content_as_list
content_list = result.content_as_list
assert len(content_list) == 1
assert content_list[0].type == "tool_use"
# It should also work with array content
result_array = CreateMessageResultWithTools(
role="assistant",
content=[
TextContent(type="text", text="Let me check the weather"),
ToolUseContent(type="tool_use", id="call_456", name="get_weather", input={"city": "NYC"}),
],
model="test-model",
stop_reason="toolUse",
)
content_list_array = result_array.content_as_list
assert len(content_list_array) == 2
assert content_list_array[0].type == "text"
assert content_list_array[1].type == "tool_use"