diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index d61ff3ec8..b6390d2eb 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -967,21 +967,32 @@ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityE episodes = record['episodes'] if provider == GraphProvider.KUZU: attributes = json.loads(record['attributes']) if record['attributes'] else {} + elif provider == GraphProvider.NEO4J: + # Neo4j: Try new JSON format first, fall back to old spread format + raw_attrs = record.get('attributes', '') + if raw_attrs and isinstance(raw_attrs, str): + # New format: JSON string in e.attributes + attributes = json.loads(raw_attrs) + else: + # Old format: attributes spread as individual properties + all_props = record.get('all_properties', {}) + if all_props: + attributes = dict(all_props) + for key in ('uuid', 'source_node_uuid', 'target_node_uuid', 'fact', + 'fact_embedding', 'name', 'group_id', 'episodes', + 'created_at', 'expired_at', 'valid_at', 'invalid_at', + 'reference_time', 'attributes'): + attributes.pop(key, None) + else: + attributes = {} else: + # FalkorDB, Neptune: Original behavior attributes = record['attributes'] - attributes.pop('uuid', None) - attributes.pop('source_node_uuid', None) - attributes.pop('target_node_uuid', None) - attributes.pop('fact', None) - attributes.pop('fact_embedding', None) - attributes.pop('name', None) - attributes.pop('group_id', None) - attributes.pop('episodes', None) - attributes.pop('created_at', None) - attributes.pop('expired_at', None) - attributes.pop('valid_at', None) - attributes.pop('invalid_at', None) - attributes.pop('reference_time', None) + for key in ('uuid', 'source_node_uuid', 'target_node_uuid', 'fact', + 'fact_embedding', 'name', 'group_id', 'episodes', + 'created_at', 'expired_at', 'valid_at', 'invalid_at', + 'reference_time'): + attributes.pop(key, None) edge = EntityEdge( uuid=record['uuid'], diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 8c4943e89..975c46f42 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -1417,7 +1417,10 @@ async def add_episode_bulk( @handle_multiple_group_ids async def build_communities( - self, group_ids: list[str] | None = None, driver: GraphDriver | None = None + self, + group_ids: list[str] | None = None, + driver: GraphDriver | None = None, + sample_size: int | None = None, ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising @@ -1425,6 +1428,13 @@ async def build_communities( ---------- group_ids : list[str] | None Optional. Create communities only for the listed group_ids. If blank the entire graph will be used. + sample_size : int | None + Optional. If set, each community's LLM summary is built from only + the top-K most representative members (highest in-community + weighted degree, then longest summary). Dramatically reduces LLM + cost on large graphs — without sampling, summary cost grows with + total node count; with sampling it grows with the number of + communities. Recommended for graphs >10k nodes. """ if driver is None: driver = self.clients.driver @@ -1433,7 +1443,7 @@ async def build_communities( await remove_communities(driver) community_nodes, community_edges = await build_communities( - driver, self.llm_client, group_ids + driver, self.llm_client, group_ids, sample_size=sample_size ) await semaphore_gather( diff --git a/graphiti_core/models/edges/edge_db_queries.py b/graphiti_core/models/edges/edge_db_queries.py index 15d4c71f6..4af97795b 100644 --- a/graphiti_core/models/edges/edge_db_queries.py +++ b/graphiti_core/models/edges/edge_db_queries.py @@ -205,6 +205,23 @@ def get_entity_edge_return_query(provider: GraphProvider) -> str: properties(e) AS attributes """ + if provider == GraphProvider.NEO4J: + return """ + e.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.group_id AS group_id, + e.created_at AS created_at, + e.name AS name, + e.fact AS fact, + e.episodes AS episodes, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at, + COALESCE(e.attributes, '') AS attributes, + properties(e) AS all_properties + """ + return """ e.uuid AS uuid, n.uuid AS source_node_uuid, diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 6b91f0a0b..e62477e2f 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -279,6 +279,18 @@ def get_entity_node_return_query(provider: GraphProvider) -> str: n.summary AS summary, n.attributes AS attributes """ + + if provider == GraphProvider.NEO4J: + return """ + n.uuid AS uuid, + n.name AS name, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + COALESCE(n.attributes, '') AS attributes, + properties(n) AS all_properties + """ return """ n.uuid AS uuid, diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 7527bb37b..8c8ad6f10 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -1033,7 +1033,30 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode: def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode: if provider == GraphProvider.KUZU: attributes = json.loads(record['attributes']) if record['attributes'] else {} + elif provider == GraphProvider.NEO4J: + # Neo4j: Try new JSON format first, fall back to old spread format + raw_attrs = record.get('attributes', '') + if raw_attrs and isinstance(raw_attrs, str): + # New format: JSON string in n.attributes + attributes = json.loads(raw_attrs) + else: + # Old format: attributes spread as individual properties + all_props = record.get('all_properties', {}) + if all_props: + attributes = dict(all_props) + # Remove known system fields + attributes.pop('uuid', None) + attributes.pop('name', None) + attributes.pop('group_id', None) + attributes.pop('name_embedding', None) + attributes.pop('summary', None) + attributes.pop('created_at', None) + attributes.pop('labels', None) + attributes.pop('attributes', None) # Remove the empty attributes field + else: + attributes = {} else: + # FalkorDB, Neptune: Original behavior attributes = record['attributes'] attributes.pop('uuid', None) attributes.pop('name', None) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 40e9c57b7..1581f7696 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -181,7 +181,12 @@ async def add_nodes_and_edges_bulk_tx( if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} entity_data['attributes'] = json.dumps(attributes) + elif driver.provider == GraphProvider.NEO4J: + # Neo4j: Serialize attributes to JSON string to support nested structures + attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {} + entity_data['attributes'] = json.dumps(attributes) if attributes else '{}' else: + # FalkorDB, Neptune: Keep original behavior (spread attributes) entity_data.update(node.attributes or {}) nodes.append(entity_data) @@ -208,7 +213,12 @@ async def add_nodes_and_edges_bulk_tx( if driver.provider == GraphProvider.KUZU: attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} edge_data['attributes'] = json.dumps(attributes) + elif driver.provider == GraphProvider.NEO4J: + # Neo4j: Serialize attributes to JSON string to support nested structures + attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {} + edge_data['attributes'] = json.dumps(attributes) if attributes else '{}' else: + # FalkorDB, Neptune: Keep original behavior (spread attributes) edge_data.update(edge.attributes or {}) edges.append(edge_data) diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 8c96bd79f..b52c2b4d9 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -27,16 +27,75 @@ class Neighbor(BaseModel): edge_count: int +async def _build_group_projection( + driver: GraphDriver, group_id: str +) -> dict[str, list[Neighbor]]: + """Fetch the RELATES_TO projection for all entities in a group. + + Returns a mapping from each node's uuid to its list of in-group neighbors + with edge counts. Used by label propagation and by in-community degree + computations for sampling. + """ + projection: dict[str, list[Neighbor]] = {} + nodes = await EntityNode.get_by_group_ids(driver, [group_id]) + for node in nodes: + match_query = """ + MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id}) + """ + if driver.provider == GraphProvider.KUZU: + match_query = """ + MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id}) + """ + records, _, _ = await driver.execute_query( + match_query + + """ + WITH count(e) AS count, m.uuid AS uuid + RETURN + uuid, + count + """, + uuid=node.uuid, + group_id=group_id, + ) + + projection[node.uuid] = [ + Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records + ] + return projection + + async def get_community_clusters( - driver: GraphDriver, group_ids: list[str] | None -) -> list[list[EntityNode]]: + driver: GraphDriver, + group_ids: list[str] | None, + return_projection: bool = False, +) -> list[list[EntityNode]] | tuple[list[list[EntityNode]], dict[str, list[Neighbor]]]: + """Compute community clusters via label propagation. + + Args: + driver: Graph driver. + group_ids: Optional list of group ids to scope clustering. If None, + all groups are used. + return_projection: When True, also return the combined projection + (uuid → neighbors with edge counts) so callers can compute + in-community degrees without a second pass over the graph. + + Returns: + By default, just the list of clusters (each a list of EntityNode). + When return_projection=True, returns (clusters, projection) tuple. + """ if driver.graph_operations_interface: try: - return await driver.graph_operations_interface.get_community_clusters(driver, group_ids) + clusters = await driver.graph_operations_interface.get_community_clusters( + driver, group_ids + ) + if return_projection: + return clusters, {} + return clusters except NotImplementedError: pass community_clusters: list[list[EntityNode]] = [] + combined_projection: dict[str, list[Neighbor]] = {} if group_ids is None: group_id_values, _, _ = await driver.execute_query( @@ -51,31 +110,9 @@ async def get_community_clusters( group_ids = group_id_values[0]['group_ids'] if group_id_values else [] for group_id in group_ids: - projection: dict[str, list[Neighbor]] = {} - nodes = await EntityNode.get_by_group_ids(driver, [group_id]) - for node in nodes: - match_query = """ - MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id}) - """ - if driver.provider == GraphProvider.KUZU: - match_query = """ - MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id}) - """ - records, _, _ = await driver.execute_query( - match_query - + """ - WITH count(e) AS count, m.uuid AS uuid - RETURN - uuid, - count - """, - uuid=node.uuid, - group_id=group_id, - ) - - projection[node.uuid] = [ - Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records - ] + projection = await _build_group_projection(driver, group_id) + if return_projection: + combined_projection.update(projection) cluster_uuids = label_propagation(projection) @@ -87,48 +124,108 @@ async def get_community_clusters( ) ) + if return_projection: + return community_clusters, combined_projection return community_clusters +LABEL_PROPAGATION_OSCILLATION_WINDOW = 8 +_LABEL_PROPAGATION_RNG_SEED = 42 + + def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: - # Implement the label propagation community detection algorithm. - # 1. Start with each node being assigned its own community - # 2. Each node will take on the community of the plurality of its neighbors - # 3. Ties are broken by going to the largest community - # 4. Continue until no communities change during propagation + # Asynchronous label propagation with shuffled node order and oscillation + # detection. This is the form described by Raghavan et al. (2007), + # "Near linear time algorithm to detect community structures in + # large-scale networks". + # + # Algorithm: + # 1. Each node starts in its own community. + # 2. In each pass, visit nodes in a FRESH random order. + # 3. For each node, move it to the plurality-weight community among its + # neighbors, using the CURRENT (in-place) community assignments. + # Reading the live state (not a snapshot) is the key correctness fix + # over the naive synchronous form — once a node flips, its neighbors + # see the new label immediately, breaking ping-pong loops. + # 4. Break ties deterministically by preferring the higher community id, + # and only move if the candidate strictly improves on the current + # support (so well-connected nodes stay put under ties). + # 5. Terminate on natural convergence (no node changed in a full pass). + # As a belt-and-suspenders safeguard, also break if the full state + # repeats within a short recent window — async LPA is known to + # converge on undirected graphs, but a cycle detector catches any + # edge case we have not anticipated. + # + # Rationale: the synchronous form (batch update from a frozen snapshot) + # is vulnerable to flip-flop oscillation on graphs with high-degree hub + # nodes. Tied candidate scores cause groups of nodes to swap labels + # symmetrically every iteration, which repeats forever. Async updates + # eliminate that class of failure and empirically converge in O(log n) + # iterations on real-world graphs. + + import random + from collections import deque community_map = {uuid: i for i, uuid in enumerate(projection.keys())} + node_order = list(projection.keys()) + + rng = random.Random(_LABEL_PROPAGATION_RNG_SEED) + recent_state_hashes: deque[int] = deque(maxlen=LABEL_PROPAGATION_OSCILLATION_WINDOW) while True: + rng.shuffle(node_order) no_change = True - new_community_map: dict[str, int] = {} - for uuid, neighbors in projection.items(): + for uuid in node_order: + neighbors = projection[uuid] + if not neighbors: + continue + curr_community = community_map[uuid] community_candidates: dict[int, int] = defaultdict(int) for neighbor in neighbors: + # In-place read — picks up changes from earlier in this pass. community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count - community_lst = [ - (count, community) for community, count in community_candidates.items() - ] - - community_lst.sort(reverse=True) - candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1) - if community_candidate != -1 and candidate_rank > 1: - new_community = community_candidate - else: - new_community = max(community_candidate, curr_community) - new_community_map[uuid] = new_community + if not community_candidates: + continue + + # Pick (count desc, community_id desc) — determinism on ties. + best_community, best_count = max( + community_candidates.items(), + key=lambda item: (item[1], item[0]), + ) + curr_support = community_candidates.get(curr_community, 0) + + # Only move on strict improvement, or on tie with a deterministic + # preference for the higher community id. This prevents a node + # from churning between equally-supported communities forever. + if best_count > curr_support: + new_community = best_community + elif best_count == curr_support and best_community > curr_community: + new_community = best_community + else: + new_community = curr_community if new_community != curr_community: + community_map[uuid] = new_community no_change = False if no_change: break - community_map = new_community_map + # Belt-and-suspenders: if the exact same community_map repeats + # within a short window, we are in a stable cycle — stop and keep + # whatever partition we have. Async LPA should not reach this path + # on real graphs; if it does, something is structurally unusual. + state_hash = hash(frozenset(community_map.items())) + if state_hash in recent_state_hashes: + logger.warning( + 'label_propagation detected oscillation — using current clustering' + ) + break + recent_state_hashes.append(state_hash) community_cluster_map = defaultdict(list) for uuid, community in community_map.items(): @@ -171,10 +268,68 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s return description +def _select_representative_members( + community_cluster: list[EntityNode], + projection: dict[str, list[Neighbor]] | None, + sample_size: int, +) -> list[EntityNode]: + """Pick the top-K members most likely to characterize the community. + + Scoring key (descending): in-community weighted degree, then summary + length, then name for deterministic ties. In-community degree uses the + projection we already computed during clustering — no extra queries. + + When no projection is available (e.g. the graph_operations_interface + returned clusters directly), falls back to summary length only. + """ + if len(community_cluster) <= sample_size: + return community_cluster + + member_uuids = {m.uuid for m in community_cluster} + + def in_community_degree(entity: EntityNode) -> int: + if not projection: + return 0 + neighbors = projection.get(entity.uuid, []) + return sum(n.edge_count for n in neighbors if n.node_uuid in member_uuids) + + scored = sorted( + community_cluster, + key=lambda e: (in_community_degree(e), len(e.summary or ''), e.name), + reverse=True, + ) + return scored[:sample_size] + + async def build_community( - llm_client: LLMClient, community_cluster: list[EntityNode] + llm_client: LLMClient, + community_cluster: list[EntityNode], + *, + projection: dict[str, list[Neighbor]] | None = None, + sample_size: int | None = None, ) -> tuple[CommunityNode, list[CommunityEdge]]: - summaries = [entity.summary for entity in community_cluster] + """Build a community node from its member entities. + + Args: + llm_client: LLM used to summarize pairs and generate the final name. + community_cluster: Full list of member entities. + projection: Optional {uuid -> neighbors} projection from the clustering + step. Used to rank members by in-community weighted degree when + sampling. + sample_size: If set, only the top-K most representative members + participate in the binary summary merge. The community still + contains all members in its HAS_MEMBER edges — sampling only + affects which summaries are fed into the LLM pipeline. This cuts + LLM cost from O(N) per community to O(sample_size) and typically + improves quality because hub nodes carry the community's signal. + """ + summary_members = ( + _select_representative_members(community_cluster, projection, sample_size) + if sample_size is not None + else community_cluster + ) + + summaries = [entity.summary for entity in summary_members] length = len(summaries) while length > 1: odd_one_out: str | None = None @@ -196,8 +351,10 @@ async def build_community( summaries = new_summaries length = len(summaries) - summary = truncate_at_sentence(summaries[0], MAX_SUMMARY_CHARS) - name = await generate_summary_description(llm_client, summary) + summary = truncate_at_sentence(summaries[0], MAX_SUMMARY_CHARS) if summaries else '' + name = ( + await generate_summary_description(llm_client, summary) if summary else 'community' + ) now = utc_now() community_node = CommunityNode( name=name, @@ -208,7 +365,13 @@ async def build_community( ) community_edges = build_community_edges(community_cluster, community_node, now) - logger.debug(f'Built community {community_node.uuid} with {len(community_edges)} edges') + logger.debug( + 'Built community %s with %d member edges (summary from %d/%d members)', + community_node.uuid, + len(community_edges), + len(summary_members), + len(community_cluster), + ) return community_node, community_edges @@ -217,14 +380,35 @@ async def build_communities( driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None, + *, + sample_size: int | None = None, ) -> tuple[list[CommunityNode], list[CommunityEdge]]: - community_clusters = await get_community_clusters(driver, group_ids) + """Cluster entities into communities and build a summary node for each. + + Args: + driver: Graph driver. + llm_client: LLM client for community summarization. + group_ids: Scope clustering to these group ids (or all if None). + sample_size: If set, each community's summary is built from only + the top-K most representative members (by in-community weighted + degree, then summary length). Reduces LLM cost from O(total nodes) + to O(num_communities * sample_size). Recommended for graphs + >10k nodes. + """ + clusters_result = await get_community_clusters(driver, group_ids, return_projection=True) + assert isinstance(clusters_result, tuple) + community_clusters, projection = clusters_result semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY) async def limited_build_community(cluster): async with semaphore: - return await build_community(llm_client, cluster) + return await build_community( + llm_client, + cluster, + projection=projection, + sample_size=sample_size, + ) communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list( await semaphore_gather( diff --git a/tests/test_neo4j_nested_attributes_int.py b/tests/test_neo4j_nested_attributes_int.py new file mode 100644 index 000000000..7de4ae6cd --- /dev/null +++ b/tests/test_neo4j_nested_attributes_int.py @@ -0,0 +1,208 @@ +"""Integration test for Neo4j nested attributes serialization. + +Tests that entities and edges with complex nested attributes (Maps of Lists, Lists of Maps) +are properly serialized to JSON strings for Neo4j storage. + +This test addresses a bug where Neo4j would reject entity/edge attributes containing +nested structures with the error: +Neo.ClientError.Statement.TypeError - Property values can only be of primitive types +or arrays thereof. +""" + +import pytest +from datetime import datetime, UTC + +from graphiti_core.nodes import EntityNode +from graphiti_core.edges import EntityEdge +from graphiti_core.driver.driver import GraphProvider + + +@pytest.mark.integration +async def test_nested_entity_attributes(graph_driver, embedder): + """Test that entities with nested attributes are stored and retrieved correctly in Neo4j.""" + if graph_driver.provider != GraphProvider.NEO4J: + pytest.skip("This test is specific to Neo4j nested attributes serialization") + + # Create entity with nested attributes (Maps of Lists, Lists of Maps) + entity = EntityNode( + uuid="test-entity-nested-attrs-001", + name="Test Entity with Nested Attributes", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Test entity for nested attributes", + attributes={ + # Simple array of primitives - should work + "discovered_resources": ["resource1", "resource2", "resource3"], + # Nested map with list values - the problematic case + "metadata": { + "analysis": ["analysis_item1", "analysis_item2"], + "nested_map": {"key1": "value1", "key2": "value2"} + }, + # Map with complex nested structure + "activity_log": { + "initiated_actions": ["action1", "action2"], + "completed_tasks": { + "task_list": ["task1", "task2"], + "priority": "high" + } + }, + # Simple primitive attributes + "count": 42, + "status": "active" + } + ) + + await entity.generate_name_embedding(embedder) + + # Save entity - this would previously crash Neo4j with nested structures + await entity.save(graph_driver) + + # Retrieve entity and verify attributes are preserved + retrieved = await EntityNode.get_by_uuid(graph_driver, entity.uuid) + + assert retrieved is not None, "Entity should be retrievable" + assert retrieved.uuid == entity.uuid + assert retrieved.name == entity.name + + # Verify nested attributes are correctly preserved + assert retrieved.attributes == entity.attributes, "Attributes should be preserved exactly" + assert retrieved.attributes["discovered_resources"] == ["resource1", "resource2", "resource3"] + assert retrieved.attributes["metadata"]["analysis"] == ["analysis_item1", "analysis_item2"] + assert retrieved.attributes["metadata"]["nested_map"]["key1"] == "value1" + assert retrieved.attributes["activity_log"]["completed_tasks"]["task_list"] == ["task1", "task2"] + assert retrieved.attributes["count"] == 42 + assert retrieved.attributes["status"] == "active" + + +@pytest.mark.integration +async def test_nested_edge_attributes(graph_driver, embedder): + """Test that edges with nested attributes are stored and retrieved correctly in Neo4j.""" + if graph_driver.provider != GraphProvider.NEO4J: + pytest.skip("This test is specific to Neo4j nested attributes serialization") + + # First create two entity nodes to connect + source_entity = EntityNode( + uuid="test-source-entity-001", + name="Source Entity", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Source entity for edge test", + attributes={} + ) + + target_entity = EntityNode( + uuid="test-target-entity-001", + name="Target Entity", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Target entity for edge test", + attributes={} + ) + + await source_entity.generate_name_embedding(embedder) + await target_entity.generate_name_embedding(embedder) + await source_entity.save(graph_driver) + await target_entity.save(graph_driver) + + # Create edge with nested attributes + edge = EntityEdge( + uuid="test-edge-nested-attrs-001", + source_node_uuid=source_entity.uuid, + target_node_uuid=target_entity.uuid, + name="RELATES_TO", + fact="Source entity relates to target entity with complex metadata", + group_id="test-group-nested", + episodes=["episode1", "episode2"], + created_at=datetime.now(UTC), + valid_at=datetime.now(UTC), + attributes={ + # Nested map with list values + "relationship_metadata": { + "interaction_types": ["collaboration", "communication"], + "details": { + "frequency": "daily", + "confidence": 0.95 + } + }, + # Map with complex structure + "historical_data": { + "events": ["event1", "event2", "event3"], + "analysis": { + "trends": ["increasing", "positive"], + "factors": {"external": True, "internal": False} + } + }, + # Simple attributes + "weight": 0.85, + "verified": True + } + ) + + await edge.generate_embedding(embedder) + + # Save edge - this would previously crash Neo4j with nested structures + await edge.save(graph_driver) + + # Retrieve edge and verify attributes are preserved + retrieved = await EntityEdge.get_by_uuid(graph_driver, edge.uuid) + + assert retrieved is not None, "Edge should be retrievable" + assert retrieved.uuid == edge.uuid + assert retrieved.fact == edge.fact + + # Verify nested attributes are correctly preserved + assert retrieved.attributes == edge.attributes, "Edge attributes should be preserved exactly" + assert retrieved.attributes["relationship_metadata"]["interaction_types"] == ["collaboration", "communication"] + assert retrieved.attributes["relationship_metadata"]["details"]["frequency"] == "daily" + assert retrieved.attributes["historical_data"]["events"] == ["event1", "event2", "event3"] + assert retrieved.attributes["historical_data"]["analysis"]["factors"]["external"] is True + assert retrieved.attributes["weight"] == 0.85 + assert retrieved.attributes["verified"] is True + + +@pytest.mark.integration +async def test_empty_and_none_attributes(graph_driver, embedder): + """Test that empty and None attributes are handled correctly.""" + if graph_driver.provider != GraphProvider.NEO4J: + pytest.skip("This test is specific to Neo4j nested attributes serialization") + + # Entity with empty attributes + entity_empty = EntityNode( + uuid="test-entity-empty-attrs-001", + name="Entity with Empty Attributes", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Test entity with empty attributes", + attributes={} + ) + + await entity_empty.generate_name_embedding(embedder) + await entity_empty.save(graph_driver) + + retrieved_empty = await EntityNode.get_by_uuid(graph_driver, entity_empty.uuid) + assert retrieved_empty is not None + assert retrieved_empty.attributes == {} + + # Entity with None-valued attributes + entity_none = EntityNode( + uuid="test-entity-none-attrs-001", + name="Entity with None Attributes", + group_id="test-group-nested", + labels=["Entity", "TestType"], + created_at=datetime.now(UTC), + summary="Test entity with None attributes", + attributes={"key1": None, "key2": "value2"} + ) + + await entity_none.generate_name_embedding(embedder) + await entity_none.save(graph_driver) + + retrieved_none = await EntityNode.get_by_uuid(graph_driver, entity_none.uuid) + assert retrieved_none is not None + assert retrieved_none.attributes["key1"] is None + assert retrieved_none.attributes["key2"] == "value2" + diff --git a/tests/utils/maintenance/test_community_operations.py b/tests/utils/maintenance/test_community_operations.py new file mode 100644 index 000000000..5cd414011 --- /dev/null +++ b/tests/utils/maintenance/test_community_operations.py @@ -0,0 +1,190 @@ +"""Tests for label_propagation community detection. + +Focuses on the oscillation-prevention fix: graphs with high-degree hub +nodes previously caused the synchronous batch implementation to loop +forever. The asynchronous form (visit nodes in shuffled order, update +the map in place) converges quickly on every case we throw at it. +""" + +from __future__ import annotations + +import time + +import pytest + +from graphiti_core.utils.maintenance.community_operations import ( + Neighbor, + label_propagation, +) + + +def _make_projection(edges: list[tuple[str, str, int]]) -> dict[str, list[Neighbor]]: + """Build an undirected projection from a weighted edge list.""" + projection: dict[str, list[Neighbor]] = {} + for a, b, weight in edges: + projection.setdefault(a, []).append(Neighbor(node_uuid=b, edge_count=weight)) + projection.setdefault(b, []).append(Neighbor(node_uuid=a, edge_count=weight)) + return projection + + +def _assert_partition(clusters: list[list[str]], expected_nodes: set[str]) -> None: + """Every node appears exactly once across clusters.""" + seen: set[str] = set() + for cluster in clusters: + for node in cluster: + assert node not in seen, f"node {node} appears in multiple clusters" + seen.add(node) + assert seen == expected_nodes, f"missing nodes: {expected_nodes - seen}" + + +def test_empty_projection_returns_empty(): + assert label_propagation({}) == [] + + +def test_single_isolated_node(): + projection = {"a": []} + clusters = label_propagation(projection) + _assert_partition(clusters, {"a"}) + assert len(clusters) == 1 + + +def test_two_disconnected_triangles(): + projection = _make_projection( + [ + ("a1", "a2", 1), + ("a2", "a3", 1), + ("a3", "a1", 1), + ("b1", "b2", 1), + ("b2", "b3", 1), + ("b3", "b1", 1), + ] + ) + clusters = label_propagation(projection) + _assert_partition(clusters, {"a1", "a2", "a3", "b1", "b2", "b3"}) + assert len(clusters) == 2 + + +def test_complete_graph_collapses_to_one_community(): + edges = [(f"n{i}", f"n{j}", 1) for i in range(8) for j in range(i + 1, 8)] + projection = _make_projection(edges) + clusters = label_propagation(projection) + assert len(clusters) == 1 + assert len(clusters[0]) == 8 + + +def test_hub_with_leaves_converges(): + """Regression: central hub with many leaves used to oscillate. + + The synchronous batch implementation flipped leaves between the hub's + community and their own community every iteration, never converging. + """ + edges = [(f"leaf{i}", "hub", 1) for i in range(20)] + projection = _make_projection(edges) + start = time.time() + clusters = label_propagation(projection) + elapsed = time.time() - start + _assert_partition(clusters, {"hub", *(f"leaf{i}" for i in range(20))}) + assert elapsed < 1.0, f"hub graph should converge quickly; took {elapsed:.2f}s" + + +def test_two_stars_joined_by_bridge(): + """Two hub+leaves clusters connected by one bridge edge. + + A correct community detector should identify two communities (one per + star). Earlier synchronous implementations could oscillate here. + """ + edges = [ + *[(f"a_leaf{i}", "hub_a", 1) for i in range(10)], + *[(f"b_leaf{i}", "hub_b", 1) for i in range(10)], + ("hub_a", "hub_b", 1), + ] + projection = _make_projection(edges) + clusters = label_propagation(projection) + _assert_partition( + clusters, + {"hub_a", "hub_b", *(f"a_leaf{i}" for i in range(10)), *(f"b_leaf{i}" for i in range(10))}, + ) + assert len(clusters) == 2 + + +def test_real_world_pathological_graph_converges(): + """Regression test from an observed production failure. + + A 48-node knowledge graph with a central "Threshold" node + (uuid `d689c03c`) connected to 14+ entities caused the synchronous + batch implementation to oscillate indefinitely — a fixed subset of + 19 nodes kept flipping between two states forever. + + This projection is a simplified version of the failing graph. With + the synchronous implementation it never returned; the async form + converges in milliseconds. + """ + # Hub node with heavy ties to several satellites + hub = "hub" + sat_heavy = [f"sat_h{i}" for i in range(4)] # strong connections to hub + sat_light = [f"sat_l{i}" for i in range(10)] # weak connections to hub + + edges: list[tuple[str, str, int]] = [] + # Strong ties: hub ↔ each heavy satellite (edge count 29) + edges.extend((hub, sat, 29) for sat in sat_heavy) + # Weak ties: hub ↔ each light satellite (edge count 1) + edges.extend((hub, sat, 1) for sat in sat_light) + # Triangle-ish ties among light satellites to create tie ambiguity + for i in range(0, len(sat_light) - 1, 2): + edges.append((sat_light[i], sat_light[i + 1], 1)) + # A few floating dyads that should form their own mini-communities + edges.append(("pair_a1", "pair_a2", 1)) + edges.append(("pair_b1", "pair_b2", 1)) + + projection = _make_projection(edges) + + start = time.time() + clusters = label_propagation(projection) + elapsed = time.time() - start + + all_nodes = {hub, *sat_heavy, *sat_light, "pair_a1", "pair_a2", "pair_b1", "pair_b2"} + _assert_partition(clusters, all_nodes) + assert elapsed < 1.0, f"pathological graph should converge fast; took {elapsed:.2f}s" + # Sanity: at least one community should contain the hub and its heavy ties + hub_cluster = next(c for c in clusters if hub in c) + for sat in sat_heavy: + assert sat in hub_cluster, f"{sat} should be in hub's community" + + +def test_deterministic_under_seed(): + """Same input produces the same partition across runs. + + The async form shuffles node order, but uses a fixed RNG seed so + results are reproducible. + """ + edges = [ + ("a", "b", 1), + ("b", "c", 1), + ("c", "a", 1), + ("d", "e", 1), + ("e", "f", 1), + ("f", "d", 1), + ("a", "d", 1), + ] + projection = _make_projection(edges) + + first = label_propagation(projection) + second = label_propagation(projection) + + # Canonicalize (sort within cluster, sort list of clusters) + def canon(cs: list[list[str]]) -> list[list[str]]: + return sorted([sorted(c) for c in cs]) + + assert canon(first) == canon(second) + + +@pytest.mark.parametrize("n", [50, 200]) +def test_ring_graph_of_varying_sizes(n: int): + """Rings are edge cases for label propagation.""" + edges = [(f"r{i}", f"r{(i + 1) % n}", 1) for i in range(n)] + projection = _make_projection(edges) + start = time.time() + clusters = label_propagation(projection) + elapsed = time.time() - start + _assert_partition(clusters, {f"r{i}" for i in range(n)}) + assert elapsed < 2.0, f"ring of {n} should converge fast; took {elapsed:.2f}s"