Skip to content

Commit 35de869

Browse files
committed
refactor: Introduce HeaderProvider type and modularize header generation logic in MCP tools.
1 parent 8b5c296 commit 35de869

3 files changed

Lines changed: 72 additions & 39 deletions

File tree

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from ..base_authenticated_tool import BaseAuthenticatedTool
5757
# import
5858
from ..tool_context import ToolContext
59+
from .types import HeaderProvider
5960

6061
logger = logging.getLogger("google_adk." + __name__)
6162

@@ -78,9 +79,7 @@ def __init__(
7879
auth_scheme: Optional[AuthScheme] = None,
7980
auth_credential: Optional[AuthCredential] = None,
8081
require_confirmation: Union[bool, Callable[..., bool]] = False,
81-
header_provider: Optional[
82-
Callable[[ReadonlyContext], Dict[str, str]]
83-
] = None,
82+
header_provider: Optional[HeaderProvider] = None,
8483
):
8584
"""Initializes an McpTool.
8685

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import asyncio
1818
import logging
1919
import sys
20+
from typing import Any
2021
from typing import Callable
2122
from typing import Dict
2223
from typing import List
@@ -57,16 +58,30 @@
5758
raise e
5859

5960
from .mcp_tool import MCPTool
61+
from .types import HeaderProvider
62+
6063

6164
logger = logging.getLogger("google_adk." + __name__)
6265

6366

67+
def _validate_header_value(state_key: str, value: Any) -> None:
68+
"""Validates that a state value is suitable for use in a header."""
69+
if not isinstance(value, (str, int, float, bool)):
70+
logger.warning(
71+
'Value for state key "%s" is of type %s, which may not serialize'
72+
' correctly into a header. Consider pre-serializing complex values or'
73+
' using state_header_format.',
74+
state_key,
75+
type(value).__name__,
76+
)
77+
78+
6479
def create_session_state_header_provider(
6580
state_key: str,
6681
header_name: str = "Authorization",
6782
header_format: str = "Bearer {value}",
6883
default_value: Optional[str] = None,
69-
) -> Callable[[ReadonlyContext], Dict[str, str]]:
84+
) -> HeaderProvider:
7085
"""Creates a header provider that extracts values from session state.
7186
7287
This utility function generates a header_provider callable that can be used
@@ -103,20 +118,34 @@ def provider(ctx: ReadonlyContext) -> Dict[str, str]:
103118
value = ctx.state.get(state_key, default_value)
104119
if value is None:
105120
return {}
106-
if not isinstance(value, (str, int, float, bool)):
107-
logger.warning(
108-
'Value for state key "%s" is of type %s, which may not serialize'
109-
' correctly into a header. Consider pre-serializing complex values or'
110-
' using a different header_format.',
111-
state_key,
112-
type(value).__name__,
113-
)
121+
_validate_header_value(state_key, value)
114122
formatted_value = header_format.format(value=value)
115123
return {header_name: formatted_value}
116124

117125
return provider
118126

119127

128+
def create_combined_header_provider(
129+
providers: List[HeaderProvider],
130+
) -> HeaderProvider:
131+
"""Creates a header provider that combines multiple providers.
132+
133+
Args:
134+
providers: A list of header providers to combine.
135+
136+
Returns:
137+
A single header provider that merges the results of all input providers.
138+
"""
139+
140+
def combined_provider(ctx: ReadonlyContext) -> Dict[str, str]:
141+
headers = {}
142+
for provider in providers:
143+
headers.update(provider(ctx))
144+
return headers
145+
146+
return combined_provider
147+
148+
120149
class McpToolset(BaseToolset):
121150
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
122151
@@ -162,9 +191,7 @@ def __init__(
162191
auth_scheme: Optional[AuthScheme] = None,
163192
auth_credential: Optional[AuthCredential] = None,
164193
require_confirmation: Union[bool, Callable[..., bool]] = False,
165-
header_provider: Optional[
166-
Callable[[ReadonlyContext], Dict[str, str]]
167-
] = None,
194+
header_provider: Optional[HeaderProvider] = None,
168195
):
169196
"""Initializes the McpToolset.
170197
@@ -298,30 +325,17 @@ def from_config(
298325
state_mapping = mcp_toolset_config.state_header_mapping
299326
state_format = mcp_toolset_config.state_header_format or {}
300327

301-
def config_based_header_provider(
302-
ctx: ReadonlyContext,
303-
) -> Dict[str, str]:
304-
headers = {}
305-
for state_key, header_name in state_mapping.items():
306-
value = ctx.state.get(state_key)
307-
if value is not None:
308-
if not isinstance(value, (str, int, float, bool)):
309-
logger.warning(
310-
'Value for state key "%s" is of type %s, which may not'
311-
' serialize correctly into a header. Consider pre-serializing'
312-
' complex values or using a different header_format.',
313-
state_key,
314-
type(value).__name__,
315-
)
316-
# Apply formatting if specified for this header
317-
if header_name in state_format:
318-
formatted_value = state_format[header_name].format(value=value)
319-
else:
320-
formatted_value = str(value)
321-
headers[header_name] = formatted_value
322-
return headers
323-
324-
header_provider = config_based_header_provider
328+
providers = [
329+
create_session_state_header_provider(
330+
state_key=state_key,
331+
header_name=header_name,
332+
header_format=state_format.get(header_name, "{value}"),
333+
default_value=None,
334+
)
335+
for state_key, header_name in state_mapping.items()
336+
]
337+
338+
header_provider = create_combined_header_provider(providers)
325339

326340
return cls(
327341
connection_params=connection_params,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Callable
16+
from typing import Dict
17+
18+
from ...agents.readonly_context import ReadonlyContext
19+
20+
HeaderProvider = Callable[[ReadonlyContext], Dict[str, str]]

0 commit comments

Comments
 (0)