|
17 | 17 | import asyncio |
18 | 18 | import logging |
19 | 19 | import sys |
| 20 | +from typing import Any |
20 | 21 | from typing import Callable |
21 | 22 | from typing import Dict |
22 | 23 | from typing import List |
|
57 | 58 | raise e |
58 | 59 |
|
59 | 60 | from .mcp_tool import MCPTool |
| 61 | +from .types import HeaderProvider |
| 62 | + |
60 | 63 |
|
61 | 64 | logger = logging.getLogger("google_adk." + __name__) |
62 | 65 |
|
63 | 66 |
|
| 67 | +def _validate_header_value(state_key: str, value: Any) -> None: |
| 68 | + """Validates that a state value is suitable for use in a header.""" |
| 69 | + if not isinstance(value, (str, int, float, bool)): |
| 70 | + logger.warning( |
| 71 | + 'Value for state key "%s" is of type %s, which may not serialize' |
| 72 | + ' correctly into a header. Consider pre-serializing complex values or' |
| 73 | + ' using state_header_format.', |
| 74 | + state_key, |
| 75 | + type(value).__name__, |
| 76 | + ) |
| 77 | + |
| 78 | + |
64 | 79 | def create_session_state_header_provider( |
65 | 80 | state_key: str, |
66 | 81 | header_name: str = "Authorization", |
67 | 82 | header_format: str = "Bearer {value}", |
68 | 83 | default_value: Optional[str] = None, |
69 | | -) -> Callable[[ReadonlyContext], Dict[str, str]]: |
| 84 | +) -> HeaderProvider: |
70 | 85 | """Creates a header provider that extracts values from session state. |
71 | 86 |
|
72 | 87 | This utility function generates a header_provider callable that can be used |
@@ -103,20 +118,34 @@ def provider(ctx: ReadonlyContext) -> Dict[str, str]: |
103 | 118 | value = ctx.state.get(state_key, default_value) |
104 | 119 | if value is None: |
105 | 120 | return {} |
106 | | - if not isinstance(value, (str, int, float, bool)): |
107 | | - logger.warning( |
108 | | - 'Value for state key "%s" is of type %s, which may not serialize' |
109 | | - ' correctly into a header. Consider pre-serializing complex values or' |
110 | | - ' using a different header_format.', |
111 | | - state_key, |
112 | | - type(value).__name__, |
113 | | - ) |
| 121 | + _validate_header_value(state_key, value) |
114 | 122 | formatted_value = header_format.format(value=value) |
115 | 123 | return {header_name: formatted_value} |
116 | 124 |
|
117 | 125 | return provider |
118 | 126 |
|
119 | 127 |
|
| 128 | +def create_combined_header_provider( |
| 129 | + providers: List[HeaderProvider], |
| 130 | +) -> HeaderProvider: |
| 131 | + """Creates a header provider that combines multiple providers. |
| 132 | +
|
| 133 | + Args: |
| 134 | + providers: A list of header providers to combine. |
| 135 | +
|
| 136 | + Returns: |
| 137 | + A single header provider that merges the results of all input providers. |
| 138 | + """ |
| 139 | + |
| 140 | + def combined_provider(ctx: ReadonlyContext) -> Dict[str, str]: |
| 141 | + headers = {} |
| 142 | + for provider in providers: |
| 143 | + headers.update(provider(ctx)) |
| 144 | + return headers |
| 145 | + |
| 146 | + return combined_provider |
| 147 | + |
| 148 | + |
120 | 149 | class McpToolset(BaseToolset): |
121 | 150 | """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. |
122 | 151 |
|
@@ -162,9 +191,7 @@ def __init__( |
162 | 191 | auth_scheme: Optional[AuthScheme] = None, |
163 | 192 | auth_credential: Optional[AuthCredential] = None, |
164 | 193 | require_confirmation: Union[bool, Callable[..., bool]] = False, |
165 | | - header_provider: Optional[ |
166 | | - Callable[[ReadonlyContext], Dict[str, str]] |
167 | | - ] = None, |
| 194 | + header_provider: Optional[HeaderProvider] = None, |
168 | 195 | ): |
169 | 196 | """Initializes the McpToolset. |
170 | 197 |
|
@@ -298,30 +325,17 @@ def from_config( |
298 | 325 | state_mapping = mcp_toolset_config.state_header_mapping |
299 | 326 | state_format = mcp_toolset_config.state_header_format or {} |
300 | 327 |
|
301 | | - def config_based_header_provider( |
302 | | - ctx: ReadonlyContext, |
303 | | - ) -> Dict[str, str]: |
304 | | - headers = {} |
305 | | - for state_key, header_name in state_mapping.items(): |
306 | | - value = ctx.state.get(state_key) |
307 | | - if value is not None: |
308 | | - if not isinstance(value, (str, int, float, bool)): |
309 | | - logger.warning( |
310 | | - 'Value for state key "%s" is of type %s, which may not' |
311 | | - ' serialize correctly into a header. Consider pre-serializing' |
312 | | - ' complex values or using a different header_format.', |
313 | | - state_key, |
314 | | - type(value).__name__, |
315 | | - ) |
316 | | - # Apply formatting if specified for this header |
317 | | - if header_name in state_format: |
318 | | - formatted_value = state_format[header_name].format(value=value) |
319 | | - else: |
320 | | - formatted_value = str(value) |
321 | | - headers[header_name] = formatted_value |
322 | | - return headers |
323 | | - |
324 | | - header_provider = config_based_header_provider |
| 328 | + providers = [ |
| 329 | + create_session_state_header_provider( |
| 330 | + state_key=state_key, |
| 331 | + header_name=header_name, |
| 332 | + header_format=state_format.get(header_name, "{value}"), |
| 333 | + default_value=None, |
| 334 | + ) |
| 335 | + for state_key, header_name in state_mapping.items() |
| 336 | + ] |
| 337 | + |
| 338 | + header_provider = create_combined_header_provider(providers) |
325 | 339 |
|
326 | 340 | return cls( |
327 | 341 | connection_params=connection_params, |
|
0 commit comments