diff --git a/src/api/views.rs b/src/api/views.rs index ea52bb2..148a0e5 100644 --- a/src/api/views.rs +++ b/src/api/views.rs @@ -178,10 +178,12 @@ pub async fn create_view( ))); } - // Validate SQL is SELECT only - let sql_upper = req.sql.trim().to_uppercase(); - if !sql_upper.starts_with("SELECT") { - return Err(ApiError::BadRequest("SQL must be a SELECT statement".to_string())); + // Validate SQL is query-only (SELECT or WITH ... SELECT) + let sql_upper = req.sql.trim_start().to_uppercase(); + if !sql_upper.starts_with("SELECT") && !sql_upper.starts_with("WITH") { + return Err(ApiError::BadRequest( + "SQL must be a SELECT statement (CTEs with WITH are allowed)".to_string(), + )); } // Parse signature if provided @@ -209,8 +211,8 @@ pub async fn create_view( let sql = if let Some(ref sig) = signature { let sql = sig.normalize_table_references(&req.sql); let sql = sig.rewrite_filters_for_pushdown(&sql); - let cte = sig.to_cte_sql_clickhouse(); - format!("WITH {} {}", cte, sql) + let ctes = vec![sig.to_cte_sql_clickhouse()]; + crate::query::merge_ctes_into_query(&sql, &ctes) } else { req.sql.clone() }; diff --git a/src/clickhouse.rs b/src/clickhouse.rs index 072f151..a3ee9a0 100644 --- a/src/clickhouse.rs +++ b/src/clickhouse.rs @@ -103,7 +103,7 @@ impl ClickHouseEngine { } let ctes: Vec = sigs.iter().map(|sig| sig.to_cte_sql_clickhouse()).collect(); - format!("WITH {} {sql}", ctes.join(", ")) + crate::query::merge_ctes_into_query(&sql, &ctes) } else { sql.to_string() }; diff --git a/src/query/mod.rs b/src/query/mod.rs index 8f7470c..5dfab9e 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -24,3 +24,77 @@ pub fn convert_hex_literals_postgres(sql: &str) -> String { .replace_all(sql, r"'\x$1'") .into_owned() } + +/// Merge generated CTE definitions into a user query. +/// +/// If query already starts with a WITH clause, generated CTEs are prepended +/// to the existing CTE list. Otherwise, a new WITH clause is added. +pub fn merge_ctes_into_query(sql: &str, generated_ctes: &[String]) -> String { + if generated_ctes.is_empty() { + return sql.to_string(); + } + + let ctes = generated_ctes.join(", "); + let trimmed = sql.trim_start(); + let leading_ws = &sql[..sql.len() - trimmed.len()]; + + if starts_with_keyword(trimmed, "WITH RECURSIVE") { + let rest = trimmed["WITH RECURSIVE".len()..].trim_start(); + return format!("{leading_ws}WITH RECURSIVE {ctes}, {rest}"); + } + + if starts_with_keyword(trimmed, "WITH") { + let rest = trimmed["WITH".len()..].trim_start(); + return format!("{leading_ws}WITH {ctes}, {rest}"); + } + + format!("{leading_ws}WITH {ctes} {trimmed}") +} + +fn starts_with_keyword(input: &str, keyword: &str) -> bool { + input.len() >= keyword.len() + && input[..keyword.len()].eq_ignore_ascii_case(keyword) + && input[keyword.len()..] + .chars() + .next() + .is_none_or(|c| c.is_ascii_whitespace()) +} + +#[cfg(test)] +mod tests { + use super::merge_ctes_into_query; + + #[test] + fn test_merge_ctes_into_plain_select() { + let merged = merge_ctes_into_query( + "SELECT n FROM numbers", + &["Transfer AS (SELECT 1)".to_string()], + ); + assert_eq!(merged, "WITH Transfer AS (SELECT 1) SELECT n FROM numbers"); + } + + #[test] + fn test_merge_ctes_into_existing_with() { + let merged = merge_ctes_into_query( + "WITH numbers AS (SELECT 1 AS n) SELECT n FROM numbers", + &["Transfer AS (SELECT 1)".to_string()], + ); + assert_eq!( + merged, + "WITH Transfer AS (SELECT 1), numbers AS (SELECT 1 AS n) SELECT n FROM numbers" + ); + } + + #[test] + fn test_merge_ctes_into_existing_with_recursive() { + let merged = merge_ctes_into_query( + "WITH RECURSIVE r AS (SELECT 1) SELECT * FROM r", + &["Transfer AS (SELECT 1)".to_string()], + ); + assert_eq!( + merged, + "WITH RECURSIVE Transfer AS (SELECT 1), r AS (SELECT 1) SELECT * FROM r" + ); + } + +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 73f0065..f8bd015 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -174,7 +174,7 @@ pub async fn execute_query_postgres( .iter() .map(|sig| sig.to_cte_sql_postgres_filtered(filter)) .collect(); - format!("WITH {} {sql}", ctes.join(", ")) + crate::query::merge_ctes_into_query(&sql, &ctes) } else { sql.to_string() }; @@ -502,4 +502,3 @@ mod tests { } } - diff --git a/tests/api_live_test.rs b/tests/api_live_test.rs index 357e829..b0da653 100644 --- a/tests/api_live_test.rs +++ b/tests/api_live_test.rs @@ -235,6 +235,73 @@ async fn test_query_with_signature_cte() { } } +#[tokio::test] +#[serial(db)] +async fn test_query_with_user_cte_succeeds_without_signature() { + let db = TestDb::empty().await; + let broadcaster = Arc::new(Broadcaster::new()); + let (pools, chain_id) = make_pools(db.pool.clone()); + let mut app = make_test_service(pools, chain_id, broadcaster).await; + + let response = app + .call( + Request::builder() + .method("GET") + .uri("/query?sql=WITH%20numbers%20AS%20%28SELECT%201%20AS%20n%29%20SELECT%20n%20FROM%20numbers&chainId=1") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["ok"], true); + assert_eq!(json["columns"], serde_json::json!(["n"])); + assert_eq!(json["row_count"], 1); +} + +#[tokio::test] +#[serial(db)] +async fn test_query_with_signature_and_user_cte_succeeds() { + let db = TestDb::empty().await; + let broadcaster = Arc::new(Broadcaster::new()); + let (pools, chain_id) = make_pools(db.pool.clone()); + let mut app = make_test_service(pools, chain_id, broadcaster).await; + + let sig = "Transfer(address%20indexed%20from%2Caddress%20indexed%20to%2Cuint256%20value)"; + let uri = format!( + "/query?sql=WITH%20numbers%20AS%20%28SELECT%201%20AS%20n%29%20SELECT%20n%20FROM%20numbers&chainId=1&signature={sig}" + ); + + let response = app + .call( + Request::builder() + .method("GET") + .uri(&uri) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["ok"], true); + assert_eq!(json["columns"], serde_json::json!(["n"])); + assert_eq!(json["row_count"], 1); +} + #[tokio::test] #[serial(db)] async fn test_query_rejects_non_select() { diff --git a/tests/smoke_test.rs b/tests/smoke_test.rs index 7ccd91f..ce1ebd6 100644 --- a/tests/smoke_test.rs +++ b/tests/smoke_test.rs @@ -931,6 +931,66 @@ async fn test_query_logs_with_event_signature() { assert!(result.columns.contains(&"value".to_string())); } +#[tokio::test] +#[serial(db)] +async fn test_query_with_user_cte_succeeds_without_signature() { + let db = TestDb::empty().await; + let opts = default_options(); + + let result = execute_query_postgres( + &db.pool, + "WITH numbers AS (SELECT 1 AS n) SELECT n FROM numbers", + &[], + &opts, + ) + .await + .expect("Query with user CTE failed"); + + assert_eq!(result.engine.as_deref(), Some("postgres")); + assert_eq!(result.row_count, 1); + assert_eq!(result.columns, vec!["n".to_string()]); +} + +#[tokio::test] +#[serial(db)] +async fn test_query_with_signature_and_user_cte_succeeds() { + let db = TestDb::empty().await; + let opts = default_options(); + + let result = execute_query_postgres( + &db.pool, + "WITH numbers AS (SELECT 1 AS n) SELECT n FROM numbers", + &["Transfer(address indexed from, address indexed to, uint256 value)"], + &opts, + ) + .await + .expect("Query with signature + user CTE should succeed"); + + assert_eq!(result.engine.as_deref(), Some("postgres")); + assert_eq!(result.row_count, 1); + assert_eq!(result.columns, vec!["n".to_string()]); +} + +#[tokio::test] +#[serial(db)] +async fn test_query_with_signature_and_multiple_user_ctes_succeeds() { + let db = TestDb::empty().await; + let opts = default_options(); + + let result = execute_query_postgres( + &db.pool, + "WITH first_cte AS (SELECT 1 AS n), second_cte AS (SELECT n + 1 AS n FROM first_cte) SELECT n FROM second_cte", + &["Transfer(address indexed from, address indexed to, uint256 value)"], + &opts, + ) + .await + .expect("Query with signature + multiple user CTEs should succeed"); + + assert_eq!(result.engine.as_deref(), Some("postgres")); + assert_eq!(result.row_count, 1); + assert_eq!(result.columns, vec!["n".to_string()]); +} + #[tokio::test] #[serial(db)] async fn test_query_receipts() { @@ -1284,5 +1344,3 @@ async fn test_query_daily_stats_pattern() { assert!(result.columns.contains(&"day".to_string())); assert!(result.columns.contains(&"transfer_count".to_string())); } - -