Skip to content

Commit f5ca9e2

Browse files
committed
feat: Add strict mode to create_session_state_header_provider to raise errors for non-primitive types and skip empty string values.
1 parent 35de869 commit f5ca9e2

2 files changed

Lines changed: 96 additions & 12 deletions

File tree

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,61 +64,97 @@
6464
logger = logging.getLogger("google_adk." + __name__)
6565

6666

67-
def _validate_header_value(state_key: str, value: Any) -> None:
68-
"""Validates that a state value is suitable for use in a header."""
67+
def _validate_header_value(
68+
state_key: str, value: Any, strict: bool = False
69+
) -> None:
70+
"""Validates that a state value is suitable for use in a header.
71+
72+
Args:
73+
state_key: The key being validated.
74+
value: The value to validate.
75+
strict: If True, raises ValueError for non-primitive types.
76+
77+
Raises:
78+
ValueError: If strict=True and value is not a primitive type.
79+
"""
6980
if not isinstance(value, (str, int, float, bool)):
70-
logger.warning(
71-
'Value for state key "%s" is of type %s, which may not serialize'
72-
' correctly into a header. Consider pre-serializing complex values or'
73-
' using state_header_format.',
74-
state_key,
75-
type(value).__name__,
81+
msg = (
82+
f'Value for state key "{state_key}" is of type'
83+
f' {type(value).__name__}, which may not serialize correctly into a'
84+
' header. Consider pre-serializing complex values or using'
85+
' state_header_format.'
7686
)
87+
if strict:
88+
raise ValueError(msg)
89+
logger.warning(msg)
7790

7891

7992
def create_session_state_header_provider(
8093
state_key: str,
8194
header_name: str = "Authorization",
8295
header_format: str = "Bearer {value}",
8396
default_value: Optional[str] = None,
97+
strict: bool = False,
8498
) -> HeaderProvider:
8599
"""Creates a header provider that extracts values from session state.
86100
87101
This utility function generates a header_provider callable that can be used
88102
with McpToolset to automatically extract values from the session state and
89103
format them as HTTP headers for MCP server connections.
90104
105+
.. warning::
106+
**Security Best Practice**: For sensitive, short-lived tokens like JWTs,
107+
use ``request_state`` instead of ``session.state`` to avoid persisting
108+
sensitive data to the database. Pass tokens via
109+
``RunAgentRequest.request_state``, which will override ``session.state``
110+
for the duration of the request without being persisted.
111+
91112
Args:
92-
state_key: The key to look up in session.state.
113+
state_key: The key to look up in session.state (or request_state).
93114
header_name: The HTTP header name to set (default: 'Authorization').
94115
header_format: Format string for the header value. Use {value} as a
95116
placeholder for the state value (default: 'Bearer {value}').
96117
default_value: Default value if state_key is not found in session state.
97118
If None, the header is omitted when the key is missing.
119+
strict: If True, raises ValueError when non-primitive types are
120+
encountered. If False (default), logs a warning instead.
98121
99122
Returns:
100123
A callable that takes a ReadonlyContext and returns a dictionary of
101124
headers to be used for the MCP session.
102125
126+
Raises:
127+
ValueError: If strict=True and a non-primitive type is found in state.
128+
103129
Example::
104130
131+
# Example 1: Using request_state for JWT tokens (recommended)
105132
toolset = McpToolset(
106133
connection_params=StreamableHTTPConnectionParams(
107134
url="http://api.example.com/mcp"
108135
),
109136
header_provider=create_session_state_header_provider(
110-
state_key="jwt_token",
137+
state_key="jwt_token", # Will read from request_state first
111138
header_name="Authorization",
112139
header_format="Bearer {value}"
113140
)
114141
)
142+
143+
# Client sends request with ephemeral JWT
144+
response = await agent.run(
145+
RunAgentRequest(
146+
session_id="user-123",
147+
request_state={"jwt_token": "eyJhbG..."} # Ephemeral, not persisted
148+
)
149+
)
115150
"""
116151

117152
def provider(ctx: ReadonlyContext) -> Dict[str, str]:
118153
value = ctx.state.get(state_key, default_value)
119-
if value is None:
154+
# Skip header if value is None or empty string
155+
if value is None or value == "":
120156
return {}
121-
_validate_header_value(state_key, value)
157+
_validate_header_value(state_key, value, strict=strict)
122158
formatted_value = header_format.format(value=value)
123159
return {header_name: formatted_value}
124160

tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,54 @@ def test_none_value_in_state_returns_empty(self):
129129

130130
assert headers == {}
131131

132+
def test_empty_string_value_returns_empty(self):
133+
"""Test that empty string value in state returns empty dict."""
134+
mock_context = Mock(spec=ReadonlyContext)
135+
mock_context.state = {"jwt_token": ""}
136+
137+
provider = create_session_state_header_provider(state_key="jwt_token")
138+
139+
headers = provider(mock_context)
140+
141+
assert headers == {}
142+
143+
def test_strict_mode_with_primitive_types(self):
144+
"""Test that strict mode works properly with primitive types."""
145+
mock_context = Mock(spec=ReadonlyContext)
146+
147+
# Test with string
148+
mock_context.state = {"token": "my-token"}
149+
provider = create_session_state_header_provider(
150+
state_key="token", strict=True
151+
)
152+
headers = provider(mock_context)
153+
assert headers == {"Authorization": "Bearer my-token"}
154+
155+
# Test with int
156+
mock_context.state = {"count": 42}
157+
provider = create_session_state_header_provider(
158+
state_key="count", header_name="X-Count", header_format="{value}", strict=True
159+
)
160+
headers = provider(mock_context)
161+
assert headers == {"X-Count": "42"}
162+
163+
def test_strict_mode_raises_on_non_primitive_types(self):
164+
"""Test that strict mode raises ValueError for non-primitive types."""
165+
mock_context = Mock(spec=ReadonlyContext)
166+
mock_context.state = {"complex_data": {"nested": "dict"}}
167+
168+
provider = create_session_state_header_provider(
169+
state_key="complex_data", strict=True
170+
)
171+
172+
with pytest.raises(ValueError) as exc_info:
173+
provider(mock_context)
174+
175+
assert "complex_data" in str(exc_info.value)
176+
assert "dict" in str(exc_info.value)
177+
assert "may not serialize correctly" in str(exc_info.value)
178+
179+
132180

133181
class TestMcpToolsetConfigStateHeaderMapping:
134182
"""Test suite for state_header_mapping configuration."""

0 commit comments

Comments
 (0)