diff --git a/tools/python/pyproject.toml b/tools/python/pyproject.toml index 74969cbe..c9585969 100644 --- a/tools/python/pyproject.toml +++ b/tools/python/pyproject.toml @@ -19,11 +19,13 @@ openai = ["openai>=1.0.0", "agents>=0.0.84"] langchain = ["langchain>=0.1.0"] crewai = ["crewai>=0.1.0"] strands = ["strands>=0.1.0"] +mcp-payments = ["stripe>=7.0.0"] all = [ "stripe-agent-toolkit[openai]", "stripe-agent-toolkit[langchain]", "stripe-agent-toolkit[crewai]", "stripe-agent-toolkit[strands]", + "stripe-agent-toolkit[mcp-payments]", ] dev = [ "pytest>=7.0.0", diff --git a/tools/python/stripe_agent_toolkit/mcp/__init__.py b/tools/python/stripe_agent_toolkit/mcp/__init__.py new file mode 100644 index 00000000..cc550e1f --- /dev/null +++ b/tools/python/stripe_agent_toolkit/mcp/__init__.py @@ -0,0 +1,5 @@ +"""MCP utilities for Stripe Agent Toolkit.""" + +from .register_paid_tool import register_paid_tool, PaidToolOptions + +__all__ = ["register_paid_tool", "PaidToolOptions"] diff --git a/tools/python/stripe_agent_toolkit/mcp/register_paid_tool.py b/tools/python/stripe_agent_toolkit/mcp/register_paid_tool.py new file mode 100644 index 00000000..1fa60542 --- /dev/null +++ b/tools/python/stripe_agent_toolkit/mcp/register_paid_tool.py @@ -0,0 +1,317 @@ +"""Register a paid MCP tool with Stripe Checkout gating.""" + +from __future__ import annotations + +import inspect +import json +from typing import Any, Callable, Optional +from typing_extensions import TypedDict + +try: + import stripe +except ImportError: # pragma: no cover - exercised via runtime use + stripe = None # type: ignore[assignment] + + +class PaidToolOptions(TypedDict): + """Options for registering a paid MCP tool.""" + + payment_reason: str + meter_event: Optional[str] + stripe_secret_key: str + user_email: str + checkout: dict[str, Any] + + +async def _maybe_await(value: Any) -> Any: + """Await the value when it is awaitable (helps with async mocks).""" + if inspect.isawaitable(value): + return await value + return value + + +def _as_list(data: Any) -> list[Any]: + """Extract API list payload from Stripe responses.""" + if isinstance(data, dict): + maybe_data = data.get("data") + if isinstance(maybe_data, list): + return maybe_data + return [] + + maybe_data = getattr(data, "data", None) + if isinstance(maybe_data, list): + return maybe_data + return [] + + +def _extract_error_message(error: Exception) -> str: + """Extract an actionable error message from Stripe exceptions.""" + raw = getattr(error, "raw", None) + if isinstance(raw, dict) and isinstance(raw.get("message"), str): + return raw["message"] + message = getattr(error, "message", None) + if isinstance(message, str): + return message + return str(error) or "Unknown error" + + +def _make_result( + payload: dict[str, Any], + *, + is_error: bool = False, +) -> dict[str, Any]: + """Format return payload for MCP tool responses.""" + result: dict[str, Any] = { + "content": [ + { + "type": "text", + "text": json.dumps(payload), + } + ] + } + if is_error: + result["isError"] = True + return result + + +async def register_paid_tool( + mcp_server: Any, + tool_name: str, + tool_description: str, + params_schema: Any, + callback: Callable[..., Any], + options: PaidToolOptions, +) -> None: + """Register a paid tool that enforces Stripe payment before execution.""" + line_items = options["checkout"].get("line_items") + price_id: Optional[str] = None + if isinstance(line_items, list): + for item in line_items: + if isinstance(item, dict): + maybe_price = item.get("price") + if isinstance(maybe_price, str): + price_id = maybe_price + break + + if not price_id: + raise ValueError( + "Price ID is required for a paid MCP tool. Learn more about " + "prices: https://docs.stripe.com/products-prices/" + "how-products-and-prices-work" + ) + + if stripe is None: + raise ImportError( + "The Stripe SDK is required. Install with " + "`stripe-agent-toolkit[mcp-payments]`." + ) + + app_info = { + "name": "stripe-agent-toolkit-mcp-payments", + "version": "0.7.0", + "url": "https://github.com/stripe/ai", + } + + if hasattr(stripe, "StripeClient"): + stripe_client = stripe.StripeClient( + options["stripe_secret_key"], + app_info=app_info, + ) + else: + stripe.api_key = options["stripe_secret_key"] + if hasattr(stripe, "set_app_info"): + stripe.set_app_info( + app_info["name"], + app_info["version"], + app_info["url"], + ) + stripe_client = stripe + + async def get_or_create_customer(email: str) -> str: + customers = await _maybe_await( + stripe_client.customers.list({"email": email}) + ) + customer_id: Optional[str] = None + for customer in _as_list(customers): + customer_email = ( + customer.get("email") + if isinstance(customer, dict) + else getattr(customer, "email", None) + ) + if customer_email == email: + customer_id = ( + customer.get("id") + if isinstance(customer, dict) + else getattr(customer, "id", None) + ) + break + + if not customer_id: + customer = await _maybe_await( + stripe_client.customers.create({"email": email}) + ) + if isinstance(customer, dict): + customer_id = customer.get("id") + else: + customer_id = getattr(customer, "id", None) + + if not isinstance(customer_id, str) or not customer_id: + raise RuntimeError("Failed to resolve Stripe customer ID") + return customer_id + + async def is_tool_paid_for(name: str, customer_id: str) -> bool: + sessions = await _maybe_await( + stripe_client.checkout.sessions.list( + {"customer": customer_id, "limit": 100} + ) + ) + paid_session: Optional[Any] = None + for session in _as_list(sessions): + metadata = ( + session.get("metadata") + if isinstance(session, dict) + else getattr(session, "metadata", None) + ) or {} + tool_name_meta = ( + metadata.get("toolName") + if isinstance(metadata, dict) + else getattr(metadata, "toolName", None) + ) + payment_status = ( + session.get("payment_status") + if isinstance(session, dict) + else getattr(session, "payment_status", None) + ) + if tool_name_meta == name and payment_status == "paid": + paid_session = session + break + + if paid_session is None: + return False + + subscription = ( + paid_session.get("subscription") + if isinstance(paid_session, dict) + else getattr(paid_session, "subscription", None) + ) + if subscription: + subs = await _maybe_await( + stripe_client.subscriptions.list( + {"customer": customer_id, "status": "active"} + ) + ) + for sub in _as_list(subs): + items = ( + sub.get("items") + if isinstance(sub, dict) + else getattr(sub, "items", None) + ) + item_data = ( + items.get("data") + if isinstance(items, dict) + else getattr(items, "data", None) + ) + if not isinstance(item_data, list): + continue + for item in item_data: + price = ( + item.get("price") + if isinstance(item, dict) + else getattr(item, "price", None) + ) + item_price_id = ( + price.get("id") + if isinstance(price, dict) + else getattr(price, "id", None) + ) + if item_price_id == price_id: + return True + return False + + return True + + async def create_checkout_session( + payment_type: str, + customer_id: str, + ) -> dict[str, Any]: + try: + checkout = dict(options["checkout"]) + metadata = dict(checkout.get("metadata") or {}) + metadata["toolName"] = tool_name + checkout["metadata"] = metadata + checkout["customer"] = customer_id or None + + session = await _maybe_await( + stripe_client.checkout.sessions.create(checkout) + ) + checkout_url = ( + session.get("url") + if isinstance(session, dict) + else getattr(session, "url", None) + ) + return _make_result( + { + "status": "payment_required", + "data": { + "paymentType": payment_type, + "checkoutUrl": checkout_url, + "paymentReason": options["payment_reason"], + }, + } + ) + except Exception as error: + message = _extract_error_message(error) + return _make_result( + { + "status": "error", + "error": message, + }, + is_error=True, + ) + + async def record_usage(customer_id: str) -> None: + meter_event = options.get("meter_event") + if not meter_event: + return + await _maybe_await( + stripe_client.billing.meter_events.create( + { + "event_name": meter_event, + "payload": { + "stripe_customer_id": customer_id, + "value": "1", + }, + } + ) + ) + + async def wrapped_callback(*args: Any, **kwargs: Any) -> dict[str, Any]: + try: + customer_id = await get_or_create_customer(options["user_email"]) + paid_for_tool = await is_tool_paid_for(tool_name, customer_id) + payment_type = ( + "usageBased" + if options.get("meter_event") + else "oneTimeSubscription" + ) + if not paid_for_tool: + return await create_checkout_session(payment_type, customer_id) + + if payment_type == "usageBased": + await record_usage(customer_id) + + callback_result = callback(*args, **kwargs) + return await _maybe_await(callback_result) + except Exception as error: + message = _extract_error_message(error) + return _make_result( + { + "status": "error", + "error": message, + }, + is_error=True, + ) + + mcp_server.tool(tool_name, tool_description, params_schema)( + wrapped_callback + ) diff --git a/tools/python/tests/test_register_paid_tool.py b/tools/python/tests/test_register_paid_tool.py new file mode 100644 index 00000000..78ff9551 --- /dev/null +++ b/tools/python/tests/test_register_paid_tool.py @@ -0,0 +1,230 @@ +"""Tests for paid MCP tool registration.""" + +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from stripe_agent_toolkit.mcp import register_paid_tool + + +def _base_options() -> dict: + return { + "payment_reason": "Paid tool access", + "meter_event": None, + "stripe_secret_key": "sk_test_123", + "user_email": "user@example.com", + "checkout": { + "line_items": [{"price": "price_123", "quantity": 1}], + "mode": "payment", + "success_url": "https://example.com/success", + "cancel_url": "https://example.com/cancel", + }, + } + + +class FakeServer: + """Simple MCP server mock with decorator tool registration.""" + + def __init__(self): + self.tool = MagicMock(side_effect=self._tool) + self.registered_callback = None + + def _tool(self, name, description, params_schema): + def decorator(callback): + self.registered_callback = callback + return callback + + return decorator + + +@pytest.mark.asyncio +async def test_registers_tool_on_mcp_server(): + server = FakeServer() + options = _base_options() + callback = AsyncMock(return_value={"content": []}) + + stripe_client = MagicMock() + mock_stripe = MagicMock() + mock_stripe.StripeClient.return_value = stripe_client + + with patch( + "stripe_agent_toolkit.mcp.register_paid_tool.stripe", + mock_stripe, + ): + await register_paid_tool( + server, + "my_tool", + "My paid tool", + {"type": "object"}, + callback, + options, + ) + + server.tool.assert_called_once_with( + "my_tool", + "My paid tool", + {"type": "object"}, + ) + assert server.registered_callback is not None + + +@pytest.mark.asyncio +async def test_creates_customer_when_none_exists(): + server = FakeServer() + options = _base_options() + callback = AsyncMock( + return_value={"content": [{"type": "text", "text": "ok"}]} + ) + + stripe_client = MagicMock() + stripe_client.customers.list = AsyncMock(return_value={"data": []}) + stripe_client.customers.create = AsyncMock(return_value={"id": "cus_new"}) + stripe_client.checkout.sessions.list = AsyncMock( + return_value={ + "data": [ + { + "metadata": {"toolName": "my_tool"}, + "payment_status": "paid", + "subscription": None, + } + ] + } + ) + + mock_stripe = MagicMock() + mock_stripe.StripeClient.return_value = stripe_client + + with patch( + "stripe_agent_toolkit.mcp.register_paid_tool.stripe", + mock_stripe, + ): + await register_paid_tool( + server, + "my_tool", + "desc", + {"type": "object"}, + callback, + options, + ) + result = await server.registered_callback({}) + + stripe_client.customers.create.assert_awaited_once_with( + {"email": "user@example.com"} + ) + callback.assert_awaited_once() + assert result["content"][0]["text"] == "ok" + + +@pytest.mark.asyncio +async def test_creates_checkout_session_for_unpaid_tool(): + server = FakeServer() + options = _base_options() + callback = AsyncMock( + return_value={"content": [{"type": "text", "text": "ok"}]} + ) + + stripe_client = MagicMock() + stripe_client.customers.list = AsyncMock( + return_value={"data": [{"id": "cus_123", "email": "user@example.com"}]} + ) + stripe_client.checkout.sessions.list = AsyncMock(return_value={"data": []}) + stripe_client.checkout.sessions.create = AsyncMock( + return_value={"url": "https://checkout.stripe.com/test"} + ) + + mock_stripe = MagicMock() + mock_stripe.StripeClient.return_value = stripe_client + + with patch( + "stripe_agent_toolkit.mcp.register_paid_tool.stripe", + mock_stripe, + ): + await register_paid_tool( + server, + "my_tool", + "desc", + {"type": "object"}, + callback, + options, + ) + result = await server.registered_callback({}) + + callback.assert_not_called() + payload = json.loads(result["content"][0]["text"]) + assert payload["status"] == "payment_required" + assert payload["data"]["checkoutUrl"] == "https://checkout.stripe.com/test" + assert payload["data"]["paymentType"] == "oneTimeSubscription" + + +@pytest.mark.asyncio +async def test_usage_based_meter_event_recorded(): + server = FakeServer() + options = _base_options() + options["meter_event"] = "tool_usage" + callback = AsyncMock( + return_value={"content": [{"type": "text", "text": "ok"}]} + ) + + stripe_client = MagicMock() + stripe_client.customers.list = AsyncMock( + return_value={"data": [{"id": "cus_123", "email": "user@example.com"}]} + ) + stripe_client.checkout.sessions.list = AsyncMock( + return_value={ + "data": [ + { + "metadata": {"toolName": "my_tool"}, + "payment_status": "paid", + "subscription": None, + } + ] + } + ) + stripe_client.billing = SimpleNamespace( + meter_events=SimpleNamespace(create=AsyncMock(return_value={})) + ) + + mock_stripe = MagicMock() + mock_stripe.StripeClient.return_value = stripe_client + + with patch( + "stripe_agent_toolkit.mcp.register_paid_tool.stripe", + mock_stripe, + ): + await register_paid_tool( + server, + "my_tool", + "desc", + {"type": "object"}, + callback, + options, + ) + await server.registered_callback({}) + + stripe_client.billing.meter_events.create.assert_awaited_once_with( + { + "event_name": "tool_usage", + "payload": {"stripe_customer_id": "cus_123", "value": "1"}, + } + ) + callback.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_missing_price_id_raises_error(): + server = FakeServer() + options = _base_options() + options["checkout"]["line_items"] = [{"quantity": 1}] + callback = AsyncMock(return_value={"content": []}) + + with pytest.raises(ValueError, match="Price ID is required"): + await register_paid_tool( + server, + "my_tool", + "desc", + {"type": "object"}, + callback, + options, + )