|
64 | 64 | logger = logging.getLogger("google_adk." + __name__) |
65 | 65 |
|
66 | 66 |
|
67 | | -def _validate_header_value(state_key: str, value: Any) -> None: |
68 | | - """Validates that a state value is suitable for use in a header.""" |
| 67 | +def _validate_header_value( |
| 68 | + state_key: str, value: Any, strict: bool = False |
| 69 | +) -> None: |
| 70 | + """Validates that a state value is suitable for use in a header. |
| 71 | +
|
| 72 | + Args: |
| 73 | + state_key: The key being validated. |
| 74 | + value: The value to validate. |
| 75 | + strict: If True, raises ValueError for non-primitive types. |
| 76 | +
|
| 77 | + Raises: |
| 78 | + ValueError: If strict=True and value is not a primitive type. |
| 79 | + """ |
69 | 80 | if not isinstance(value, (str, int, float, bool)): |
70 | | - logger.warning( |
71 | | - 'Value for state key "%s" is of type %s, which may not serialize' |
72 | | - ' correctly into a header. Consider pre-serializing complex values or' |
73 | | - ' using state_header_format.', |
74 | | - state_key, |
75 | | - type(value).__name__, |
| 81 | + msg = ( |
| 82 | + f'Value for state key "{state_key}" is of type' |
| 83 | + f' {type(value).__name__}, which may not serialize correctly into a' |
| 84 | + ' header. Consider pre-serializing complex values or using' |
| 85 | + ' state_header_format.' |
76 | 86 | ) |
| 87 | + if strict: |
| 88 | + raise ValueError(msg) |
| 89 | + logger.warning(msg) |
77 | 90 |
|
78 | 91 |
|
79 | 92 | def create_session_state_header_provider( |
80 | 93 | state_key: str, |
81 | 94 | header_name: str = "Authorization", |
82 | 95 | header_format: str = "Bearer {value}", |
83 | 96 | default_value: Optional[str] = None, |
| 97 | + strict: bool = False, |
84 | 98 | ) -> HeaderProvider: |
85 | 99 | """Creates a header provider that extracts values from session state. |
86 | 100 |
|
87 | 101 | This utility function generates a header_provider callable that can be used |
88 | 102 | with McpToolset to automatically extract values from the session state and |
89 | 103 | format them as HTTP headers for MCP server connections. |
90 | 104 |
|
| 105 | + .. warning:: |
| 106 | + **Security Best Practice**: For sensitive, short-lived tokens like JWTs, |
| 107 | + use ``request_state`` instead of ``session.state`` to avoid persisting |
| 108 | + sensitive data to the database. Pass tokens via |
| 109 | + ``RunAgentRequest.request_state``, which will override ``session.state`` |
| 110 | + for the duration of the request without being persisted. |
| 111 | +
|
91 | 112 | Args: |
92 | | - state_key: The key to look up in session.state. |
| 113 | + state_key: The key to look up in session.state (or request_state). |
93 | 114 | header_name: The HTTP header name to set (default: 'Authorization'). |
94 | 115 | header_format: Format string for the header value. Use {value} as a |
95 | 116 | placeholder for the state value (default: 'Bearer {value}'). |
96 | 117 | default_value: Default value if state_key is not found in session state. |
97 | 118 | If None, the header is omitted when the key is missing. |
| 119 | + strict: If True, raises ValueError when non-primitive types are |
| 120 | + encountered. If False (default), logs a warning instead. |
98 | 121 |
|
99 | 122 | Returns: |
100 | 123 | A callable that takes a ReadonlyContext and returns a dictionary of |
101 | 124 | headers to be used for the MCP session. |
102 | 125 |
|
| 126 | + Raises: |
| 127 | + ValueError: If strict=True and a non-primitive type is found in state. |
| 128 | +
|
103 | 129 | Example:: |
104 | 130 |
|
| 131 | + # Example 1: Using request_state for JWT tokens (recommended) |
105 | 132 | toolset = McpToolset( |
106 | 133 | connection_params=StreamableHTTPConnectionParams( |
107 | 134 | url="http://api.example.com/mcp" |
108 | 135 | ), |
109 | 136 | header_provider=create_session_state_header_provider( |
110 | | - state_key="jwt_token", |
| 137 | + state_key="jwt_token", # Will read from request_state first |
111 | 138 | header_name="Authorization", |
112 | 139 | header_format="Bearer {value}" |
113 | 140 | ) |
114 | 141 | ) |
| 142 | +
|
| 143 | + # Client sends request with ephemeral JWT |
| 144 | + response = await agent.run( |
| 145 | + RunAgentRequest( |
| 146 | + session_id="user-123", |
| 147 | + request_state={"jwt_token": "eyJhbG..."} # Ephemeral, not persisted |
| 148 | + ) |
| 149 | + ) |
115 | 150 | """ |
116 | 151 |
|
117 | 152 | def provider(ctx: ReadonlyContext) -> Dict[str, str]: |
118 | 153 | value = ctx.state.get(state_key, default_value) |
119 | | - if value is None: |
| 154 | + # Skip header if value is None or empty string |
| 155 | + if value is None or value == "": |
120 | 156 | return {} |
121 | | - _validate_header_value(state_key, value) |
| 157 | + _validate_header_value(state_key, value, strict=strict) |
122 | 158 | formatted_value = header_format.format(value=value) |
123 | 159 | return {header_name: formatted_value} |
124 | 160 |
|
|
0 commit comments