Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/engram/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from ._models import (
BM25Retrieval,
CommittedOperation,
CommittedOperations,
ConversationInput,
FetchRetrieval,
HybridRetrieval,
Memory,
MessageInput,
NamedRetrievalType,
PreExtractedInput,
PreExtractedItem,
RetrievalConfig,
RetrievalConfigModel,
Run,
RunStatus,
SearchResults,
Expand All @@ -15,6 +19,7 @@
ToolCallFuncInput,
ToolCallInput,
Topic,
VectorRetrieval,
)
from .async_client import AsyncEngramClient
from .client import EngramClient
Expand All @@ -32,18 +37,22 @@
"APIError",
"AsyncEngramClient",
"AuthenticationError",
"BM25Retrieval",
"CommittedOperation",
"CommittedOperations",
"ConnectionError",
"ConversationInput",
"EngramClient",
"EngramError",
"EngramTimeoutError",
"FetchRetrieval",
"HybridRetrieval",
"Memory",
"MessageInput",
"NamedRetrievalType",
"PreExtractedInput",
"PreExtractedItem",
"RetrievalConfig",
"RetrievalConfigModel",
"Run",
"RunStatus",
"SearchResults",
Expand All @@ -53,5 +62,6 @@
"ToolCallInput",
"Topic",
"ValidationError",
"VectorRetrieval",
"__version__",
]
14 changes: 12 additions & 2 deletions src/engram/_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
from .memory import (
AddInput,
BM25Retrieval,
ConversationInput,
FetchRetrieval,
HybridRetrieval,
Memory,
MessageInput,
NamedRetrievalType,
PreExtractedInput,
PreExtractedItem,
RetrievalConfig,
RetrievalConfigModel,
SearchResults,
StringInput,
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
TopicSelector,
VectorRetrieval,
)
from .run import CommittedOperation, CommittedOperations, Run, RunStatus

__all__ = [
"AddInput",
"BM25Retrieval",
"CommittedOperation",
"CommittedOperations",
"ConversationInput",
"FetchRetrieval",
"HybridRetrieval",
"Memory",
"MessageInput",
"NamedRetrievalType",
"PreExtractedInput",
"PreExtractedItem",
"RetrievalConfig",
"RetrievalConfigModel",
"Run",
"RunStatus",
"SearchResults",
Expand All @@ -35,4 +44,5 @@
"ToolCallInput",
"Topic",
"TopicSelector",
"VectorRetrieval",
]
29 changes: 26 additions & 3 deletions src/engram/_models/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Literal, TypeAlias


Expand Down Expand Up @@ -91,9 +91,32 @@ class ConversationInput:


@dataclass(slots=True)
class RetrievalConfig:
retrieval_type: Literal["vector", "bm25", "hybrid", "fetch"]
class VectorRetrieval:
limit: int | None = None
retrieval_type: Literal["vector"] = field(default="vector", init=False)


@dataclass(slots=True)
class BM25Retrieval:
limit: int | None = None
retrieval_type: Literal["bm25"] = field(default="bm25", init=False)


@dataclass(slots=True)
class HybridRetrieval:
limit: int | None = None
retrieval_type: Literal["hybrid"] = field(default="hybrid", init=False)


@dataclass(slots=True)
class FetchRetrieval:
limit: int | None = None
retrieval_type: Literal["fetch"] = field(default="fetch", init=False)


NamedRetrievalType: TypeAlias = Literal["vector", "bm25", "hybrid", "fetch"]

RetrievalConfigModel: TypeAlias = VectorRetrieval | BM25Retrieval | HybridRetrieval | FetchRetrieval


@dataclass(slots=True)
Expand Down
7 changes: 4 additions & 3 deletions src/engram/_resources/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .._models import (
AddInput,
Memory,
RetrievalConfig,
NamedRetrievalType,
RetrievalConfigModel,
Run,
SearchResults,
TopicSelector,
Expand Down Expand Up @@ -85,7 +86,7 @@ def search(
topics: list[TopicSelector] | None = None,
user_id: str | None = None,
group: str | None = None,
retrieval_config: RetrievalConfig | None = None,
retrieval_config: RetrievalConfigModel | NamedRetrievalType | None = None,
properties: dict[str, str] | None = None,
) -> SearchResults:
body = build_search_body(
Expand Down Expand Up @@ -157,7 +158,7 @@ async def search(
topics: list[TopicSelector] | None = None,
user_id: str | None = None,
group: str | None = None,
retrieval_config: RetrievalConfig | None = None,
retrieval_config: RetrievalConfigModel | NamedRetrievalType | None = None,
properties: dict[str, str] | None = None,
) -> SearchResults:
body = build_search_body(
Expand Down
22 changes: 20 additions & 2 deletions src/engram/_serialization/_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

from .._models import (
AddInput,
BM25Retrieval,
ConversationInput,
FetchRetrieval,
HybridRetrieval,
NamedRetrievalType,
PreExtractedInput,
RetrievalConfig,
RetrievalConfigModel,
StringInput,
ToolCallInput,
Topic,
TopicSelector,
VectorRetrieval,
)


Expand Down Expand Up @@ -123,11 +128,16 @@ def build_search_body(
topics: list[TopicSelector] | None,
user_id: str | None,
group: str | None,
retrieval_config: RetrievalConfig | None,
retrieval_config: RetrievalConfigModel | NamedRetrievalType | None,
properties: dict[str, str] | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"query": query}
if retrieval_config is not None:
if isinstance(retrieval_config, str):
name = retrieval_config
retrieval_config = _retrieval_type_to_config.get(name)
if retrieval_config is None:
raise ValueError(f"Unrecognised retrieval type: {name}")
body["retrieval_config"] = {
"retrieval_type": retrieval_config.retrieval_type,
"limit": retrieval_config.limit,
Expand All @@ -142,3 +152,11 @@ def build_search_body(
if properties is not None:
body["properties"] = dict(properties)
return body


_retrieval_type_to_config: dict[NamedRetrievalType, RetrievalConfigModel] = {
"vector": VectorRetrieval(),
"bm25": BM25Retrieval(),
"hybrid": HybridRetrieval(),
"fetch": FetchRetrieval(),
}
21 changes: 19 additions & 2 deletions tests/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
MessageInput,
PreExtractedInput,
PreExtractedItem,
RetrievalConfig,
StringInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
VectorRetrieval,
)
from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient
from engram.errors import APIError, AuthenticationError, ValidationError
Expand Down Expand Up @@ -363,7 +363,7 @@ def handler(request: httpx.Request) -> httpx.Response:
await client.memories.search(
query="find this",
topics=["a"],
retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5),
retrieval_config=VectorRetrieval(limit=5),
)
body = json.loads(captured[0].content)
assert body["query"] == "find this"
Expand All @@ -372,6 +372,23 @@ def handler(request: httpx.Request) -> httpx.Response:
assert body["retrieval_config"]["limit"] == 5


@pytest.mark.asyncio
async def test_search_string_retrieval_config() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"memories": [], "total": 0})

client = _make_client_with_handler(handler)
for retrieval_type in ("vector", "bm25", "hybrid", "fetch"):
captured.clear()
await client.memories.search(query="test", retrieval_config=retrieval_type)
body = json.loads(captured[0].content)
assert body["retrieval_config"]["retrieval_type"] == retrieval_type
assert body["retrieval_config"]["limit"] is None


# ── properties / list ───────────────────────────────────────────────────


Expand Down
20 changes: 18 additions & 2 deletions tests/test_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
MessageInput,
PreExtractedInput,
PreExtractedItem,
RetrievalConfig,
StringInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
VectorRetrieval,
)
from engram.client import DEFAULT_BASE_URL, EngramClient
from engram.errors import APIError, AuthenticationError, ValidationError
Expand Down Expand Up @@ -367,7 +367,7 @@ def handler(request: httpx.Request) -> httpx.Response:
client.memories.search(
query="find this",
topics=["a"],
retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5),
retrieval_config=VectorRetrieval(limit=5),
)
body = json.loads(captured[0].content)
assert body["query"] == "find this"
Expand All @@ -376,6 +376,22 @@ def handler(request: httpx.Request) -> httpx.Response:
assert body["retrieval_config"]["limit"] == 5


def test_search_string_retrieval_config() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"memories": [], "total": 0})

client = _make_client_with_handler(handler)
for retrieval_type in ("vector", "bm25", "hybrid", "fetch"):
captured.clear()
client.memories.search(query="test", retrieval_config=retrieval_type)
body = json.loads(captured[0].content)
assert body["retrieval_config"]["retrieval_type"] == retrieval_type
assert body["retrieval_config"]["limit"] is None


def test_search_no_retrieval_config_by_default() -> None:
captured: list[httpx.Request] = []

Expand Down
19 changes: 16 additions & 3 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@ def test_public_imports() -> None:
APIError,
AsyncEngramClient,
AuthenticationError,
BM25Retrieval,
CommittedOperation,
CommittedOperations,
ConnectionError,
ConversationInput,
EngramClient,
EngramError,
EngramTimeoutError,
FetchRetrieval,
HybridRetrieval,
Memory,
MessageInput,
NamedRetrievalType,
PreExtractedInput,
PreExtractedItem,
RetrievalConfig,
RetrievalConfigModel,
Run,
RunStatus,
SearchResults,
Expand All @@ -25,6 +29,7 @@ def test_public_imports() -> None:
ToolCallInput,
Topic,
ValidationError,
VectorRetrieval,
)

assert isinstance(EngramClient, type)
Expand All @@ -40,7 +45,10 @@ def test_public_imports() -> None:
assert isinstance(SearchResults, type)
assert isinstance(PreExtractedInput, type)
assert isinstance(PreExtractedItem, type)
assert isinstance(RetrievalConfig, type)
assert isinstance(VectorRetrieval, type)
assert isinstance(BM25Retrieval, type)
assert isinstance(HybridRetrieval, type)
assert isinstance(FetchRetrieval, type)
assert isinstance(CommittedOperation, type)
assert isinstance(CommittedOperations, type)
assert isinstance(ConversationInput, type)
Expand All @@ -55,18 +63,22 @@ def test_public_imports() -> None:
"APIError",
"AsyncEngramClient",
"AuthenticationError",
"BM25Retrieval",
"CommittedOperation",
"CommittedOperations",
"ConnectionError",
"ConversationInput",
"EngramClient",
"EngramError",
"EngramTimeoutError",
"FetchRetrieval",
"HybridRetrieval",
"Memory",
"MessageInput",
"NamedRetrievalType",
"PreExtractedInput",
"PreExtractedItem",
"RetrievalConfig",
"RetrievalConfigModel",
"Run",
"RunStatus",
"SearchResults",
Expand All @@ -76,6 +88,7 @@ def test_public_imports() -> None:
"ToolCallInput",
"Topic",
"ValidationError",
"VectorRetrieval",
"__version__",
}
assert set(engram.__all__) == expected_exports
Expand Down
Loading
Loading