Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions api/core/tools/custom_tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,27 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
elif not isinstance(credentials["api_key_value"], str):
raise ToolProviderCredentialValidationError("api_key_value must be a string")

# NOTE: `ToolRuntime` can be shared across repeated tool invocations (multi-step agents, retries, parallel
# calls). Do not mutate `runtime.credentials` when assembling request headers.
api_key_value = credentials["api_key_value"]
if "api_key_header_prefix" in credentials:
api_key_header_prefix = credentials["api_key_header_prefix"]
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
if (
api_key_header_prefix == "basic"
and api_key_value
and not api_key_value.lower().startswith("basic ")
):
api_key_value = f"Basic {api_key_value}"
elif (
api_key_header_prefix == "bearer"
and api_key_value
and not api_key_value.lower().startswith("bearer ")
):
api_key_value = f"Bearer {api_key_value}"
elif api_key_header_prefix == "custom":
pass

headers[api_key_header] = credentials["api_key_value"]
headers[api_key_header] = api_key_value

elif credentials["auth_type"] == "api_key_query":
# For query parameter authentication, we don't add anything to headers
Expand Down
31 changes: 31 additions & 0 deletions api/tests/unit_tests/core/tools/test_custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,37 @@ def test_assembling_request_auth_header_assembly():
assert tool.assembling_request(parameters={}) == {}


def test_assembling_request_does_not_mutate_runtime_credentials_and_avoids_double_prefix():
tool = _build_tool()

tool.runtime.credentials = {
"auth_type": "api_key_header",
"api_key_header_prefix": "bearer",
"api_key_value": "abc",
}
headers1 = tool.assembling_request(parameters={})
headers2 = tool.assembling_request(parameters={})
assert headers1["Authorization"] == "Bearer abc"
assert headers2["Authorization"] == "Bearer abc"
assert tool.runtime.credentials["api_key_value"] == "abc"

tool.runtime.credentials = {
"auth_type": "api_key_header",
"api_key_header_prefix": "bearer",
"api_key_value": "Bearer already-prefixed",
}
headers = tool.assembling_request(parameters={})
assert headers["Authorization"] == "Bearer already-prefixed"

tool.runtime.credentials = {
"auth_type": "api_key_header",
"api_key_header_prefix": "basic",
"api_key_value": "Basic already-prefixed",
}
headers = tool.assembling_request(parameters={})
assert headers["Authorization"] == "Basic already-prefixed"


def test_assembling_request_runtime_auth_errors():
tool = _build_tool()

Expand Down