From f66c4fcb24f734c2aff132cf8ea45ee0e5b5b42a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E8=A1=8D=E5=8D=8E?= Date: Tue, 23 Dec 2025 21:16:28 +0800 Subject: [PATCH] add token limit --- oxygent/config.py | 48 +++++++ oxygent/mas.py | 83 ++++++++++++ oxygent/oxy/base_oxy.py | 27 ++++ oxygent/rate_limiter.py | 289 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 447 insertions(+) create mode 100644 oxygent/rate_limiter.py diff --git a/oxygent/config.py b/oxygent/config.py index 5572c5ff..b1b0fbfc 100644 --- a/oxygent/config.py +++ b/oxygent/config.py @@ -109,6 +109,12 @@ class Config: "mcp_is_keep_alive": True, "is_concurrent_init": True, }, + "rate_limiter": { + "enabled": False, + "default_rate": 1.0, + "default_capacity": 10, + "per_oxy_limits": {}, + }, } @classmethod @@ -623,3 +629,45 @@ def set_tool_is_concurrent_init(cls, is_concurrent_init): @classmethod def get_tool_is_concurrent_init(cls): return cls.get_module_config("tool", "is_concurrent_init") + + """ rate_limiter """ + + @classmethod + def set_rate_limiter_config(cls, rate_limiter_config): + return cls.set_module_config("rate_limiter", rate_limiter_config) + + @classmethod + def get_rate_limiter_config(cls): + return cls.get_module_config("rate_limiter") + + @classmethod + def set_rate_limiter_enabled(cls, enabled: bool): + cls.set_module_config("rate_limiter", "enabled", enabled) + + @classmethod + def get_rate_limiter_enabled(cls) -> bool: + return cls.get_module_config("rate_limiter", "enabled", False) + + @classmethod + def set_rate_limiter_default_rate(cls, rate: float): + cls.set_module_config("rate_limiter", "default_rate", rate) + + @classmethod + def get_rate_limiter_default_rate(cls) -> float: + return cls.get_module_config("rate_limiter", "default_rate", 1.0) + + @classmethod + def set_rate_limiter_default_capacity(cls, capacity: int): + cls.set_module_config("rate_limiter", "default_capacity", capacity) + + @classmethod + def get_rate_limiter_default_capacity(cls) -> int: + return cls.get_module_config("rate_limiter", "default_capacity", 10) + + @classmethod + def set_rate_limiter_per_oxy_limits(cls, per_oxy_limits: dict): + cls.set_module_config("rate_limiter", "per_oxy_limits", per_oxy_limits) + + @classmethod + def get_rate_limiter_per_oxy_limits(cls) -> dict: + return cls.get_module_config("rate_limiter", "per_oxy_limits", {}) diff --git a/oxygent/mas.py b/oxygent/mas.py index 2b368c8c..a8af3983 100644 --- a/oxygent/mas.py +++ b/oxygent/mas.py @@ -35,6 +35,7 @@ from .db_factory import DBFactory from .log_setup import setup_logging from .oxy import Oxy +from .rate_limiter import get_rate_limit_manager, RateLimitManager from .oxy.agents.base_agent import BaseAgent from .oxy.agents.remote_agent import RemoteAgent from .oxy.base_flow import BaseFlow @@ -97,6 +98,10 @@ class MAS(BaseModel): func_interceptor: Optional[Callable] = Field( lambda x: None, exclude=True, description="interceptor function" ) + + rate_limiter: Optional[RateLimitManager] = Field( + None, exclude=True, description="Rate limiter manager" + ) func_process_message: Optional[Callable] = Field( lambda x, oxy_request: x, exclude=True, description="process message function" @@ -127,6 +132,10 @@ def __init__(self, **kwargs): Config.set_app_name(self.name) else: self.name = Config.get_app_name() + + # Initialize rate limiter if enabled + if Config.get_rate_limiter_enabled(): + self._init_rate_limiter() async def __aenter__(self): await self.init() @@ -181,6 +190,14 @@ def add_oxy(self, oxy: Oxy): if oxy.name in self.oxy_name_to_oxy: raise Exception(f"oxy [{oxy.name}] already exists.") self.oxy_name_to_oxy[oxy.name] = oxy + + # Create rate limiter for the new oxy if rate limiting is enabled + if self.rate_limiter and Config.get_rate_limiter_enabled(): + per_oxy_limits = Config.get_rate_limiter_per_oxy_limits() + oxy_limits = per_oxy_limits.get(oxy.name, {}) + rate = oxy_limits.get("rate", Config.get_rate_limiter_default_rate()) + capacity = oxy_limits.get("capacity", Config.get_rate_limiter_default_capacity()) + self._create_oxy_limiter(oxy.name, rate, capacity) def add_oxy_list(self, oxy_list: list[Oxy]): """Register a list of Oxy objects. @@ -190,6 +207,72 @@ def add_oxy_list(self, oxy_list: list[Oxy]): """ for oxy in oxy_list: self.add_oxy(oxy) + + def _init_rate_limiter(self): + """Initialize the rate limiter manager and create limiters for oxy instances.""" + logger.info("Initializing rate limiter...") + self.rate_limiter = get_rate_limit_manager() + self.rate_limiter.enable() + + # Create default limiter with configuration + default_rate = Config.get_rate_limiter_default_rate() + default_capacity = Config.get_rate_limiter_default_capacity() + + # Create limiters for existing oxy instances + for oxy_name, oxy in self.oxy_name_to_oxy.items(): + self._create_oxy_limiter(oxy_name, default_rate, default_capacity) + + # Apply per-oxy limits from configuration + per_oxy_limits = Config.get_rate_limiter_per_oxy_limits() + for oxy_name, limits in per_oxy_limits.items(): + if oxy_name in self.oxy_name_to_oxy: + rate = limits.get("rate", default_rate) + capacity = limits.get("capacity", default_capacity) + self._create_oxy_limiter(oxy_name, rate, capacity) + + logger.info(f"Rate limiter initialized with {len(self.rate_limiter._limiters)} limiters") + + def _create_oxy_limiter(self, oxy_name: str, rate: float, capacity: int): + """Create a rate limiter for a specific oxy instance.""" + if self.rate_limiter: + self.rate_limiter.create_limiter(oxy_name, rate, capacity) + logger.debug(f"Created rate limiter for oxy '{oxy_name}': rate={rate}, capacity={capacity}") + + def check_rate_limit(self, oxy_name: str, tokens: int = 1) -> bool: + """Check if rate limit allows the operation for an oxy instance. + + Args: + oxy_name: Name of the oxy instance + tokens: Number of tokens to acquire + + Returns: + True if operation is allowed, False otherwise + """ + if not self.rate_limiter: + return True + return self.rate_limiter.check_rate_limit(oxy_name, tokens) + + async def check_rate_limit_async(self, oxy_name: str, tokens: int = 1) -> bool: + """Async check if rate limit allows the operation for an oxy instance. + + Args: + oxy_name: Name of the oxy instance + tokens: Number of tokens to acquire + + Returns: + True if operation is allowed, False otherwise + """ + if not self.rate_limiter: + return True + return await self.rate_limiter.check_rate_limit_async(oxy_name, tokens) + + def get_rate_limiter_manager(self) -> Optional[RateLimitManager]: + """Get the rate limiter manager. + + Returns: + Rate limiter manager if initialized, None otherwise + """ + return self.rate_limiter async def init(self): """Initialize the MAS. This coroutine performs all necessary setup steps to diff --git a/oxygent/oxy/base_oxy.py b/oxygent/oxy/base_oxy.py index 866fd1e7..ad09716d 100644 --- a/oxygent/oxy/base_oxy.py +++ b/oxygent/oxy/base_oxy.py @@ -151,6 +151,11 @@ class Oxy(BaseModel, ABC): timeout: float = Field(3600, description="Timeout in seconds.") retries: int = Field(2) delay: float = Field(1.0) + + rate_limiter_enabled: bool = Field( + default_factory=Config.get_rate_limiter_enabled, + description="Enable rate limiting for this oxy" + ) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -409,6 +414,28 @@ async def _pre_send_message(self, oxy_request: OxyRequest): ) async def _before_execute(self, oxy_request: OxyRequest) -> OxyRequest: + """Check rate limit before execution.""" + if (self.mas and + self.rate_limiter_enabled and + Config.get_rate_limiter_enabled()): + # Check rate limit for this oxy instance + allowed = await self.mas.check_rate_limit_async(self.name) + if not allowed: + logger.warning( + f"Rate limit exceeded for oxy {self.name}", + extra={ + "trace_id": oxy_request.current_trace_id, + "node_id": oxy_request.node_id, + }, + ) + # Create a rate limited response + from ..schemas import OxyResponse, OxyState + rate_limited_response = OxyResponse( + state=OxyState.FAILED, + output=f"Rate limit exceeded for {self.name}. Please try again later.", + ) + rate_limited_response.oxy_request = oxy_request + raise Exception(f"Rate limit exceeded for {self.name}") return oxy_request @abstractmethod diff --git a/oxygent/rate_limiter.py b/oxygent/rate_limiter.py new file mode 100644 index 00000000..e4b06b5b --- /dev/null +++ b/oxygent/rate_limiter.py @@ -0,0 +1,289 @@ +"""Token rate limiter for OxyGent system. + +This module provides token bucket rate limiting functionality with both +synchronous and asynchronous token acquisition methods. +""" + +import asyncio +import logging +import time +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TokenLimiter: + """Token bucket rate limiter implementation. + + Provides thread-safe token-based rate limiting with both sync and async + interfaces. Uses the token bucket algorithm to control request rates. + + Attributes: + rate (float): Token refill rate (tokens per second) + capacity (int): Maximum number of tokens in the bucket + tokens (float): Current number of available tokens + last_refill_time (float): Last time tokens were refilled + _lock (asyncio.Lock): Async lock for thread-safe operations + """ + + def __init__(self, rate: float = 1.0, capacity: int = 10): + """Initialize the token limiter. + + Args: + rate: Token refill rate (tokens per second) + capacity: Maximum number of tokens in the bucket + """ + self.rate = rate + self.capacity = capacity + self.tokens = float(capacity) + self.last_refill_time = time.time() + self._lock = asyncio.Lock() + + def _refill_tokens(self): + """Refill tokens based on elapsed time.""" + current_time = time.time() + elapsed = current_time - self.last_refill_time + tokens_to_add = elapsed * self.rate + + if tokens_to_add > 0: + self.tokens = min(self.capacity, self.tokens + tokens_to_add) + self.last_refill_time = current_time + + def acquire_sync(self, tokens: int = 1) -> bool: + """Synchronously acquire tokens. + + Args: + tokens: Number of tokens to acquire + + Returns: + True if tokens were acquired, False if not enough tokens available + """ + self._refill_tokens() + + if self.tokens >= tokens: + self.tokens -= tokens + logger.debug(f"Acquired {tokens} tokens. Remaining: {self.tokens}") + return True + else: + logger.debug(f"Failed to acquire {tokens} tokens. Available: {self.tokens}") + return False + + async def acquire_async(self, tokens: int = 1) -> bool: + """Asynchronously acquire tokens. + + Args: + tokens: Number of tokens to acquire + + Returns: + True if tokens were acquired, False if not enough tokens available + """ + async with self._lock: + self._refill_tokens() + + if self.tokens >= tokens: + self.tokens -= tokens + logger.debug(f"Acquired {tokens} tokens. Remaining: {self.tokens}") + return True + else: + logger.debug(f"Failed to acquire {tokens} tokens. Available: {self.tokens}") + return False + + async def wait_for_tokens(self, tokens: int = 1, timeout: Optional[float] = None) -> bool: + """Wait until the specified number of tokens are available. + + Args: + tokens: Number of tokens to wait for + timeout: Maximum time to wait in seconds (None for no timeout) + + Returns: + True if tokens were acquired, False if timeout occurred + """ + start_time = time.time() + + while True: + if await self.acquire_async(tokens): + return True + + # Calculate wait time needed + tokens_needed = tokens - self.tokens + wait_time = tokens_needed / self.rate + + # Check timeout + if timeout is not None: + elapsed = time.time() - start_time + if elapsed >= timeout: + logger.warning(f"Timeout waiting for {tokens} tokens") + return False + wait_time = min(wait_time, timeout - elapsed) + + logger.debug(f"Waiting {wait_time:.2f}s for {tokens} tokens") + await asyncio.sleep(wait_time) + + def get_available_tokens(self) -> float: + """Get the current number of available tokens. + + Returns: + Current number of tokens available + """ + self._refill_tokens() + return self.tokens + + def get_wait_time(self, tokens: int = 1) -> float: + """Calculate the time needed to wait for the specified number of tokens. + + Args: + tokens: Number of tokens needed + + Returns: + Time in seconds needed to wait + """ + self._refill_tokens() + + if self.tokens >= tokens: + return 0.0 + + tokens_needed = tokens - self.tokens + return tokens_needed / self.rate + + def reset(self): + """Reset the token bucket to full capacity.""" + self.tokens = float(self.capacity) + self.last_refill_time = time.time() + logger.info(f"Token limiter reset to {self.capacity} tokens") + + def update_rate(self, rate: float): + """Update the token refill rate. + + Args: + rate: New token refill rate (tokens per second) + """ + self.rate = rate + logger.info(f"Token limiter rate updated to {rate} tokens/second") + + def update_capacity(self, capacity: int): + """Update the token bucket capacity. + + Args: + capacity: New maximum number of tokens + """ + self.capacity = capacity + self.tokens = min(self.tokens, float(capacity)) + logger.info(f"Token limiter capacity updated to {capacity} tokens") + + +class RateLimitManager: + """Manager for multiple token limiters. + + Manages rate limiters for different entities in the OxyGent system. + """ + + def __init__(self): + """Initialize the rate limit manager.""" + self._limiters: dict[str, TokenLimiter] = {} + self._enabled = False + + def enable(self): + """Enable rate limiting.""" + self._enabled = True + logger.info("Rate limiting enabled") + + def disable(self): + """Disable rate limiting.""" + self._enabled = False + logger.info("Rate limiting disabled") + + def is_enabled(self) -> bool: + """Check if rate limiting is enabled. + + Returns: + True if rate limiting is enabled + """ + return self._enabled + + def create_limiter(self, name: str, rate: float = 1.0, capacity: int = 10) -> TokenLimiter: + """Create a new token limiter. + + Args: + name: Name of the limiter + rate: Token refill rate + capacity: Token bucket capacity + + Returns: + Created token limiter + """ + limiter = TokenLimiter(rate, capacity) + self._limiters[name] = limiter + logger.info(f"Created token limiter '{name}' with rate={rate}, capacity={capacity}") + return limiter + + def get_limiter(self, name: str) -> Optional[TokenLimiter]: + """Get a token limiter by name. + + Args: + name: Name of the limiter + + Returns: + Token limiter if found, None otherwise + """ + return self._limiters.get(name) + + def remove_limiter(self, name: str): + """Remove a token limiter. + + Args: + name: Name of the limiter to remove + """ + if name in self._limiters: + del self._limiters[name] + logger.info(f"Removed token limiter '{name}'") + + def check_rate_limit(self, name: str, tokens: int = 1) -> bool: + """Check if rate limit allows the operation. + + Args: + name: Name of the limiter + tokens: Number of tokens to acquire + + Returns: + True if operation is allowed, False otherwise + """ + if not self._enabled: + return True + + limiter = self._limiters.get(name) + if limiter is None: + return True + + return limiter.acquire_sync(tokens) + + async def check_rate_limit_async(self, name: str, tokens: int = 1) -> bool: + """Async check if rate limit allows the operation. + + Args: + name: Name of the limiter + tokens: Number of tokens to acquire + + Returns: + True if operation is allowed, False otherwise + """ + if not self._enabled: + return True + + limiter = self._limiters.get(name) + if limiter is None: + return True + + return await limiter.acquire_async(tokens) + + +# Global rate limit manager instance +_global_rate_limit_manager = RateLimitManager() + + +def get_rate_limit_manager() -> RateLimitManager: + """Get the global rate limit manager. + + Returns: + Global rate limit manager instance + """ + return _global_rate_limit_manager \ No newline at end of file