diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 168e5f4493a1e8..f5d936e140017d 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -100,16 +100,27 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: elif not isinstance(credentials["api_key_value"], str): raise ToolProviderCredentialValidationError("api_key_value must be a string") + # NOTE: `ToolRuntime` can be shared across repeated tool invocations (multi-step agents, retries, parallel + # calls). Do not mutate `runtime.credentials` when assembling request headers. + api_key_value = credentials["api_key_value"] if "api_key_header_prefix" in credentials: api_key_header_prefix = credentials["api_key_header_prefix"] - if api_key_header_prefix == "basic" and credentials["api_key_value"]: - credentials["api_key_value"] = f"Basic {credentials['api_key_value']}" - elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: - credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}" + if ( + api_key_header_prefix == "basic" + and api_key_value + and not api_key_value.lower().startswith("basic ") + ): + api_key_value = f"Basic {api_key_value}" + elif ( + api_key_header_prefix == "bearer" + and api_key_value + and not api_key_value.lower().startswith("bearer ") + ): + api_key_value = f"Bearer {api_key_value}" elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials["api_key_value"] + headers[api_key_header] = api_key_value elif credentials["auth_type"] == "api_key_query": # For query parameter authentication, we don't add anything to headers diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py index f525baeaf237c5..fd79be528d0085 100644 --- a/api/tests/unit_tests/core/tools/test_custom_tool.py +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -91,6 +91,37 @@ def test_assembling_request_auth_header_assembly(): assert tool.assembling_request(parameters={}) == {} +def test_assembling_request_does_not_mutate_runtime_credentials_and_avoids_double_prefix(): + tool = _build_tool() + + tool.runtime.credentials = { + "auth_type": "api_key_header", + "api_key_header_prefix": "bearer", + "api_key_value": "abc", + } + headers1 = tool.assembling_request(parameters={}) + headers2 = tool.assembling_request(parameters={}) + assert headers1["Authorization"] == "Bearer abc" + assert headers2["Authorization"] == "Bearer abc" + assert tool.runtime.credentials["api_key_value"] == "abc" + + tool.runtime.credentials = { + "auth_type": "api_key_header", + "api_key_header_prefix": "bearer", + "api_key_value": "Bearer already-prefixed", + } + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Bearer already-prefixed" + + tool.runtime.credentials = { + "auth_type": "api_key_header", + "api_key_header_prefix": "basic", + "api_key_value": "Basic already-prefixed", + } + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Basic already-prefixed" + + def test_assembling_request_runtime_auth_errors(): tool = _build_tool()