diff --git a/src/query/validator.rs b/src/query/validator.rs index e9e6e50..89a3f4f 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -2,8 +2,8 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; use sqlparser::ast::{ - Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, ObjectName, Query, SetExpr, - Statement, TableFactor, TableWithJoins, + Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArguments, + GroupByWithModifier, ObjectName, Query, SetExpr, Statement, TableFactor, TableWithJoins, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -207,10 +207,16 @@ fn validate_set_expr( validate_expr(selection, cte_names, depth)?; } - // Validate GROUP BY expressions - if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by { - for expr in exprs { - validate_expr(expr, cte_names, depth)?; + // Validate GROUP BY expressions and modifiers (GROUPING SETS, ROLLUP, CUBE) + match &select.group_by { + sqlparser::ast::GroupByExpr::Expressions(exprs, modifiers) => { + for expr in exprs { + validate_expr(expr, cte_names, depth)?; + } + validate_group_by_modifiers(modifiers, cte_names, depth)?; + } + sqlparser::ast::GroupByExpr::All(modifiers) => { + validate_group_by_modifiers(modifiers, cte_names, depth)?; } } @@ -610,6 +616,66 @@ fn is_allowed_function(name: &str) -> bool { ALLOWED_FUNCTIONS.contains(&bare_name) } +/// Recursively validate expressions inside GROUP BY modifiers (GROUPING SETS, etc.) +fn validate_group_by_modifiers( + modifiers: &[GroupByWithModifier], + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + for modifier in modifiers { + if let GroupByWithModifier::GroupingSets(expr) = modifier { + validate_expr(expr, cte_names, depth)?; + } + } + Ok(()) +} + +/// Maximum allowed length argument for string amplification functions (lpad, rpad, repeat). +const MAX_STRING_PAD_LENGTH: i64 = 100_000; + +/// Functions whose first numeric argument (length/count) must be capped to prevent +/// memory exhaustion via string amplification. +const STRING_AMPLIFICATION_FUNCTIONS: &[&str] = &["lpad", "rpad", "repeat"]; + +/// Validate that string amplification functions (lpad, rpad, repeat) don't have +/// excessively large length arguments that could exhaust memory. +fn validate_string_amplification(func_name: &str, func: &Function) -> Result<()> { + let bare_name = func_name.rsplit('.').next().unwrap_or(func_name); + if !STRING_AMPLIFICATION_FUNCTIONS.contains(&bare_name) { + return Ok(()); + } + + if let FunctionArguments::List(arg_list) = &func.args { + // The length/count argument is the 2nd arg for lpad/rpad, 2nd for repeat + let length_arg_idx = match bare_name { + "lpad" | "rpad" => 1, // lpad(string, length, [fill]) + "repeat" => 1, // repeat(string, count) + _ => return Ok(()), + }; + + if let Some(arg) = arg_list.args.get(length_arg_idx) { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(v))) + | FunctionArg::Named { + arg: FunctionArgExpr::Expr(Expr::Value(v)), + .. + } = arg + { + if let sqlparser::ast::Value::Number(n, _) = &v.value { + if let Ok(num) = n.parse::() { + if num > MAX_STRING_PAD_LENGTH { + return Err(anyhow!( + "Function '{bare_name}' length argument {num} exceeds maximum ({MAX_STRING_PAD_LENGTH})" + )); + } + } + } + } + } + } + + Ok(()) +} + fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) -> Result<()> { let func_name = func.name.to_string().to_lowercase(); @@ -617,6 +683,9 @@ fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) return Err(anyhow!("Function '{}' is not allowed", func_name)); } + // Check string amplification functions for excessive length arguments + validate_string_amplification(&func_name, func)?; + if let FunctionArguments::List(arg_list) = &func.args { for arg in &arg_list.args { if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) @@ -628,6 +697,15 @@ fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) validate_expr(expr, cte_names, depth)?; } } + + // Validate expressions inside function argument clauses (e.g. ORDER BY within aggregates) + for clause in &arg_list.clauses { + if let FunctionArgumentClause::OrderBy(order_exprs) = clause { + for order_expr in order_exprs { + validate_expr(&order_expr.expr, cte_names, depth)?; + } + } + } } // Validate FILTER (WHERE ...) clause @@ -1014,4 +1092,85 @@ mod tests { validate_query("SELECT * FROM blocks FETCH FIRST 10 ROWS ONLY").is_err() ); } + + // === Audit finding: GROUPING SETS bypasses function allowlist === + + #[test] + fn test_rejects_set_config_in_grouping_sets() { + assert!(validate_query( + "SELECT 1 FROM blocks GROUP BY ALL GROUPING SETS ((set_config('a','b',true)))" + ) + .is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_grouping_sets_with_expressions() { + assert!(validate_query( + "SELECT num FROM blocks GROUP BY num GROUPING SETS ((pg_sleep(1)))" + ) + .is_err()); + } + + #[test] + fn test_allows_normal_group_by() { + assert!(validate_query("SELECT num, COUNT(*) FROM blocks GROUP BY num").is_ok()); + assert!(validate_query( + "SELECT num, COUNT(*) FROM blocks GROUP BY ALL" + ) + .is_ok()); + } + + // === Audit finding: Aggregate ORDER BY bypasses function allowlist === + + #[test] + fn test_rejects_set_config_in_aggregate_order_by() { + assert!(validate_query( + "SELECT string_agg(hash, ',' ORDER BY set_config('a','b',true)) FROM blocks" + ) + .is_err()); + } + + #[test] + fn test_rejects_pg_sleep_in_aggregate_order_by() { + assert!(validate_query( + "SELECT array_agg(num ORDER BY pg_sleep(1)) FROM blocks" + ) + .is_err()); + } + + #[test] + fn test_allows_normal_aggregate_with_order_by() { + assert!(validate_query( + "SELECT string_agg(hash, ',' ORDER BY num) FROM blocks" + ) + .is_ok()); + assert!(validate_query( + "SELECT array_agg(num ORDER BY num DESC) FROM blocks" + ) + .is_ok()); + } + + // === Audit finding: lpad/rpad/repeat memory exhaustion === + + #[test] + fn test_rejects_lpad_huge_length() { + assert!(validate_query("SELECT lpad('x', 999999999) FROM blocks").is_err()); + } + + #[test] + fn test_rejects_rpad_huge_length() { + assert!(validate_query("SELECT rpad('x', 999999999) FROM blocks").is_err()); + } + + #[test] + fn test_rejects_repeat_huge_count() { + assert!(validate_query("SELECT repeat('x', 999999999) FROM blocks").is_err()); + } + + #[test] + fn test_allows_lpad_rpad_small_length() { + assert!(validate_query("SELECT lpad(hash, 66, '0') FROM blocks").is_ok()); + assert!(validate_query("SELECT rpad(hash, 66, '0') FROM blocks").is_ok()); + assert!(validate_query("SELECT repeat('0', 10) FROM blocks").is_ok()); + } }