Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 170 additions & 2 deletions oxygent/databases/db_redis/local_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import json
import time
from collections import deque
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List, Tuple
import bisect

from ...config import Config

Expand All @@ -26,10 +27,12 @@ class LocalRedis:
- Automatic expiration handling with TTL support
- List operations with configurable size limits
- Value type validation and conversion
- Sorted set (zset) operations with score-based ordering
"""

def __init__(self, *, yield_on_ops: bool = True):
self.data: Dict[str, deque] = {}
self.zsets: Dict[str, Tuple[List[float], List[str]]] = {} # (scores, members)
self.expiry: Dict[str, float] = {}
self.default_expire_time = Config.get_redis_expire_time()
self.default_list_max_size = Config.get_redis_max_size()
Expand Down Expand Up @@ -143,9 +146,174 @@ def _check_expiry(self, key: str):
key: The key to check for expiration
"""
if key in self.expiry and time.time() > self.expiry[key]:
del self.data[key]
if key in self.data:
del self.data[key]
if key in self.zsets:
del self.zsets[key]
del self.expiry[key]

async def zadd(self, key: str, mapping: Dict[str, float], ex: Optional[int] = None) -> int:
"""Add one or more members to a sorted set, or update their score if they already exist.

Args:
key: The sorted set key
mapping: Dictionary of member-score pairs
ex: Expiration time in seconds (default: 1 day)

Returns:
int: Number of new members added (excluding updated scores)
"""
if ex is None:
ex = self.default_expire_time

if key not in self.zsets:
self.zsets[key] = ([], [])
self.expiry[key] = time.time() + ex

scores, members = self.zsets[key]
added = 0

for member, score in mapping.items():
# Check if member exists
if member in members:
idx = members.index(member)
# Always update: remove old position and re-insert at new position
del scores[idx]
del members[idx]
insert_pos = bisect.bisect_left(scores, score)
scores.insert(insert_pos, score)
members.insert(insert_pos, member)
else:
# Add new member
insert_pos = bisect.bisect_left(scores, score)
scores.insert(insert_pos, score)
members.insert(insert_pos, member)
added += 1

self.expiry[key] = time.time() + ex

if self._yield_on_ops:
await asyncio.sleep(0)
return added

async def zrange(
self,
key: str,
start: int,
stop: int,
withscores: bool = False
) -> Union[List[str], List[Tuple[str, float]]]:
"""Return a range of members in a sorted set, by index.

Args:
key: The sorted set key
start: Starting index (0-based)
stop: Ending index (inclusive, -1 for last)
withscores: Whether to return scores with members

Returns:
List of members or list of (member, score) tuples
"""
self._check_expiry(key)
if key not in self.zsets:
return []

scores, members = self.zsets[key]
# Handle negative indices
if stop < 0:
stop = len(members) + stop
if start < 0:
start = len(members) + start

result = members[start:stop+1]
if withscores:
result_scores = scores[start:stop+1]
return list(zip(result, result_scores))
return result

async def zrem(self, key: str, *members: str) -> int:
"""Remove one or more members from a sorted set.

Args:
key: The sorted set key
*members: Members to remove

Returns:
int: Number of members removed
"""
self._check_expiry(key)
if key not in self.zsets:
return 0

scores, members_list = self.zsets[key]
removed = 0

for member in members:
if member in members_list:
idx = members_list.index(member)
del scores[idx]
del members_list[idx]
removed += 1

if not members_list: # If set is empty, remove it
del self.zsets[key]
del self.expiry[key]

if self._yield_on_ops:
await asyncio.sleep(0)
return removed

async def zincrby(self, key: str, increment: float, member: str, ex: Optional[int] = None) -> float:
"""Increment the score of a member in a sorted set.

Args:
key: The sorted set key
increment: The amount to increment the score by
member: The member to increment
ex: Expiration time in seconds (default: 1 day)

Returns:
float: The new score of the member
"""
if ex is None:
ex = self.default_expire_time

self._check_expiry(key)

# Initialize the sorted set if it doesn't exist
if key not in self.zsets:
self.zsets[key] = ([], [])
self.expiry[key] = time.time() + ex
else:
# Update expiry even if key exists
self.expiry[key] = time.time() + ex

scores, members_list = self.zsets[key]

# Check if member exists
if member in members_list:
idx = members_list.index(member)
old_score = scores[idx]
new_score = old_score + increment

# Always remove and re-insert to maintain sorted order
del scores[idx]
del members_list[idx]
insert_pos = bisect.bisect_left(scores, new_score)
scores.insert(insert_pos, new_score)
members_list.insert(insert_pos, member)
else:
# Member doesn't exist, add it with the increment as score
new_score = increment
insert_pos = bisect.bisect_left(scores, new_score)
scores.insert(insert_pos, new_score)
members_list.insert(insert_pos, member)

if self._yield_on_ops:
await asyncio.sleep(0)

return new_score

async def close(self):
# This method is async to maintain compatibility with the Redis interface
# Async for interface compatibility
Expand Down
86 changes: 81 additions & 5 deletions oxygent/oxy/function_tools/function_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
import asyncio
import functools
import concurrent.futures
import os
import threading
import logging

from pydantic import Field

from ..base_tool import BaseTool
from .function_tool import FunctionTool

logger = logging.getLogger(__name__)


class FunctionHub(BaseTool):
"""Central hub for registering and managing Python functions as tools.
Expand All @@ -34,12 +39,17 @@ def __init__(self, **data):
"""Initialize the FunctionHub with thread pool support."""
super().__init__(**data)
self._thread_pool = None # Private attribute for thread pool
self._thread_pool_lock = threading.Lock() # Lock for thread pool initialization

@property
def thread_pool(self):
"""Lazy initialization of thread pool."""
"""Lazy initialization of thread pool with thread safety."""
if self._thread_pool is None:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
with self._thread_pool_lock:
if self._thread_pool is None: # Double-checked locking pattern
cpu_count = os.cpu_count() or 1
max_workers = min(max(cpu_count * 3, 4), 32)
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
return self._thread_pool

async def init(self):
Expand Down Expand Up @@ -108,7 +118,73 @@ async def async_func(*args, **kwargs):
return decorator

async def cleanup(self):
"""Clean up resources, including the thread pool."""
"""Clean up resources, including the thread pool.

This method ensures proper shutdown of the thread pool to prevent
resource leaks and dangling threads. It waits for all pending tasks
to complete before shutting down.

The cleanup process is idempotent - multiple calls are safe.
"""
if self._thread_pool:
self._thread_pool.shutdown(wait=True)
self._thread_pool = None
try:
logger.info(f"FunctionHub {self.name}: Starting thread pool cleanup...")
# Wait for all pending tasks to complete
self._thread_pool.shutdown(wait=True)
logger.info(f"FunctionHub {self.name}: Thread pool shutdown completed")
except Exception as e:
logger.error(f"FunctionHub {self.name}: Error during thread pool cleanup: {e}")
# Even if shutdown fails, ensure _thread_pool is set to None
# to prevent further usage and potential memory leaks
finally:
self._thread_pool = None

def is_thread_pool_active(self):
"""Check if the thread pool is currently active.

Returns:
bool: True if thread pool is initialized and active, False otherwise.
"""
return self._thread_pool is not None

def get_thread_pool_info(self):
"""Get information about the current thread pool.

Returns:
dict: Thread pool information including worker count and status,
or None if pool is not initialized.
"""
if self._thread_pool is None:
return None

try:
# ThreadPoolExecutor doesn't expose internal state directly,
# but we can check if it's shut down
return {
"initialized": True,
"workers": getattr(self._thread_pool, '_max_workers', 'unknown'),
"shutdown": getattr(self._thread_pool, '_shutdown', False)
}
except Exception:
return {"initialized": True, "status": "active"}

async def __aenter__(self):
"""Async context manager entry point.

Returns:
FunctionHub: Self for use in async with statement
"""
await self.init()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit point.

Ensures cleanup is performed even if exceptions occur during usage.

Args:
exc_type: Exception type if an exception occurred
exc_val: Exception value if an exception occurred
exc_tb: Exception traceback if an exception occurred
"""
await self.cleanup()
Loading