diff --git a/db/api_role.sql b/db/api_role.sql new file mode 100644 index 0000000..c709bcb --- /dev/null +++ b/db/api_role.sql @@ -0,0 +1,35 @@ +-- Read-only API role for query connections. +-- Defense-in-depth: even if the query validator is bypassed, +-- this role cannot modify data, execute arbitrary functions, +-- or exhaust server resources. +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_roles WHERE rolname = 'tidx_api') THEN + CREATE ROLE tidx_api WITH LOGIN PASSWORD 'tidx_api' NOSUPERUSER NOCREATEDB NOCREATEROLE; + END IF; +END $$; + +-- Revoke all privileges first (deny-by-default) +REVOKE ALL ON ALL TABLES IN SCHEMA public FROM tidx_api; +REVOKE ALL ON ALL SEQUENCES IN SCHEMA public FROM tidx_api; +REVOKE EXECUTE ON ALL FUNCTIONS IN SCHEMA public FROM tidx_api; + +-- Grant read-only access to indexed tables only +GRANT USAGE ON SCHEMA public TO tidx_api; +GRANT SELECT ON blocks, txs, logs, receipts, token_holders, token_balances TO tidx_api; + +-- Grant execute only on ABI decode helper functions +GRANT EXECUTE ON FUNCTION abi_uint(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_int(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_address(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_bool(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_bytes(bytea, int) TO tidx_api; +GRANT EXECUTE ON FUNCTION abi_string(bytea, int) TO tidx_api; +GRANT EXECUTE ON FUNCTION format_address(bytea) TO tidx_api; +GRANT EXECUTE ON FUNCTION format_uint(bytea) TO tidx_api; + +-- Resource limits (prevent DoS) +ALTER ROLE tidx_api CONNECTION LIMIT 64; +ALTER ROLE tidx_api SET statement_timeout = '30s'; +ALTER ROLE tidx_api SET work_mem = '64MB'; +ALTER ROLE tidx_api SET temp_file_limit = '256MB'; diff --git a/src/api/mod.rs b/src/api/mod.rs index 07a05b6..096049e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -257,7 +257,7 @@ fn default_timeout() -> u64 { 5000 } fn default_limit() -> i64 { - 10000 + crate::query::HARD_LIMIT_MAX } #[derive(Serialize)] @@ -299,7 +299,7 @@ async fn handle_query_once( let options = QueryOptions { timeout_ms: params.timeout_ms.clamp(100, 30000), - limit: params.limit.clamp(1, 100000), + limit: params.limit.clamp(1, crate::query::HARD_LIMIT_MAX), }; // Route to appropriate engine @@ -397,7 +397,7 @@ async fn handle_query_live( let signature = params.signature; let options = QueryOptions { timeout_ms: params.timeout_ms.clamp(100, 30000), - limit: params.limit.clamp(1, 100000), + limit: params.limit.clamp(1, crate::query::HARD_LIMIT_MAX), }; // Detect if this is an OLAP query (aggregations, etc.) @@ -477,7 +477,16 @@ async fn handle_query_live( } else { let catch_up_start = last_block_num + 1; for block_num in catch_up_start..=end { - let filtered_sql = inject_block_filter(&sql, block_num); + let filtered_sql = match inject_block_filter(&sql, block_num) { + Ok(s) => s, + Err(e) => { + yield Ok(SseEvent::default() + .event("error") + .json_data(serde_json::json!({ "ok": false, "error": e.to_string() })) + .unwrap()); + return; + } + }; match crate::service::execute_query_postgres(&pool, &filtered_sql, signature.as_deref(), &options).await { Ok(result) => { yield Ok(SseEvent::default() @@ -516,50 +525,85 @@ async fn handle_query_live( /// Inject a block number filter into SQL query for live streaming. /// Transforms queries to only return data for the specific block. /// Uses 'num' for blocks table, 'block_num' for txs/logs tables. +/// +/// Uses sqlparser AST manipulation to safely add the filter condition, +/// avoiding SQL injection risks from string-based splicing. #[doc(hidden)] -pub fn inject_block_filter(sql: &str, block_num: u64) -> String { - let sql_upper = sql.to_uppercase(); - - // Determine column name based on table being queried - let col = if sql_upper.contains("FROM BLOCKS") || sql_upper.contains("FROM \"BLOCKS\"") { - "num" - } else { - "block_num" +pub fn inject_block_filter(sql: &str, block_num: u64) -> Result { + use sqlparser::ast::{ + BinaryOperator, Expr, Ident, SetExpr, Statement, Value, }; - - // Find WHERE clause position - if let Some(where_pos) = sql_upper.find("WHERE") { - // Insert after WHERE - let insert_pos = where_pos + 5; - format!( - "{} {} = {} AND {}", - &sql[..insert_pos], - col, - block_num, - &sql[insert_pos..] - ) - } else if let Some(order_pos) = sql_upper.find("ORDER BY") { - // Insert WHERE before ORDER BY - format!( - "{} WHERE {} = {} {}", - &sql[..order_pos], - col, - block_num, - &sql[order_pos..] - ) - } else if let Some(limit_pos) = sql_upper.find("LIMIT") { - // Insert WHERE before LIMIT - format!( - "{} WHERE {} = {} {}", - &sql[..limit_pos], - col, - block_num, - &sql[limit_pos..] - ) - } else { - // Append WHERE at end - format!("{sql} WHERE {col} = {block_num}") + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let mut statements = Parser::parse_sql(&dialect, sql) + .map_err(|e| ApiError::BadRequest(format!("SQL parse error: {e}")))?; + + if statements.len() != 1 { + return Err(ApiError::BadRequest( + "Live mode requires exactly one SQL statement".to_string(), + )); } + + let stmt = &mut statements[0]; + let query = match stmt { + Statement::Query(q) => q, + _ => { + return Err(ApiError::BadRequest( + "Live mode requires a SELECT query".to_string(), + )) + } + }; + + let select = match query.body.as_mut() { + SetExpr::Select(s) => s, + _ => { + return Err(ApiError::BadRequest( + "Live mode requires a simple SELECT query (UNION/INTERSECT not supported)" + .to_string(), + )) + } + }; + + let table_name: String = select + .from + .first() + .and_then(|twj| match &twj.relation { + sqlparser::ast::TableFactor::Table { name, .. } => { + name.0.last().and_then(|part| part.as_ident()).map(|ident| ident.value.to_lowercase()) + } + _ => None, + }) + .ok_or_else(|| { + ApiError::BadRequest( + "Live mode requires a query with a FROM table clause".to_string(), + ) + })?; + + let col_name = if table_name == "blocks" { "num" } else { "block_num" }; + + let col_expr = Expr::CompoundIdentifier(vec![ + Ident::new(&table_name), + Ident::new(col_name), + ]); + + let block_filter = Expr::BinaryOp { + left: Box::new(col_expr), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number(block_num.to_string(), false).into())), + }; + + select.selection = Some(match select.selection.take() { + Some(existing) => Expr::BinaryOp { + left: Box::new(Expr::Nested(Box::new(existing))), + op: BinaryOperator::And, + right: Box::new(block_filter), + }, + None => block_filter, + }); + + Ok(stmt.to_string()) } /// Rewrite analytics table references to include chain-specific database prefix. @@ -599,6 +643,19 @@ pub enum ApiError { NotFound(String), } +impl std::fmt::Display for ApiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiError::BadRequest(msg) => write!(f, "{msg}"), + ApiError::Timeout => write!(f, "Query timeout"), + ApiError::QueryError(msg) => write!(f, "{msg}"), + ApiError::Internal(msg) => write!(f, "{msg}"), + ApiError::Forbidden(msg) => write!(f, "{msg}"), + ApiError::NotFound(msg) => write!(f, "{msg}"), + } + } +} + impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { let (status, message) = match self { diff --git a/src/db/schema.rs b/src/db/schema.rs index 1c51511..9cc11af 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -39,6 +39,9 @@ pub async fn run_migrations(pool: &Pool) -> Result<()> { // Load any optional extensions conn.batch_execute(include_str!("../../db/extensions.sql")).await?; + // Create read-only API role with SELECT-only access to indexed tables + conn.batch_execute(include_str!("../../db/api_role.sql")).await?; + Ok(()) } diff --git a/src/query/mod.rs b/src/query/mod.rs index e0d16a2..dee2f24 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -7,7 +7,7 @@ pub use parser::{ extract_order_by_columns, AbiParam, AbiType, EventSignature, }; pub use router::{route_query, QueryEngine}; -pub use validator::validate_query; +pub use validator::{validate_query, HARD_LIMIT_MAX}; use regex_lite::Regex; use std::sync::LazyLock; diff --git a/src/query/validator.rs b/src/query/validator.rs index 87a2a88..a856389 100644 --- a/src/query/validator.rs +++ b/src/query/validator.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use anyhow::{anyhow, Result}; use sqlparser::ast::{ Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, ObjectName, Query, SetExpr, @@ -6,18 +8,35 @@ use sqlparser::ast::{ use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; +const ALLOWED_TABLES: &[&str] = &[ + "blocks", + "txs", + "logs", + "receipts", + "token_holders", + "token_balances", +]; + +const MAX_QUERY_LENGTH: usize = 65_536; +const MAX_SUBQUERY_DEPTH: usize = 4; +pub const HARD_LIMIT_MAX: i64 = 10_000; + /// Validates that a SQL query is safe to execute. /// -/// Rejects: -/// - Multiple statements -/// - Non-SELECT statements (INSERT, UPDATE, DELETE, etc.) -/// - Data-modifying CTEs -/// - Dangerous functions (pg_sleep, read_csv, pg_read_file, etc.) -/// - System catalog access +/// Uses a reject-by-default approach: only explicitly allowed tables, +/// functions, and expression types are permitted. Everything else is rejected. pub fn validate_query(sql: &str) -> Result<()> { + if sql.len() > MAX_QUERY_LENGTH { + return Err(anyhow!( + "Query too large ({} bytes, max {})", + sql.len(), + MAX_QUERY_LENGTH + )); + } + let dialect = GenericDialect {}; - let statements = Parser::parse_sql(&dialect, sql) - .map_err(|e| anyhow!("SQL parse error: {e}"))?; + let statements = + Parser::parse_sql(&dialect, sql).map_err(|e| anyhow!("SQL parse error: {e}"))?; if statements.is_empty() { return Err(anyhow!("Empty query")); @@ -30,107 +49,260 @@ pub fn validate_query(sql: &str) -> Result<()> { let stmt = &statements[0]; match stmt { - Statement::Query(query) => validate_query_ast(query), + Statement::Query(query) => { + let cte_names = extract_cte_names(query); + validate_query_ast(query, &cte_names, 0) + } _ => Err(anyhow!("Only SELECT queries are allowed")), } } -fn validate_query_ast(query: &Query) -> Result<()> { - // Check CTEs for data-modifying statements - for cte in &query.with.as_ref().map_or(vec![], |w| w.cte_tables.clone()) { - validate_query_ast(&cte.query)?; +fn extract_cte_names(query: &Query) -> HashSet { + let mut names = HashSet::new(); + if let Some(with) = &query.with { + for cte in &with.cte_tables { + names.insert(cte.alias.name.value.to_lowercase()); + } + } + names +} + +fn validate_query_ast(query: &Query, cte_names: &HashSet, depth: usize) -> Result<()> { + if depth > MAX_SUBQUERY_DEPTH { + return Err(anyhow!( + "Subquery nesting too deep (max {} levels)", + MAX_SUBQUERY_DEPTH + )); + } + + // Block recursive CTEs (can cause endless loops / resource exhaustion) + if let Some(with) = &query.with { + if with.recursive { + return Err(anyhow!("Recursive CTEs are not allowed")); + } + } + + // Block FOR UPDATE / FOR SHARE locking clauses + if !query.locks.is_empty() { + return Err(anyhow!( + "Locking clauses (FOR UPDATE/SHARE) are not allowed" + )); + } + + // Block FETCH clause (alternative to LIMIT, could bypass cap) + if query.fetch.is_some() { + return Err(anyhow!("FETCH clause is not allowed, use LIMIT instead")); } - validate_set_expr(&query.body) + let mut all_cte_names = cte_names.clone(); + if let Some(with) = &query.with { + for cte in &with.cte_tables { + all_cte_names.insert(cte.alias.name.value.to_lowercase()); + } + } + + for cte in &query + .with + .as_ref() + .map_or(vec![], |w| w.cte_tables.clone()) + { + validate_query_ast(&cte.query, &all_cte_names, depth + 1)?; + } + + validate_set_expr(&query.body, &all_cte_names, depth)?; + + // Validate ORDER BY expressions + if let Some(order_by) = &query.order_by { + match &order_by.kind { + sqlparser::ast::OrderByKind::Expressions(exprs) => { + for order_expr in exprs { + validate_expr(&order_expr.expr, &all_cte_names, depth)?; + } + } + sqlparser::ast::OrderByKind::All(_) => {} + } + } + + // Validate LIMIT / OFFSET: only allow numeric literals + if let Some(limit_clause) = &query.limit_clause { + match limit_clause { + sqlparser::ast::LimitClause::LimitOffset { + limit, + offset, + limit_by, + } => { + if let Some(limit_expr) = limit { + validate_limit_expr(limit_expr, "LIMIT")?; + } + if let Some(offset) = offset { + validate_limit_expr(&offset.value, "OFFSET")?; + } + if !limit_by.is_empty() { + return Err(anyhow!("LIMIT BY is not allowed")); + } + } + sqlparser::ast::LimitClause::OffsetCommaLimit { offset, limit } => { + validate_limit_expr(offset, "OFFSET")?; + validate_limit_expr(limit, "LIMIT")?; + } + } + } + + Ok(()) } -fn validate_set_expr(set_expr: &SetExpr) -> Result<()> { +fn validate_limit_expr(expr: &Expr, context: &str) -> Result<()> { + match expr { + Expr::Value(v) => { + let val = &v.value; + match val { + sqlparser::ast::Value::Number(n, _) => { + if let Ok(num) = n.parse::() { + if num < 0 { + return Err(anyhow!("{context} must not be negative")); + } + if num > HARD_LIMIT_MAX { + return Err(anyhow!( + "{context} value {num} exceeds maximum ({HARD_LIMIT_MAX})" + )); + } + Ok(()) + } else { + Err(anyhow!("{context} must be a valid integer")) + } + } + sqlparser::ast::Value::Null => { + Err(anyhow!("{context} NULL is not allowed")) + } + _ => Err(anyhow!("{context} must be a numeric literal")), + } + } + _ => Err(anyhow!("{context} must be a numeric literal")), + } +} + +fn validate_set_expr( + set_expr: &SetExpr, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { match set_expr { SetExpr::Select(select) => { - // Validate FROM clause + // Reject SELECT INTO (creates objects) + if select.into.is_some() { + return Err(anyhow!("SELECT INTO is not allowed")); + } + for table in &select.from { - validate_table_with_joins(table)?; + validate_table_with_joins(table, cte_names, depth)?; } - // Validate SELECT expressions for item in &select.projection { if let sqlparser::ast::SelectItem::UnnamedExpr(expr) | sqlparser::ast::SelectItem::ExprWithAlias { expr, .. } = item { - validate_expr(expr)?; + validate_expr(expr, cte_names, depth)?; } } - // Validate WHERE clause if let Some(selection) = &select.selection { - validate_expr(selection)?; + validate_expr(selection, cte_names, depth)?; + } + + // Validate GROUP BY expressions + if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by { + for expr in exprs { + validate_expr(expr, cte_names, depth)?; + } + } + + // Validate HAVING + if let Some(having) = &select.having { + validate_expr(having, cte_names, depth)?; } Ok(()) } - SetExpr::Query(q) => validate_query_ast(q), + SetExpr::Query(q) => validate_query_ast(q, cte_names, depth), SetExpr::SetOperation { left, right, .. } => { - validate_set_expr(left)?; - validate_set_expr(right) + validate_set_expr(left, cte_names, depth)?; + validate_set_expr(right, cte_names, depth) + } + SetExpr::Values(values) => { + for row in &values.rows { + for expr in row { + validate_expr(expr, cte_names, depth)?; + } + } + Ok(()) } - SetExpr::Values(_) => Ok(()), SetExpr::Insert(_) => Err(anyhow!("INSERT not allowed")), SetExpr::Update(_) => Err(anyhow!("UPDATE not allowed")), SetExpr::Delete(_) => Err(anyhow!("DELETE not allowed")), SetExpr::Merge(_) => Err(anyhow!("MERGE not allowed")), - SetExpr::Table(_) => Ok(()), + SetExpr::Table(_) => Err(anyhow!("TABLE statement is not allowed")), } } -fn validate_table_with_joins(table: &TableWithJoins) -> Result<()> { - validate_table_factor(&table.relation)?; +fn validate_table_with_joins( + table: &TableWithJoins, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { + validate_table_factor(&table.relation, cte_names, depth)?; for join in &table.joins { - validate_table_factor(&join.relation)?; + validate_table_factor(&join.relation, cte_names, depth)?; + let constraint = match &join.join_operator { + sqlparser::ast::JoinOperator::Join(c) + | sqlparser::ast::JoinOperator::Inner(c) + | sqlparser::ast::JoinOperator::Left(c) + | sqlparser::ast::JoinOperator::LeftOuter(c) + | sqlparser::ast::JoinOperator::Right(c) + | sqlparser::ast::JoinOperator::RightOuter(c) + | sqlparser::ast::JoinOperator::FullOuter(c) + | sqlparser::ast::JoinOperator::CrossJoin(c) + | sqlparser::ast::JoinOperator::Semi(c) + | sqlparser::ast::JoinOperator::LeftSemi(c) + | sqlparser::ast::JoinOperator::RightSemi(c) + | sqlparser::ast::JoinOperator::Anti(c) + | sqlparser::ast::JoinOperator::LeftAnti(c) + | sqlparser::ast::JoinOperator::RightAnti(c) => Some(c), + _ => None, + }; + if let Some(sqlparser::ast::JoinConstraint::On(expr)) = constraint { + validate_expr(expr, cte_names, depth)?; + } } Ok(()) } -fn validate_table_factor(factor: &TableFactor) -> Result<()> { +fn validate_table_factor( + factor: &TableFactor, + cte_names: &HashSet, + depth: usize, +) -> Result<()> { match factor { TableFactor::Table { name, args, .. } => { - // Check if this is a table-valued function like read_csv(...) if args.is_some() { - let func_name = name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } - } - validate_table_name(name) - } - TableFactor::Derived { subquery, .. } => validate_query_ast(subquery), - TableFactor::TableFunction { expr, .. } => { - // Block table functions that can read filesystem - if let Expr::Function(func) = expr { - let func_name = func.name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); - } - } - Ok(()) - } - TableFactor::Function { name, .. } => { - let func_name = name.to_string().to_lowercase(); - if is_dangerous_table_function(&func_name) { - return Err(anyhow!("Table function '{func_name}' is not allowed")); + return Err(anyhow!("Table functions are not allowed")); } - Ok(()) + validate_table_name(name, cte_names) } - TableFactor::NestedJoin { table_with_joins, .. } => { - validate_table_with_joins(table_with_joins) + TableFactor::Derived { subquery, .. } => { + validate_query_ast(subquery, cte_names, depth + 1) } - _ => Ok(()), + TableFactor::TableFunction { .. } => Err(anyhow!("Table functions are not allowed")), + TableFactor::Function { .. } => Err(anyhow!("Table functions are not allowed")), + TableFactor::NestedJoin { + table_with_joins, .. + } => validate_table_with_joins(table_with_joins, cte_names, depth), + _ => Err(anyhow!("Unsupported FROM clause type")), } } -fn validate_table_name(name: &ObjectName) -> Result<()> { +fn validate_table_name(name: &ObjectName, cte_names: &HashSet) -> Result<()> { let full_name = name.to_string().to_lowercase(); - // Block system catalogs const BLOCKED_SCHEMAS: &[&str] = &[ "pg_catalog", "information_schema", @@ -140,11 +312,12 @@ fn validate_table_name(name: &ObjectName) -> Result<()> { for schema in BLOCKED_SCHEMAS { if full_name.starts_with(schema) { - return Err(anyhow!("Access to system catalog '{schema}' is not allowed")); + return Err(anyhow!( + "Access to system catalog '{schema}' is not allowed" + )); } } - // Block specific dangerous tables const BLOCKED_TABLES: &[&str] = &[ "pg_stat_activity", "pg_settings", @@ -160,137 +333,309 @@ fn validate_table_name(name: &ObjectName) -> Result<()> { } } - Ok(()) + let bare_name = name + .0 + .last() + .and_then(|part| part.as_ident()) + .map(|ident| ident.value.to_lowercase()) + .unwrap_or_default(); + + if ALLOWED_TABLES.contains(&bare_name.as_str()) { + return Ok(()); + } + + if cte_names.contains(&bare_name) { + return Ok(()); + } + + Err(anyhow!("Access to table '{bare_name}' is not allowed")) } -fn validate_expr(expr: &Expr) -> Result<()> { +/// Reject-by-default expression validation. +/// Only explicitly allowed expression types are permitted. +fn validate_expr(expr: &Expr, cte_names: &HashSet, depth: usize) -> Result<()> { match expr { - Expr::Function(func) => validate_function(func), - Expr::Subquery(q) => validate_query_ast(q), - Expr::InSubquery { subquery, .. } => validate_query_ast(subquery), - Expr::Exists { subquery, .. } => validate_query_ast(subquery), + // Safe leaf nodes + Expr::Identifier(_) | Expr::CompoundIdentifier(_) => Ok(()), + Expr::Value(_) => Ok(()), + Expr::TypedString(_) => Ok(()), + Expr::Wildcard(_) | Expr::QualifiedWildcard(_, _) => Ok(()), + + // Function calls (validated against allowlist) + Expr::Function(func) => validate_function(func, cte_names, depth), + + // Subqueries (increment depth) + Expr::Subquery(q) => validate_query_ast(q, cte_names, depth + 1), + Expr::InSubquery { + expr, subquery, .. + } => { + validate_expr(expr, cte_names, depth)?; + validate_query_ast(subquery, cte_names, depth + 1) + } + Expr::Exists { subquery, .. } => validate_query_ast(subquery, cte_names, depth + 1), + + // Binary / unary operations Expr::BinaryOp { left, right, .. } => { - validate_expr(left)?; - validate_expr(right) + validate_expr(left, cte_names, depth)?; + validate_expr(right, cte_names, depth) } - Expr::UnaryOp { expr, .. } => validate_expr(expr), - Expr::Between { expr, low, high, .. } => { - validate_expr(expr)?; - validate_expr(low)?; - validate_expr(high) + Expr::UnaryOp { expr, .. } => validate_expr(expr, cte_names, depth), + + // Range expressions + Expr::Between { + expr, low, high, .. + } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(low, cte_names, depth)?; + validate_expr(high, cte_names, depth) } - Expr::Case { operand, conditions, else_result, .. } => { + + // CASE WHEN + Expr::Case { + operand, + conditions, + else_result, + .. + } => { if let Some(op) = operand { - validate_expr(op)?; + validate_expr(op, cte_names, depth)?; } for case_when in conditions { - validate_expr(&case_when.condition)?; - validate_expr(&case_when.result)?; + validate_expr(&case_when.condition, cte_names, depth)?; + validate_expr(&case_when.result, cte_names, depth)?; } if let Some(else_r) = else_result { - validate_expr(else_r)?; + validate_expr(else_r, cte_names, depth)?; } Ok(()) } - Expr::Cast { expr, .. } => validate_expr(expr), - Expr::Nested(e) => validate_expr(e), + + // Type casting + Expr::Cast { expr, .. } => validate_expr(expr, cte_names, depth), + Expr::Nested(e) => validate_expr(e, cte_names, depth), + + // IN list Expr::InList { expr, list, .. } => { - validate_expr(expr)?; + validate_expr(expr, cte_names, depth)?; for item in list { - validate_expr(item)?; + validate_expr(item, cte_names, depth)?; + } + Ok(()) + } + + // Boolean tests + Expr::IsNull(e) + | Expr::IsNotNull(e) + | Expr::IsTrue(e) + | Expr::IsFalse(e) + | Expr::IsNotTrue(e) + | Expr::IsNotFalse(e) + | Expr::IsUnknown(e) + | Expr::IsNotUnknown(e) => validate_expr(e, cte_names, depth), + + // Pattern matching + Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(pattern, cte_names, depth) + } + Expr::SimilarTo { expr, pattern, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(pattern, cte_names, depth) + } + + // ANY/ALL operators + Expr::AnyOp { right, .. } | Expr::AllOp { right, .. } => { + validate_expr(right, cte_names, depth) + } + + // IS DISTINCT FROM + Expr::IsDistinctFrom(a, b) | Expr::IsNotDistinctFrom(a, b) => { + validate_expr(a, cte_names, depth)?; + validate_expr(b, cte_names, depth) + } + + // SQL builtins parsed as dedicated Expr variants (not Function) + Expr::Extract { expr, .. } => validate_expr(expr, cte_names, depth), + Expr::Substring { expr, substring_from, substring_for, .. } => { + validate_expr(expr, cte_names, depth)?; + if let Some(from) = substring_from { + validate_expr(from, cte_names, depth)?; + } + if let Some(for_expr) = substring_for { + validate_expr(for_expr, cte_names, depth)?; + } + Ok(()) + } + Expr::Trim { expr, trim_what, .. } => { + validate_expr(expr, cte_names, depth)?; + if let Some(what) = trim_what { + validate_expr(what, cte_names, depth)?; + } + Ok(()) + } + Expr::Ceil { expr, .. } | Expr::Floor { expr, .. } => { + validate_expr(expr, cte_names, depth) + } + Expr::Position { expr, r#in, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(r#in, cte_names, depth) + } + Expr::Overlay { expr, overlay_what, overlay_from, overlay_for, .. } => { + validate_expr(expr, cte_names, depth)?; + validate_expr(overlay_what, cte_names, depth)?; + validate_expr(overlay_from, cte_names, depth)?; + if let Some(for_expr) = overlay_for { + validate_expr(for_expr, cte_names, depth)?; + } + Ok(()) + } + Expr::Collate { expr, .. } => validate_expr(expr, cte_names, depth), + Expr::AtTimeZone { timestamp, time_zone, .. } => { + validate_expr(timestamp, cte_names, depth)?; + validate_expr(time_zone, cte_names, depth) + } + + // Tuple / row constructors + Expr::Tuple(exprs) => { + for e in exprs { + validate_expr(e, cte_names, depth)?; } Ok(()) } - _ => Ok(()), + + // Array literal + Expr::Array(arr) => { + for e in &arr.elem { + validate_expr(e, cte_names, depth)?; + } + Ok(()) + } + + // Interval literal + Expr::Interval(_) => Ok(()), + + // Reject everything else (reject-by-default) + _ => Err(anyhow!("Unsupported expression type")), } } -fn validate_function(func: &Function) -> Result<()> { +const ALLOWED_FUNCTIONS: &[&str] = &[ + // ABI decode helpers (custom PostgreSQL functions) + "abi_uint", + "abi_int", + "abi_address", + "abi_bool", + "abi_bytes", + "abi_string", + "format_address", + "format_uint", + // Aggregates + "count", + "sum", + "avg", + "min", + "max", + // Scalar / null handling + "coalesce", + "nullif", + "greatest", + "least", + // Numeric + "abs", + "round", + "floor", + "ceil", + "ceiling", + "trunc", + "pow", + "power", + // String + "lower", + "upper", + "length", + "substring", + "substr", + "trim", + "ltrim", + "rtrim", + "replace", + "concat", + "left", + "right", + "lpad", + "rpad", + // Bytea / hex + "encode", + "decode", + "octet_length", + // Time + "date_trunc", + "extract", + "to_timestamp", + "now", + // Window functions + "row_number", + "rank", + "dense_rank", + "lag", + "lead", + "first_value", + "last_value", + "ntile", + "percent_rank", + "cume_dist", + // Type casting helpers + "cast", +]; + +fn is_allowed_function(name: &str) -> bool { + let bare_name = name.rsplit('.').next().unwrap_or(name); + ALLOWED_FUNCTIONS.contains(&bare_name) +} + +fn validate_function(func: &Function, cte_names: &HashSet, depth: usize) -> Result<()> { let func_name = func.name.to_string().to_lowercase(); - if is_dangerous_function(&func_name) { - return Err(anyhow!("Function '{func_name}' is not allowed")); + if !is_allowed_function(&func_name) { + return Err(anyhow!("Function '{}' is not allowed", func_name)); } - // Recursively validate function arguments if let FunctionArguments::List(arg_list) = &func.args { for arg in &arg_list.args { if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) - | FunctionArg::Named { arg: FunctionArgExpr::Expr(expr), .. } = arg + | FunctionArg::Named { + arg: FunctionArgExpr::Expr(expr), + .. + } = arg { - validate_expr(expr)?; + validate_expr(expr, cte_names, depth)?; } } } - Ok(()) -} - -/// Check if a function is dangerous (DoS, file access, side effects). -fn is_dangerous_function(name: &str) -> bool { - const DANGEROUS: &[&str] = &[ - // PostgreSQL DoS/side-effect functions - "pg_sleep", - "pg_terminate_backend", - "pg_cancel_backend", - "pg_reload_conf", - "pg_rotate_logfile", - "pg_switch_wal", - "pg_create_restore_point", - "pg_start_backup", - "pg_stop_backup", - "set_config", - "current_setting", - // PostgreSQL file access - "pg_read_file", - "pg_read_binary_file", - "pg_ls_dir", - "pg_stat_file", - "lo_import", - "lo_export", - // PostgreSQL command execution - "pg_execute_server_program", - // ClickHouse system functions - "system.flush_logs", - "system.reload_config", - "system.shutdown", - "system.kill_query", - "system.drop_dns_cache", - "system.drop_mark_cache", - "system.drop_uncompressed_cache", - ]; + // Validate FILTER (WHERE ...) clause + if let Some(filter) = &func.filter { + validate_expr(filter, cte_names, depth)?; + } - DANGEROUS.iter().any(|&d| name == d || name.ends_with(&format!(".{d}"))) -} + // Validate WITHIN GROUP (ORDER BY ...) clause + for order_expr in &func.within_group { + validate_expr(&order_expr.expr, cte_names, depth)?; + } -/// Check if a table function is dangerous (filesystem access). -fn is_dangerous_table_function(name: &str) -> bool { - const DANGEROUS: &[&str] = &[ - // ClickHouse file/URL table functions - "file", - "url", - "s3", - "gcs", - "hdfs", - "remote", - "remoteSecure", - "cluster", - "clusterAllReplicas", - // ClickHouse input formats - "input", - "format", - // ClickHouse system access - "system", - "numbers", - "zeros", - "generateRandom", - // ClickHouse dictionary access (could leak data) - "dictGet", - "dictGetOrDefault", - "dictHas", - ]; + // Validate window function OVER clause + if let Some(window_type) = &func.over { + if let sqlparser::ast::WindowType::WindowSpec(spec) = window_type { + for expr in &spec.partition_by { + validate_expr(expr, cte_names, depth)?; + } + for order_expr in &spec.order_by { + validate_expr(&order_expr.expr, cte_names, depth)?; + } + } + } - DANGEROUS.iter().any(|&d| name == d || name.contains(&format!("{d}("))) + Ok(()) } #[cfg(test)] @@ -330,7 +675,6 @@ mod tests { #[test] fn test_rejects_data_modifying_cte() { - // This is the V1 bypass attempt let result = validate_query( "WITH x AS (UPDATE blocks SET miner = 'pwn' RETURNING 1) SELECT * FROM x", ); @@ -339,11 +683,9 @@ mod tests { #[test] fn test_rejects_comment_bypass() { - // Comments are stripped by parser, so this becomes a valid UPDATE let result = validate_query( "WITH x AS (UPDA/**/TE blocks SET miner = 'pwn' RETURNING 1) SELECT * FROM x", ); - // Parser will either fail to parse or recognize it as UPDATE assert!(result.is_err()); } @@ -385,22 +727,278 @@ mod tests { #[test] fn test_allows_window_functions() { + assert!( + validate_query("SELECT num, ROW_NUMBER() OVER (ORDER BY num) FROM blocks").is_ok() + ); + } + + #[test] + fn test_allows_subquery() { + assert!( + validate_query("SELECT * FROM blocks WHERE num IN (SELECT block_num FROM txs)").is_ok() + ); + } + + #[test] + fn test_rejects_nested_dangerous_function() { + assert!(validate_query("SELECT COALESCE(pg_sleep(1), 0)").is_err()); + } + + #[test] + fn test_rejects_sync_state() { + assert!(validate_query("SELECT * FROM sync_state").is_err()); + } + + #[test] + fn test_rejects_pg_tables() { + assert!(validate_query("SELECT * FROM pg_tables").is_err()); + } + + #[test] + fn test_rejects_unknown_table() { + assert!(validate_query("SELECT * FROM some_random_table").is_err()); + } + + #[test] + fn test_allows_cte_defined_table() { + assert!( + validate_query("WITH my_cte AS (SELECT * FROM blocks) SELECT * FROM my_cte").is_ok() + ); + } + + #[test] + fn test_rejects_dblink() { assert!(validate_query( - "SELECT num, ROW_NUMBER() OVER (ORDER BY num) FROM blocks" + "SELECT * FROM dblink('host=evil dbname=secrets', 'SELECT * FROM passwords')" ) - .is_ok()); + .is_err()); + assert!(validate_query("SELECT dblink_connect('myconn', 'host=evil')").is_err()); + assert!(validate_query("SELECT dblink_exec('myconn', 'DROP TABLE blocks')").is_err()); } #[test] - fn test_allows_subquery() { + fn test_allows_analytics_tables() { + assert!(validate_query("SELECT * FROM token_holders").is_ok()); + assert!(validate_query("SELECT * FROM token_balances").is_ok()); + assert!(validate_query("SELECT * FROM public.blocks").is_ok()); + } + + #[test] + fn test_rejects_recursive_cte() { + assert!(validate_query( + "WITH RECURSIVE r AS (SELECT 1 AS n UNION ALL SELECT n+1 FROM r) SELECT * FROM r" + ) + .is_err()); + } + + #[test] + fn test_rejects_generate_series() { + assert!(validate_query("SELECT generate_series(1, 1000000000)").is_err()); + assert!(validate_query( + "SELECT * FROM blocks WHERE num IN (SELECT generate_series(1, 1000000))" + ) + .is_err()); + } + + #[test] + fn test_rejects_values_function_bypass() { + assert!(validate_query("VALUES (pg_sleep(10))").is_err()); + assert!(validate_query("VALUES (pg_read_file('/etc/passwd'))").is_err()); + } + + #[test] + fn test_rejects_table_statement() { + assert!(validate_query("TABLE blocks").is_err()); + assert!(validate_query("TABLE pg_shadow").is_err()); + } + + #[test] + fn test_rejects_select_into() { + assert!(validate_query("SELECT * INTO newtable FROM blocks").is_err()); + } + + #[test] + fn test_rejects_lo_functions() { + assert!(validate_query("SELECT lo_get(12345)").is_err()); + assert!(validate_query("SELECT lo_open(12345, 262144)").is_err()); + } + + #[test] + fn test_rejects_admin_file_functions() { + assert!(validate_query("SELECT pg_file_read('/etc/passwd', 0, 1000)").is_err()); + assert!(validate_query("SELECT pg_file_write('/tmp/evil', 'data', false)").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_having() { + assert!(validate_query( + "SELECT COUNT(*) FROM blocks GROUP BY num HAVING pg_sleep(1) IS NOT NULL" + ) + .is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_join_on() { + assert!( + validate_query("SELECT * FROM blocks JOIN txs ON pg_sleep(1) IS NOT NULL").is_err() + ); + } + + #[test] + fn test_allows_simple_values() { + assert!(validate_query("VALUES (1, 'hello'), (2, 'world')").is_ok()); + } + + #[test] + fn test_rejects_unknown_function() { + assert!(validate_query("SELECT md5('test') FROM blocks").is_err()); + assert!(validate_query("SELECT regexp_replace(hash, 'a', 'b') FROM blocks").is_err()); + } + + #[test] + fn test_allows_abi_helpers() { + assert!(validate_query("SELECT abi_uint(input) FROM txs").is_ok()); + assert!(validate_query("SELECT abi_address(input) FROM txs").is_ok()); + assert!(validate_query("SELECT format_address(miner) FROM blocks").is_ok()); + } + + #[test] + fn test_allows_common_functions() { + assert!(validate_query("SELECT COALESCE(gas_used, 0) FROM blocks").is_ok()); + assert!(validate_query("SELECT ABS(gas_used) FROM blocks").is_ok()); + assert!(validate_query("SELECT LOWER('test') FROM blocks").is_ok()); + assert!( + validate_query("SELECT date_trunc('hour', to_timestamp(ts)) FROM blocks").is_ok() + ); + } + + #[test] + fn test_rejects_all_table_functions() { + assert!(validate_query("SELECT * FROM generate_series(1, 100)").is_err()); + assert!(validate_query("SELECT * FROM unnest(ARRAY[1,2,3])").is_err()); + } + + #[test] + fn test_rejects_unsupported_table_factor() { + assert!(validate_query("SELECT * FROM UNNEST(ARRAY[1,2,3])").is_err()); + } + + // === New tests for this commit === + + #[test] + fn test_rejects_for_update() { + assert!(validate_query("SELECT * FROM blocks FOR UPDATE").is_err()); + assert!(validate_query("SELECT * FROM blocks FOR SHARE").is_err()); + } + + #[test] + fn test_rejects_dangerous_function_in_order_by() { + assert!( + validate_query("SELECT * FROM blocks ORDER BY pg_sleep(1)").is_err() + ); + } + + #[test] + fn test_rejects_excessive_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT 100000000").is_err()); + assert!(validate_query("SELECT * FROM blocks LIMIT 10001").is_err()); + } + + #[test] + fn test_allows_reasonable_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT 100").is_ok()); + assert!(validate_query("SELECT * FROM blocks LIMIT 10000").is_ok()); + assert!(validate_query("SELECT * FROM blocks LIMIT 1 OFFSET 5").is_ok()); + } + + #[test] + fn test_rejects_subquery_in_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT (SELECT 1)").is_err()); + } + + #[test] + fn test_rejects_deep_subquery_nesting() { + // 5 levels of derived table nesting exceeds MAX_SUBQUERY_DEPTH (4) + let deep = "SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM blocks) a) b) c) d) e"; + assert!(validate_query(deep).is_err()); + } + + #[test] + fn test_allows_moderate_subquery_nesting() { + // 3 levels of nesting is within limits + let moderate = "SELECT * FROM (SELECT * FROM (SELECT * FROM blocks) a) b"; + assert!(validate_query(moderate).is_ok()); + } + + #[test] + fn test_rejects_query_too_large() { + let huge = format!("SELECT * FROM blocks WHERE num IN ({})", "1,".repeat(70_000)); + assert!(validate_query(&huge).is_err()); + } + + #[test] + fn test_allows_order_by_column() { + assert!(validate_query("SELECT * FROM blocks ORDER BY num DESC").is_ok()); + assert!( + validate_query("SELECT * FROM blocks ORDER BY num DESC, hash ASC").is_ok() + ); + } + + #[test] + fn test_allows_cast_expression() { + assert!(validate_query("SELECT CAST(num AS TEXT) FROM blocks").is_ok()); + } + + #[test] + fn test_allows_between() { + assert!(validate_query("SELECT * FROM blocks WHERE num BETWEEN 1 AND 100").is_ok()); + } + + #[test] + fn test_allows_like() { + assert!(validate_query("SELECT * FROM txs WHERE hash LIKE '%abc%'").is_ok()); + } + + #[test] + fn test_allows_is_null() { + assert!(validate_query("SELECT * FROM blocks WHERE miner IS NOT NULL").is_ok()); + } + + #[test] + fn test_allows_case_when() { assert!(validate_query( - "SELECT * FROM blocks WHERE num IN (SELECT block_num FROM txs)" + "SELECT CASE WHEN num > 100 THEN 'big' ELSE 'small' END FROM blocks" ) .is_ok()); } #[test] - fn test_rejects_nested_dangerous_function() { - assert!(validate_query("SELECT COALESCE(pg_sleep(1), 0)").is_err()); + fn test_allows_array_literal() { + assert!(validate_query("SELECT * FROM blocks WHERE num = ANY(ARRAY[1,2,3])").is_ok()); + } + + #[test] + fn test_rejects_filter_clause_bypass() { + assert!(validate_query( + "SELECT COUNT(*) FILTER (WHERE pg_sleep(1) IS NOT NULL) FROM blocks" + ) + .is_err()); + } + + #[test] + fn test_rejects_limit_null() { + assert!(validate_query("SELECT * FROM blocks LIMIT NULL").is_err()); + } + + #[test] + fn test_rejects_negative_limit() { + assert!(validate_query("SELECT * FROM blocks LIMIT -1").is_err()); + } + + #[test] + fn test_rejects_fetch_clause() { + assert!( + validate_query("SELECT * FROM blocks FETCH FIRST 10 ROWS ONLY").is_err() + ); } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 27a0d98..384dcf3 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -5,7 +5,7 @@ use std::time::Instant; use crate::db::Pool; use crate::metrics; -use crate::query::{extract_column_references, validate_query, EventSignature}; +use crate::query::{extract_column_references, validate_query, EventSignature, HARD_LIMIT_MAX}; #[derive(Debug, Clone, Serialize)] pub struct SyncStatus { @@ -109,7 +109,7 @@ impl Default for QueryOptions { fn default() -> Self { Self { timeout_ms: 5000, - limit: 10000, + limit: HARD_LIMIT_MAX, } } } @@ -145,13 +145,8 @@ pub async fn execute_query_postgres( sql.to_string() }; - // Add LIMIT if not present - let sql_upper = sql.to_uppercase(); - let sql = if !sql_upper.contains("LIMIT") { - format!("{} LIMIT {}", sql, options.limit) - } else { - sql - }; + // Add LIMIT if not present (AST-based detection to avoid string matching bypass) + let sql = append_limit_if_missing(&sql, options.limit); // Convert '0x...' hex literals to '\x...' bytea literals for PostgreSQL // Only replace hex values (40+ chars), not short '0x' prefixes used in concat() @@ -226,6 +221,21 @@ pub async fn execute_query_postgres( }) } +fn append_limit_if_missing(sql: &str, limit: i64) -> String { + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + if let Ok(stmts) = Parser::parse_sql(&dialect, sql) { + if let Some(sqlparser::ast::Statement::Query(query)) = stmts.first() { + if query.limit_clause.is_none() { + return format!("{sql} LIMIT {limit}"); + } + } + } + sql.to_string() +} + pub fn format_column_json(row: &tokio_postgres::Row, idx: usize) -> serde_json::Value { let col = &row.columns()[idx]; diff --git a/tests/api_live_test.rs b/tests/api_live_test.rs index 69e4de0..4608cd4 100644 --- a/tests/api_live_test.rs +++ b/tests/api_live_test.rs @@ -361,37 +361,57 @@ async fn test_query_live_returns_sse() { #[test] fn test_inject_block_filter_blocks_table() { let sql = "SELECT num, hash FROM blocks ORDER BY num DESC LIMIT 1"; - let filtered = inject_block_filter(sql, 100); - assert!(filtered.contains("num = 100"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 100).unwrap(); + assert!(filtered.contains("blocks.num = 100"), "got: {filtered}"); assert!(filtered.contains("ORDER BY"), "should preserve ORDER BY"); } #[test] fn test_inject_block_filter_txs_table() { let sql = "SELECT * FROM txs ORDER BY block_num DESC LIMIT 10"; - let filtered = inject_block_filter(sql, 200); - assert!(filtered.contains("block_num = 200"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 200).unwrap(); + assert!(filtered.contains("txs.block_num = 200"), "got: {filtered}"); } #[test] fn test_inject_block_filter_logs_table() { let sql = "SELECT * FROM logs WHERE address = '0x123' ORDER BY block_num DESC"; - let filtered = inject_block_filter(sql, 300); - assert!(filtered.contains("block_num = 300"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 300).unwrap(); + assert!(filtered.contains("logs.block_num = 300"), "got: {filtered}"); assert!(filtered.contains("address = '0x123'"), "should preserve existing WHERE"); } #[test] fn test_inject_block_filter_with_existing_where() { let sql = "SELECT * FROM txs WHERE gas_used > 21000 ORDER BY block_num DESC"; - let filtered = inject_block_filter(sql, 400); - assert!(filtered.contains("block_num = 400"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 400).unwrap(); + assert!(filtered.contains("txs.block_num = 400"), "got: {filtered}"); assert!(filtered.contains("gas_used > 21000"), "should preserve existing condition"); } #[test] fn test_inject_block_filter_no_order_by() { let sql = "SELECT COUNT(*) FROM blocks LIMIT 1"; - let filtered = inject_block_filter(sql, 500); - assert!(filtered.contains("num = 500"), "got: {filtered}"); + let filtered = inject_block_filter(sql, 500).unwrap(); + assert!(filtered.contains("blocks.num = 500"), "got: {filtered}"); +} + +#[test] +fn test_inject_block_filter_rejects_union() { + let sql = "SELECT * FROM txs UNION SELECT * FROM logs"; + assert!(inject_block_filter(sql, 100).is_err()); +} + +#[test] +fn test_inject_block_filter_rejects_non_select() { + let sql = "INSERT INTO txs VALUES (1)"; + assert!(inject_block_filter(sql, 100).is_err()); +} + +#[test] +fn test_inject_block_filter_where_keyword_in_string_literal() { + let sql = "SELECT * FROM txs WHERE input = 'WHERE clause test'"; + let filtered = inject_block_filter(sql, 100).unwrap(); + assert!(filtered.contains("txs.block_num = 100"), "got: {filtered}"); + assert!(filtered.contains("'WHERE clause test'"), "should preserve string literal"); }