Skip to content
Open
37 changes: 24 additions & 13 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,21 +967,32 @@ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityE
episodes = record['episodes']
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
elif provider == GraphProvider.NEO4J:
# Neo4j: Try new JSON format first, fall back to old spread format
raw_attrs = record.get('attributes', '')
if raw_attrs and isinstance(raw_attrs, str):
# New format: JSON string in e.attributes
attributes = json.loads(raw_attrs)
else:
# Old format: attributes spread as individual properties
all_props = record.get('all_properties', {})
if all_props:
attributes = dict(all_props)
for key in ('uuid', 'source_node_uuid', 'target_node_uuid', 'fact',
'fact_embedding', 'name', 'group_id', 'episodes',
'created_at', 'expired_at', 'valid_at', 'invalid_at',
'reference_time', 'attributes'):
attributes.pop(key, None)
else:
attributes = {}
else:
# FalkorDB, Neptune: Original behavior
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('source_node_uuid', None)
attributes.pop('target_node_uuid', None)
attributes.pop('fact', None)
attributes.pop('fact_embedding', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('episodes', None)
attributes.pop('created_at', None)
attributes.pop('expired_at', None)
attributes.pop('valid_at', None)
attributes.pop('invalid_at', None)
attributes.pop('reference_time', None)
for key in ('uuid', 'source_node_uuid', 'target_node_uuid', 'fact',
'fact_embedding', 'name', 'group_id', 'episodes',
'created_at', 'expired_at', 'valid_at', 'invalid_at',
'reference_time'):
attributes.pop(key, None)

edge = EntityEdge(
uuid=record['uuid'],
Expand Down
14 changes: 12 additions & 2 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,14 +1417,24 @@ async def add_episode_bulk(

@handle_multiple_group_ids
async def build_communities(
self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
self,
group_ids: list[str] | None = None,
driver: GraphDriver | None = None,
sample_size: int | None = None,
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
"""
Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
the content of these communities.
----------
group_ids : list[str] | None
Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
sample_size : int | None
Optional. If set, each community's LLM summary is built from only
the top-K most representative members (highest in-community
weighted degree, then longest summary). Dramatically reduces LLM
cost on large graphs — without sampling, summary cost grows with
total node count; with sampling it grows with the number of
communities. Recommended for graphs >10k nodes.
"""
if driver is None:
driver = self.clients.driver
Expand All @@ -1433,7 +1443,7 @@ async def build_communities(
await remove_communities(driver)

community_nodes, community_edges = await build_communities(
driver, self.llm_client, group_ids
driver, self.llm_client, group_ids, sample_size=sample_size
)

await semaphore_gather(
Expand Down
17 changes: 17 additions & 0 deletions graphiti_core/models/edges/edge_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,23 @@ def get_entity_edge_return_query(provider: GraphProvider) -> str:
properties(e) AS attributes
"""

if provider == GraphProvider.NEO4J:
return """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.group_id AS group_id,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
COALESCE(e.attributes, '') AS attributes,
properties(e) AS all_properties
"""

return """
e.uuid AS uuid,
n.uuid AS source_node_uuid,
Expand Down
12 changes: 12 additions & 0 deletions graphiti_core/models/nodes/node_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ def get_entity_node_return_query(provider: GraphProvider) -> str:
n.summary AS summary,
n.attributes AS attributes
"""

if provider == GraphProvider.NEO4J:
return """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
COALESCE(n.attributes, '') AS attributes,
properties(n) AS all_properties
"""

return """
n.uuid AS uuid,
Expand Down
23 changes: 23 additions & 0 deletions graphiti_core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,30 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
elif provider == GraphProvider.NEO4J:
# Neo4j: Try new JSON format first, fall back to old spread format
raw_attrs = record.get('attributes', '')
if raw_attrs and isinstance(raw_attrs, str):
# New format: JSON string in n.attributes
attributes = json.loads(raw_attrs)
else:
# Old format: attributes spread as individual properties
all_props = record.get('all_properties', {})
if all_props:
attributes = dict(all_props)
# Remove known system fields
attributes.pop('uuid', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('name_embedding', None)
attributes.pop('summary', None)
attributes.pop('created_at', None)
attributes.pop('labels', None)
attributes.pop('attributes', None) # Remove the empty attributes field
else:
attributes = {}
else:
# FalkorDB, Neptune: Original behavior
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('name', None)
Expand Down
10 changes: 10 additions & 0 deletions graphiti_core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ async def add_nodes_and_edges_bulk_tx(
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
entity_data['attributes'] = json.dumps(attributes)
elif driver.provider == GraphProvider.NEO4J:
# Neo4j: Serialize attributes to JSON string to support nested structures
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
entity_data['attributes'] = json.dumps(attributes) if attributes else '{}'
else:
# FalkorDB, Neptune: Keep original behavior (spread attributes)
entity_data.update(node.attributes or {})

nodes.append(entity_data)
Expand All @@ -208,7 +213,12 @@ async def add_nodes_and_edges_bulk_tx(
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
elif driver.provider == GraphProvider.NEO4J:
# Neo4j: Serialize attributes to JSON string to support nested structures
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes) if attributes else '{}'
else:
# FalkorDB, Neptune: Keep original behavior (spread attributes)
edge_data.update(edge.attributes or {})

edges.append(edge_data)
Expand Down
Loading
Loading