diff --git a/helix-db/src/grammar.pest b/helix-db/src/grammar.pest index e7474702..e3d46b61 100644 --- a/helix-db/src/grammar.pest +++ b/helix-db/src/grammar.pest @@ -223,7 +223,7 @@ rerank_mmr = { "RerankMMR" ~ "(" ~ "lambda" ~ ":" ~ evaluates_to_number ~ ("," ~ // --------------------------------------------------------------------- // Vector steps // --------------------------------------------------------------------- -search_vector = { "SearchV" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ vector_data ~ "," ~ (integer | identifier) ~ ")" }// ~ ("::" ~ pre_filter)? } +search_vector = { "SearchV" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ vector_data ~ "," ~ (integer | identifier) ~ ("," ~ pre_filter)? ~ ")" } bm25_search = { "SearchBM25" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ (string_literal | identifier) ~ "," ~ (integer | identifier) ~ ")" } pre_filter = { "PREFILTER" ~ "(" ~ (evaluates_to_bool | anonymous_traversal) ~ ")" } BatchAddV = { "BatchAddV" ~ "<" ~ identifier_upper ~ ">" ~ "(" ~ identifier ~ ")" } diff --git a/helix-db/src/helix_engine/vector_core/utils.rs b/helix-db/src/helix_engine/vector_core/utils.rs index 545b7d02..c6ac2a69 100644 --- a/helix-db/src/helix_engine/vector_core/utils.rs +++ b/helix-db/src/helix_engine/vector_core/utils.rs @@ -129,10 +129,13 @@ impl<'db, 'arena, 'txn, 'q> VectorFilter<'db, 'arena, 'txn, 'q> continue; } - if properties.label == label + // Expand properties into item BEFORE applying filter + // so that filter can access item.get_property() correctly + item.expand_from_vector_without_data(properties); + + if item.label == label && (filter.is_none() || filter.unwrap().iter().all(|f| f(&item, txn))) { - item.expand_from_vector_without_data(properties); result.push(item); break; } diff --git a/helix-db/src/helixc/analyzer/methods/traversal_validation.rs b/helix-db/src/helixc/analyzer/methods/traversal_validation.rs index 439a3fc7..45ff204f 100644 --- a/helix-db/src/helixc/analyzer/methods/traversal_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/traversal_validation.rs @@ -631,56 +631,56 @@ pub(crate) fn validate_traversal<'a>( } }; - // let pre_filter: Option> = match &sv.pre_filter { - // Some(expr) => { - // let (_, stmt) = infer_expr_type( - // ctx, - // expr, - // scope, - // original_query, - // Some(Type::Vector(sv.vector_type.clone())), - // gen_query, - // ); - // // Where/boolean ops don't change the element type, - // // so `cur_ty` stays the same. - // assert!(stmt.is_some()); - // let stmt = stmt.unwrap(); - // let mut gen_traversal = GeneratedTraversal { - // traversal_type: TraversalType::NestedFrom(GenRef::Std("v".to_string())), - // steps: vec![], - // should_collect: ShouldCollect::ToVec, - // source_step: Separator::Empty(SourceStep::Anonymous), - // }; - // match stmt { - // GeneratedStatement::Traversal(tr) => { - // gen_traversal - // .steps - // .push(Separator::Period(GeneratedStep::Where(Where::Ref( - // WhereRef { - // expr: BoExp::Expr(tr), - // }, - // )))); - // } - // GeneratedStatement::BoExp(expr) => { - // gen_traversal - // .steps - // .push(Separator::Period(GeneratedStep::Where(match expr { - // BoExp::Exists(mut traversal) => { - // traversal.should_collect = ShouldCollect::No; - // Where::Ref(WhereRef { - // expr: BoExp::Exists(traversal), - // }) - // } - // _ => Where::Ref(WhereRef { expr }), - // }))); - // } - // _ => unreachable!(), - // } - // Some(vec![BoExp::Expr(gen_traversal)]) - // } - // None => None, - // }; - let pre_filter = None; + let pre_filter: Option> = match &sv.pre_filter { + Some(expr) => { + let (_, stmt) = infer_expr_type( + ctx, + expr, + scope, + original_query, + Some(Type::Vector(sv.vector_type.clone())), + gen_query, + ); + // Where/boolean ops don't change the element type, + // so `cur_ty` stays the same. + if stmt.is_none() { + generate_error!( + ctx, + original_query, + sv.loc.clone(), + E601, + "invalid pre_filter expression" + ); + return None; + } + let stmt = stmt.unwrap(); + match stmt { + GeneratedStatement::Traversal(tr) => { + Some(vec![BoExp::Expr(tr)]) + } + GeneratedStatement::BoExp(expr) => { + match expr { + BoExp::Exists(mut traversal) => { + traversal.should_collect = ShouldCollect::No; + Some(vec![BoExp::Exists(traversal)]) + } + _ => Some(vec![expr]), + } + } + _ => { + generate_error!( + ctx, + original_query, + sv.loc.clone(), + E601, + "pre_filter must be a boolean expression" + ); + return None; + } + } + } + None => None, + }; gen_traversal.traversal_type = TraversalType::Ref; gen_traversal.should_collect = ShouldCollect::ToVec; diff --git a/helix-db/src/helixc/generator/source_steps.rs b/helix-db/src/helixc/generator/source_steps.rs index b491125f..ba6cf37c 100644 --- a/helix-db/src/helixc/generator/source_steps.rs +++ b/helix-db/src/helixc/generator/source_steps.rs @@ -443,7 +443,7 @@ impl Display for SearchVector { self.label, pre_filter .iter() - .map(|f| format!("|v: &HVector, txn: &RoTxn| {f}")) + .map(|f| format!("|val: &HVector, txn: &RoTxn| {f}")) .collect::>() .join(", ") ), diff --git a/helix-db/src/helixc/parser/expression_parse_methods.rs b/helix-db/src/helixc/parser/expression_parse_methods.rs index c754e550..f22eab5d 100644 --- a/helix-db/src/helixc/parser/expression_parse_methods.rs +++ b/helix-db/src/helixc/parser/expression_parse_methods.rs @@ -492,7 +492,25 @@ impl HelixParser { }); } Rule::pre_filter => { - pre_filter = Some(Box::new(self.parse_expression(p)?)); + // Extract the inner expression from PREFILTER(...) + let inner = p.into_inner().next().ok_or_else(|| { + ParserError::from("PREFILTER requires an expression") + })?; + // Handle the inner rule directly (anonymous_traversal or evaluates_to_bool) + let expr = match inner.as_rule() { + Rule::anonymous_traversal => Expression { + loc: inner.loc(), + expr: ExpressionType::Traversal(Box::new(self.parse_anon_traversal(inner)?)), + }, + Rule::evaluates_to_bool => { + let bool_inner = inner.into_inner().next().ok_or_else(|| { + ParserError::from("evaluates_to_bool requires inner expression") + })?; + self.parse_expression(bool_inner)? + }, + _ => self.parse_expression(inner)?, + }; + pre_filter = Some(Box::new(expr)); } _ => { return Err(ParserError::from(format!(