Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/engram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
)
from .async_client import AsyncEngramClient
from .client import EngramClient
Expand Down Expand Up @@ -50,6 +51,7 @@
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
"Topic",
"ValidationError",
"__version__",
]
4 changes: 4 additions & 0 deletions src/engram/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
TopicSelector,
)
from .run import CommittedOperation, CommittedOperations, Run, RunStatus

Expand All @@ -31,4 +33,6 @@
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
"Topic",
"TopicSelector",
]
17 changes: 16 additions & 1 deletion src/engram/_models/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ class RetrievalConfig:
limit: int | None = None


@dataclass(slots=True)
class Topic:
"""A topic with an optional per-topic property filter.

Use ``None`` as a property value to clear an inherited global filter
for this topic only.
"""

name: str
properties: dict[str, str | None] | None = None


TopicSelector: TypeAlias = str | Topic


@dataclass(slots=True)
class Memory:
id: str
Expand All @@ -106,9 +121,9 @@ class Memory:
created_at: str
updated_at: str
user_id: str | None = None
conversation_id: str | None = None
tags: list[str] | None = None
score: float | None = None
properties: dict[str, str] | None = None


class SearchResults(Sequence[Memory]):
Expand Down
32 changes: 21 additions & 11 deletions src/engram/_resources/memories.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

from typing import TypeAlias
from uuid import UUID

from .._http import AsyncHttpTransport, HttpTransport
from .._models import AddInput, Memory, RetrievalConfig, Run, SearchResults
from .._models import (
AddInput,
Memory,
RetrievalConfig,
Run,
SearchResults,
TopicSelector,
)
from .._serialization import (
build_add_body,
build_memory_params,
Expand All @@ -16,6 +24,8 @@
_MEMORIES_PATH = "/v1/memories"
_MEMORIES_SEARCH_PATH = "/v1/memories/search"

_Topics: TypeAlias = list[TopicSelector] | None


def _memory_path(memory_id: str | UUID) -> str:
return f"{_MEMORIES_PATH}/{memory_id}"
Expand All @@ -32,14 +42,14 @@ def add(
input_data: AddInput,
*,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
properties: dict[str, str] | None = None,
) -> Run:
body = build_add_body(
input_data,
user_id=user_id,
conversation_id=conversation_id,
group=group,
properties=properties,
)
data = self._transport.request("POST", _MEMORIES_PATH, json=body)
return parse_run(data)
Expand Down Expand Up @@ -75,19 +85,19 @@ def search(
self,
*,
query: str,
topics: list[str] | None = None,
topics: _Topics = None,
Comment thread
danmichaeljones marked this conversation as resolved.
Outdated
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
retrieval_config: RetrievalConfig | None = None,
properties: dict[str, str] | None = None,
) -> SearchResults:
body = build_search_body(
query=query,
topics=topics,
user_id=user_id,
conversation_id=conversation_id,
group=group,
retrieval_config=retrieval_config,
properties=properties,
)
data = self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body)
return parse_search_results(data)
Expand All @@ -104,14 +114,14 @@ async def add(
input_data: AddInput,
*,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
properties: dict[str, str] | None = None,
) -> Run:
body = build_add_body(
input_data,
user_id=user_id,
conversation_id=conversation_id,
group=group,
properties=properties,
)
data = await self._transport.request("POST", _MEMORIES_PATH, json=body)
return parse_run(data)
Expand Down Expand Up @@ -147,19 +157,19 @@ async def search(
self,
*,
query: str,
topics: list[str] | None = None,
topics: _Topics = None,
user_id: str | None = None,
conversation_id: str | None = None,
group: str | None = None,
retrieval_config: RetrievalConfig | None = None,
properties: dict[str, str] | None = None,
) -> SearchResults:
body = build_search_body(
query=query,
topics=topics,
user_id=user_id,
conversation_id=conversation_id,
group=group,
retrieval_config=retrieval_config,
properties=properties,
)
data = await self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body)
return parse_search_results(data)
38 changes: 29 additions & 9 deletions src/engram/_serialization/_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
RetrievalConfig,
StringInput,
ToolCallInput,
Topic,
TopicSelector,
)


Expand Down Expand Up @@ -65,20 +67,37 @@ def _serialize_conversation_content(content: ConversationInput) -> dict[str, Any
return {"conversation": conversation}


def _serialize_topic(topic: TopicSelector) -> str | dict[str, Any]:
Comment thread
danmichaeljones marked this conversation as resolved.
Outdated
if isinstance(topic, str):
return topic
if isinstance(topic, Topic):
out: dict[str, Any] = {"name": topic.name}
if topic.properties is not None:
out["properties"] = dict(topic.properties)
return out
raise TypeError(f"Unsupported topic type: {type(topic)}") # pragma: no cover


def _serialize_topics(topics: list[TopicSelector] | None) -> list[str | dict[str, Any]] | None:
if topics is None:
return None
return [_serialize_topic(t) for t in topics]


def build_add_body(
input_data: AddInput,
*,
user_id: str | None,
conversation_id: str | None,
group: str | None,
properties: dict[str, str] | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"input": _serialize_input(input_data)}
if user_id is not None:
body["user_id"] = user_id
if conversation_id is not None:
body["conversation_id"] = conversation_id
if group is not None:
body["group"] = group
if properties is not None:
body["properties"] = dict(properties)
return body


Expand All @@ -98,24 +117,25 @@ def build_memory_params(
def build_search_body(
*,
query: str,
topics: list[str] | None,
topics: list[TopicSelector] | None,
user_id: str | None,
conversation_id: str | None,
group: str | None,
retrieval_config: RetrievalConfig | None,
properties: dict[str, str] | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"query": query}
if retrieval_config is not None:
body["retrieval_config"] = {
"retrieval_type": retrieval_config.retrieval_type,
"limit": retrieval_config.limit,
}
if topics is not None:
body["topics"] = topics
serialized_topics = _serialize_topics(topics)
if serialized_topics is not None:
body["topics"] = serialized_topics
if user_id is not None:
body["user_id"] = user_id
if conversation_id is not None:
body["conversation_id"] = conversation_id
if group is not None:
body["group"] = group
if properties is not None:
body["properties"] = dict(properties)
return body
2 changes: 1 addition & 1 deletion src/engram/_serialization/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def parse_memory(data: dict[str, Any]) -> Memory:
created_at=data["created_at"],
updated_at=data["updated_at"],
user_id=data.get("user_id"),
conversation_id=data.get("conversation_id"),
tags=data.get("tags"),
score=data.get("score"),
properties=data.get("properties"),
)


Expand Down
48 changes: 44 additions & 4 deletions tests/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
StringInput,
ToolCallFuncInput,
ToolCallInput,
Topic,
)
from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient
from engram.errors import APIError, AuthenticationError, ValidationError
Expand Down Expand Up @@ -130,7 +131,7 @@ async def test_add_conversation() -> None:
result = await client.memories.add(
[{"role": "user", "content": "hi"}],
user_id="u1",
conversation_id="c1",
properties={"conversation_id": "c1"},
)
assert result.run_id == "r3"

Expand Down Expand Up @@ -233,7 +234,7 @@ async def test_add_conversation_content() -> None:
result = await client.memories.add(
ConversationInput(messages=[MessageInput(role="user", content="hi")]),
user_id="u1",
conversation_id="c1",
properties={"conversation_id": "c1"},
)
assert result.run_id == "r5"

Expand Down Expand Up @@ -262,15 +263,15 @@ def handler(request: httpx.Request) -> httpx.Response:
],
metadata={"session_id": "s1"},
),
conversation_id="c1",
properties={"conversation_id": "c1"},
)
body = json.loads(captured[0].content)
conv = body["input"]["conversation"]
assert conv["metadata"] == {"session_id": "s1"}
assert conv["messages"][1]["tool_calls"] == [
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}}
]
assert body["conversation_id"] == "c1"
assert body["properties"] == {"conversation_id": "c1"}


# ── memories.get ────────────────────────────────────────────────────────
Expand Down Expand Up @@ -371,6 +372,45 @@ def handler(request: httpx.Request) -> httpx.Response:
assert body["retrieval_config"]["limit"] == 5


# ── properties / list ───────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_add_sends_properties() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"run_id": "r1", "status": "pending"})

client = _make_client_with_handler(handler)
await client.memories.add(
"hello",
properties={"region": "eu"},
)
body = json.loads(captured[0].content)
assert body["properties"] == {"region": "eu"}


@pytest.mark.asyncio
async def test_search_sends_topic_filters() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"memories": [], "total": 0})

client = _make_client_with_handler(handler)
await client.memories.search(
query="q",
topics=[Topic(name="t1", properties={"region": "eu"})],
properties={"tier": "pro"},
)
body = json.loads(captured[0].content)
assert body["topics"] == [{"name": "t1", "properties": {"region": "eu"}}]
assert body["properties"] == {"tier": "pro"}


# ── runs.get ────────────────────────────────────────────────────────────


Expand Down
Loading
Loading