diff --git a/crates/experimentation_platform/src/api/experiments/handlers.rs b/crates/experimentation_platform/src/api/experiments/handlers.rs index 2e0b864e1..9c81a7576 100644 --- a/crates/experimentation_platform/src/api/experiments/handlers.rs +++ b/crates/experimentation_platform/src/api/experiments/handlers.rs @@ -870,7 +870,7 @@ pub async fn discard( Ok((updated_experiment, config_version_id)) } -pub async fn get_applicable_variants_helper( +pub fn get_applicable_variants_helper( db_conn: &mut PooledConnection>, context: Map, dimensions_info: &HashMap, @@ -956,8 +956,7 @@ async fn get_applicable_variants_handler( &dimensions_info, identifier, &workspace_context, - ) - .await?; + )?; let variants = exps .into_iter() diff --git a/crates/service_utils/src/db.rs b/crates/service_utils/src/db.rs index f272a2e73..86f2f6f9f 100644 --- a/crates/service_utils/src/db.rs +++ b/crates/service_utils/src/db.rs @@ -2,7 +2,63 @@ use diesel::{ PgConnection, r2d2::{ConnectionManager, Pool}, }; +use superposition_types::{DBConnection, result}; pub mod utils; pub type PgSchemaConnectionPool = Pool>; + +/// Helper macro to run a database query with connection management and error handling. +/// Example usage: +/// ```rust,ignore +/// run_query!(db_pool, conn, { +/// // Your query logic here, using `conn` as the database connection +/// }); +/// ``` +#[macro_export] +macro_rules! run_query { + ($db_pool:expr, $conn:ident, $body:expr) => {{ + let mut $conn = $db_pool.get().map_err(|e| { + superposition_macros::unexpected_error!( + "Unable to get db connection from pool, error: {}", + e + ) + })?; + diesel::Connection::set_prepared_statement_cache_size( + &mut $conn, + diesel::connection::CacheSize::Disabled, + ); + + $body + }}; +} + +/// Helper function to run a database transaction with connection management and error handling. +/// Example usage: +/// ```rust,ignore +/// run_transaction(&db_pool, |conn| { +/// // Your transactional query logic here, using `conn` as the database +/// // connection within the transaction +/// Ok(result) // Return a result from the transaction block +/// }); +/// ``` +pub fn run_transaction( + db_pool: &PgSchemaConnectionPool, + query_fn: F, +) -> result::Result +where + F: FnOnce(&mut DBConnection) -> result::Result, +{ + let mut conn = db_pool.get().map_err(|e| { + superposition_macros::unexpected_error!( + "Unable to get db connection from pool, error: {}", + e + ) + })?; + diesel::Connection::set_prepared_statement_cache_size( + &mut conn, + diesel::connection::CacheSize::Disabled, + ); + + diesel::Connection::transaction(&mut conn, query_fn) +} diff --git a/crates/service_utils/src/helpers.rs b/crates/service_utils/src/helpers.rs index 74de9af24..a5e0c92c9 100644 --- a/crates/service_utils/src/helpers.rs +++ b/crates/service_utils/src/helpers.rs @@ -35,11 +35,15 @@ use superposition_types::{ }, superposition_schema::superposition::workspaces, }, - result::{self}, + result, }; -use crate::encryption::{EncryptionError, decrypt_secret, decrypt_workspace_key}; -use crate::service::types::{AppState, SchemaName, WorkspaceContext}; +use crate::{ + db::PgSchemaConnectionPool, + encryption::{EncryptionError, decrypt_secret, decrypt_workspace_key}, + run_query, + service::types::{AppState, SchemaName, WorkspaceContext}, +}; // using named group to capture which type (secrets/variables) the regex was // because variables and secrets need to be handled differently inside webhook execution @@ -206,11 +210,18 @@ pub fn parse_config_tags( pub fn get_workspace( workspace_schema_name: &SchemaName, - db_conn: &mut DBConnection, + db_pool: &PgSchemaConnectionPool, ) -> result::Result { - let workspace = workspaces::dsl::workspaces - .filter(workspaces::workspace_schema_name.eq(workspace_schema_name.to_string())) - .get_result::(db_conn)?; + let workspace = run_query!( + db_pool, + conn, + workspaces::dsl::workspaces + .filter( + workspaces::workspace_schema_name.eq(workspace_schema_name.to_string()), + ) + .get_result::(&mut conn) + )?; + Ok(workspace) } diff --git a/crates/service_utils/src/middlewares/auth_n/helpers.rs b/crates/service_utils/src/middlewares/auth_n/helpers.rs index 5e017a05a..8b3993ebf 100644 --- a/crates/service_utils/src/middlewares/auth_n/helpers.rs +++ b/crates/service_utils/src/middlewares/auth_n/helpers.rs @@ -1,6 +1,6 @@ use actix_web::{HttpRequest, web::Data}; use diesel::{ - Connection, ExpressionMethods, RunQueryDsl, + ExpressionMethods, RunQueryDsl, query_dsl::methods::{OrderDsl, SelectDsl}, }; use superposition_types::database::superposition_schema::superposition::organisations; @@ -20,9 +20,6 @@ pub(super) fn fetch_org_ids_from_db( match app_state.db_pool.get() { Ok(mut conn) => { - conn.set_prepared_statement_cache_size( - diesel::connection::CacheSize::Disabled, - ); let orgs = organisations::table .order(organisations::created_at.desc()) .select(organisations::id) diff --git a/crates/service_utils/src/middlewares/workspace_context.rs b/crates/service_utils/src/middlewares/workspace_context.rs index d1c76af19..474580777 100644 --- a/crates/service_utils/src/middlewares/workspace_context.rs +++ b/crates/service_utils/src/middlewares/workspace_context.rs @@ -10,7 +10,7 @@ use actix_web::{ }; use futures_util::future::LocalBoxFuture; use regex::Regex; -use superposition_macros::{bad_argument, unexpected_error}; +use superposition_macros::bad_argument; use crate::helpers::get_workspace; use crate::{ @@ -137,14 +137,8 @@ where (true, Some(workspace_id)) => { let schema = format!("{}_{}", *organisation, *workspace_id); let schema_name = SchemaName(schema); - let workspace_settings = { - let mut db_conn = app_state - .db_pool - .get() - .map_err(|err| unexpected_error!("{}", err))?; - - get_workspace(&schema_name, &mut db_conn)? - }; + let workspace_settings = + get_workspace(&schema_name, &app_state.db_pool)?; req.extensions_mut().insert(workspace_id.clone()); req.extensions_mut().insert(WorkspaceContext { diff --git a/crates/superposition/src/organisation/handlers.rs b/crates/superposition/src/organisation/handlers.rs index 5ab4ba816..3d990333c 100644 --- a/crates/superposition/src/organisation/handlers.rs +++ b/crates/superposition/src/organisation/handlers.rs @@ -1,11 +1,11 @@ use actix_web::{ Scope, get, post, routes, - web::{Json, Path, Query}, + web::{Data, Json, Path, Query}, }; use chrono::Utc; use diesel::prelude::*; use idgenerator::IdInstance; -use service_utils::service::types::DbConnection; +use service_utils::{run_query, service::types::AppState}; use superposition_derives::authorized; use superposition_types::{ PaginatedResponse, User, @@ -32,11 +32,9 @@ pub fn endpoints() -> Scope { #[post("")] pub async fn create_handler( request: Json, - db_conn: DbConnection, + state: Data, user: User, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; - // Generating a numeric ID from IdInstance and prefixing it with `orgid` let numeric_id = IdInstance::next_id(); let org_id = format!("orgid{}", numeric_id); @@ -58,9 +56,11 @@ pub async fn create_handler( updated_by: user.get_username(), }; - let new_org = diesel::insert_into(organisations::table) - .values(&new_org) - .get_result(&mut conn)?; + let new_org = run_query!(state.db_pool, conn, { + diesel::insert_into(organisations::table) + .values(&new_org) + .get_result(&mut conn) + })?; Ok(Json(new_org)) } @@ -72,18 +72,19 @@ pub async fn create_handler( pub async fn update_handler( org_id: Path, request: Json, - db_conn: DbConnection, + state: Data, user: User, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; let org_id = org_id.into_inner(); let now = Utc::now(); let req = request.into_inner(); - let updated_org = diesel::update(organisations::table) - .filter(organisations::id.eq(org_id)) - .set((req, updated_at.eq(now), updated_by.eq(user.get_email()))) - .get_result(&mut conn)?; + let updated_org = run_query!(state.db_pool, conn, { + diesel::update(organisations::table) + .filter(organisations::id.eq(org_id)) + .set((req, updated_at.eq(now), updated_by.eq(user.get_email()))) + .get_result(&mut conn) + })?; Ok(Json(updated_org)) } @@ -92,13 +93,15 @@ pub async fn update_handler( #[get("/{org_id}")] pub async fn get_handler( org_id: Path, - db_conn: DbConnection, + state: Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; - - let org = organisations::table - .find(org_id.as_str()) - .first::(&mut conn)?; + let org = run_query!( + state.db_pool, + conn, + organisations::table + .find(org_id.as_str()) + .first::(&mut conn) + )?; Ok(Json(org)) } @@ -106,21 +109,27 @@ pub async fn get_handler( #[authorized] #[get("")] pub async fn list_handler( - db_conn: DbConnection, + state: Data, filters: Query, ) -> superposition::Result>> { - let DbConnection(mut conn) = db_conn; - if let Some(true) = filters.all { - let result: Vec = organisations::table - .order(organisations::created_at.desc()) - .get_results(&mut conn)?; + let result = run_query!( + state.db_pool, + conn, + organisations::table + .order(organisations::created_at.desc()) + .get_results(&mut conn) + )?; return Ok(Json(PaginatedResponse::all(result))); } // Get total count of organisations - let total_items: i64 = organisations::table.count().get_result(&mut conn)?; + let total_items = run_query!( + state.db_pool, + conn, + organisations::table.count().get_result(&mut conn) + )?; // Set up pagination let limit = filters.count.unwrap_or(10); @@ -136,7 +145,7 @@ pub async fn list_handler( } // Get paginated results - let data: Vec = builder.load(&mut conn)?; + let data = run_query!(state.db_pool, conn, builder.load(&mut conn))?; let total_pages = (total_items as f64 / limit as f64).ceil() as i64; diff --git a/crates/superposition/src/resolve/handlers.rs b/crates/superposition/src/resolve/handlers.rs index dc54bdc28..95a3c2667 100644 --- a/crates/superposition/src/resolve/handlers.rs +++ b/crates/superposition/src/resolve/handlers.rs @@ -9,7 +9,10 @@ use context_aware_config::api::config::helpers::{ }; use experimentation_platform::api::experiments::handlers::get_applicable_variants_helper; use serde_json::{Map, Value}; -use service_utils::service::types::{AppState, DbConnection, WorkspaceContext}; +use service_utils::{ + run_query, + service::types::{AppState, WorkspaceContext}, +}; use superposition_derives::authorized; use superposition_types::{ api::config::{ContextPayload, MergeStrategy, ResolveConfigQuery}, @@ -32,19 +35,22 @@ async fn resolve_with_exp_handler( req: HttpRequest, body: Option>, merge_strategy: Header, - db_conn: DbConnection, dimension_params: DimensionQuery, query_filters: superposition_query::Query, identifier_query: superposition_query::Query, workspace_context: WorkspaceContext, state: Data, ) -> superposition::Result { - let DbConnection(mut conn) = db_conn; let query_filters = query_filters.into_inner(); let identifier_query = identifier_query.into_inner(); - let max_created_at = get_max_created_at(&mut conn, &workspace_context.schema_name) - .map_err(|e| log::error!("failed to fetch max timestamp from event_log : {e}")) - .ok(); + // TODO: Granularise the connection usage in this function once all crates are migrated + let max_created_at = run_query!( + state.db_pool, + conn, + get_max_created_at(&mut conn, &workspace_context.schema_name) + ) + .map_err(|e| log::error!("failed to fetch max timestamp from event_log : {e}")) + .ok(); if identifier_query.identifier.is_none() && is_not_modified(max_created_at, &req) { return Ok(HttpResponse::NotModified().finish()); @@ -59,38 +65,57 @@ async fn resolve_with_exp_handler( // This value is separately needed, as in the following check the value before the modification is required let config_ver = config_version.to_owned(); - let mut config = generate_config_from_version( - &mut config_version, - &mut conn, - &workspace_context.schema_name, + // TODO: Granularise the connection usage in this function once all crates are migrated + let mut config = run_query!( + state.db_pool, + conn, + generate_config_from_version( + &mut config_version, + &mut conn, + &workspace_context.schema_name, + ) )?; if let (None, Some(identifier)) = (config_ver, identifier_query.identifier) { let context_map: &Map = &query_data; - let (applicable_variants, _) = get_applicable_variants_helper( - &mut conn, - context_map.clone(), - &config.dimensions, - identifier, - &workspace_context, - ) - .await?; + // TODO: Granularise the connection usage in this function once all crates are migrated + let (applicable_variants, _) = run_query!( + state.db_pool, + conn, + get_applicable_variants_helper( + &mut conn, + context_map.clone(), + &config.dimensions, + identifier, + &workspace_context, + ) + )?; query_data.insert("variantIds".to_string(), applicable_variants.into()); } - let resolved_config = resolve( - &mut config, - query_data, - merge_strategy, - &mut conn, - &query_filters, - &workspace_context, - &state.master_encryption_key, + // TODO: Granularise the connection usage in this function once all crates are migrated + let resolved_config = run_query!( + state.db_pool, + conn, + resolve( + &mut config, + query_data, + merge_strategy, + &mut conn, + &query_filters, + &workspace_context, + &state.master_encryption_key, + ) )?; let mut resp = HttpResponse::Ok(); add_last_modified_to_header(max_created_at, is_smithy, &mut resp); - add_audit_id_to_header(&mut conn, &mut resp, &workspace_context.schema_name); + // TODO: Granularise the connection usage in this function once all crates are migrated + run_query!( + state.db_pool, + conn, + add_audit_id_to_header(&mut conn, &mut resp, &workspace_context.schema_name) + ); add_config_version_to_header(&config_version, &mut resp); Ok(resp.json(resolved_config)) } diff --git a/crates/superposition/src/webhooks/handlers.rs b/crates/superposition/src/webhooks/handlers.rs index abf4c0b69..44ab6e76a 100644 --- a/crates/superposition/src/webhooks/handlers.rs +++ b/crates/superposition/src/webhooks/handlers.rs @@ -6,10 +6,14 @@ use actix_web::{ use chrono::Utc; use context_aware_config::helpers::validate_change_reason; use diesel::{ExpressionMethods, PgArrayExpressionMethods, QueryDsl, RunQueryDsl}; -use service_utils::service::types::{AppState, DbConnection, WorkspaceContext}; +use service_utils::{ + db::run_transaction, + run_query, + service::types::{AppState, WorkspaceContext}, +}; use superposition_derives::authorized; use superposition_types::{ - PaginatedResponse, User, + DBConnection, PaginatedResponse, User, api::webhook::{CreateWebhookRequest, UpdateWebhookRequest, WebhookName}, custom_query::PaginationParams, database::{ @@ -33,21 +37,29 @@ pub fn endpoints() -> Scope { async fn create_handler( workspace_context: WorkspaceContext, request: Json, - db_conn: DbConnection, user: User, state: Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; let req = request.into_inner(); - validate_change_reason( - &workspace_context, - &req.change_reason, - &mut conn, - &state.master_encryption_key, + // TODO: Granularise the connection usage in this function once all crates are migrated + run_query!( + state.db_pool, + conn, + validate_change_reason( + &workspace_context, + &req.change_reason, + &mut conn, + &state.master_encryption_key, + ) )?; - validate_events(&req.events, None, &workspace_context.schema_name, &mut conn)?; + validate_events( + &req.events, + None, + &workspace_context.schema_name, + &state.db_pool, + )?; let now = Utc::now(); let webhook_data = Webhook { name: req.name.to_string(), @@ -67,12 +79,16 @@ async fn create_handler( last_modified_at: now, }; - diesel::insert_into(webhooks::table) - .values(&webhook_data) - .schema_name(&workspace_context.schema_name) - .execute(&mut conn)?; + let created = run_query!( + state.db_pool, + conn, + diesel::insert_into(webhooks::table) + .values(&webhook_data) + .schema_name(&workspace_context.schema_name) + .get_result::(&mut conn) + )?; - Ok(Json(webhook_data)) + Ok(Json(created)) } #[authorized] @@ -80,20 +96,23 @@ async fn create_handler( async fn update_handler( workspace_context: WorkspaceContext, params: web::Path, - db_conn: DbConnection, user: User, request: Json, state: Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; let req = request.into_inner(); let w_name: String = params.into_inner().into(); - validate_change_reason( - &workspace_context, - &req.change_reason, - &mut conn, - &state.master_encryption_key, + // TODO: Granularise the connection usage in this function once all crates are migrated + run_query!( + state.db_pool, + conn, + validate_change_reason( + &workspace_context, + &req.change_reason, + &mut conn, + &state.master_encryption_key, + ) )?; if let Some(webhook_events) = &req.events { @@ -101,19 +120,23 @@ async fn update_handler( webhook_events, Some(&w_name), &workspace_context.schema_name, - &mut conn, + &state.db_pool, )?; } - let update = diesel::update(webhooks::table) - .filter(webhooks::name.eq(w_name)) - .set(( - req, - last_modified_at.eq(Utc::now()), - last_modified_by.eq(user.get_email()), - )) - .schema_name(&workspace_context.schema_name) - .get_result::(&mut conn)?; + let update = run_query!( + state.db_pool, + conn, + diesel::update(webhooks::table) + .filter(webhooks::name.eq(w_name)) + .set(( + req, + last_modified_at.eq(Utc::now()), + last_modified_by.eq(user.get_email()), + )) + .schema_name(&workspace_context.schema_name) + .get_result::(&mut conn) + )?; Ok(Json(update)) } @@ -123,13 +146,12 @@ async fn update_handler( async fn get_handler( workspace_context: WorkspaceContext, params: web::Path, - db_conn: DbConnection, + state: Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; let webhook_row = fetch_webhook( ¶ms.into_inner(), &workspace_context.schema_name, - &mut conn, + &state.db_pool, )?; Ok(Json(webhook_row)) } @@ -138,22 +160,28 @@ async fn get_handler( #[get("")] async fn list_handler( workspace_context: WorkspaceContext, - db_conn: DbConnection, + state: Data, pagination: Query, ) -> superposition::Result>> { - let DbConnection(mut conn) = db_conn; - if let Some(true) = pagination.all { - let result: Vec = webhooks - .schema_name(&workspace_context.schema_name) - .get_results(&mut conn)?; + let result: Vec = run_query!( + state.db_pool, + conn, + webhooks + .schema_name(&workspace_context.schema_name) + .get_results(&mut conn) + )?; return Ok(Json(PaginatedResponse::all(result))); } - let total_items: i64 = webhooks - .count() - .schema_name(&workspace_context.schema_name) - .get_result(&mut conn)?; + let total_items: i64 = run_query!( + state.db_pool, + conn, + webhooks + .count() + .schema_name(&workspace_context.schema_name) + .get_result(&mut conn) + )?; let limit = pagination.count.unwrap_or(10); let mut builder = webhooks .schema_name(&workspace_context.schema_name) @@ -164,7 +192,7 @@ async fn list_handler( let offset = (page - 1) * limit; builder = builder.offset(offset); } - let data: Vec = builder.load(&mut conn)?; + let data: Vec = run_query!(state.db_pool, conn, builder.load(&mut conn))?; let total_pages = (total_items as f64 / limit as f64).ceil() as i64; Ok(Json(PaginatedResponse { @@ -179,23 +207,25 @@ async fn list_handler( async fn delete_handler( workspace_context: WorkspaceContext, params: web::Path, - db_conn: DbConnection, + state: Data, user: User, ) -> superposition::Result { - let DbConnection(mut conn) = db_conn; let w_name: String = params.into_inner().into(); - diesel::update(webhooks::table) - .filter(webhooks::name.eq(&w_name)) - .set(( - webhooks::last_modified_at.eq(Utc::now()), - webhooks::last_modified_by.eq(user.get_email()), - )) - .schema_name(&workspace_context.schema_name) - .execute(&mut conn)?; - diesel::delete(webhooks.filter(webhooks::name.eq(&w_name))) - .schema_name(&workspace_context.schema_name) - .execute(&mut conn)?; + run_transaction(&state.db_pool, |conn: &mut DBConnection| { + diesel::update(webhooks::table) + .filter(webhooks::name.eq(&w_name)) + .set(( + webhooks::last_modified_at.eq(Utc::now()), + webhooks::last_modified_by.eq(user.get_email()), + )) + .schema_name(&workspace_context.schema_name) + .execute(conn)?; + diesel::delete(webhooks.filter(webhooks::name.eq(&w_name))) + .schema_name(&workspace_context.schema_name) + .execute(conn)?; + Ok(()) + })?; Ok(HttpResponse::NoContent().finish()) } @@ -204,13 +234,16 @@ async fn delete_handler( async fn get_by_event_handler( workspace_context: WorkspaceContext, params: web::Path, - db_conn: DbConnection, + state: Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; let event = params.into_inner(); - let webhook_row = webhooks - .filter(webhooks::events.contains(vec![event])) - .schema_name(&workspace_context.schema_name) - .first::(&mut conn)?; + let webhook_row = run_query!( + state.db_pool, + conn, + webhooks + .filter(webhooks::events.contains(vec![event])) + .schema_name(&workspace_context.schema_name) + .first(&mut conn) + )?; Ok(Json(webhook_row)) } diff --git a/crates/superposition/src/webhooks/helper.rs b/crates/superposition/src/webhooks/helper.rs index 2d556df50..c2a1472c8 100644 --- a/crates/superposition/src/webhooks/helper.rs +++ b/crates/superposition/src/webhooks/helper.rs @@ -1,24 +1,27 @@ +use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl}; +use service_utils::{db::PgSchemaConnectionPool, run_query}; use superposition_macros::bad_argument; use superposition_types::{ - database::models::others::{Webhook, WebhookEvent}, + database::{ + models::others::{Webhook, WebhookEvent}, + schema::webhooks::{self, dsl}, + }, result as superposition, }; -use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl}; -use diesel::{ - PgConnection, - r2d2::{ConnectionManager, PooledConnection}, -}; -use superposition_types::database::schema::webhooks::{self, dsl}; - pub fn validate_events( events: &[WebhookEvent], exclude_webhook: Option<&String>, schema_name: &String, - conn: &mut PooledConnection>, + db_pool: &PgSchemaConnectionPool, ) -> superposition::Result<()> { - let result: Vec = - dsl::webhooks.schema_name(schema_name).get_results(conn)?; + let result: Vec = run_query!( + db_pool, + conn, + dsl::webhooks + .schema_name(schema_name) + .get_results(&mut conn) + )?; for webhook in result { if exclude_webhook == Some(&webhook.name) { continue; @@ -35,10 +38,16 @@ pub fn validate_events( pub fn fetch_webhook( w_name: &String, schema_name: &String, - conn: &mut PooledConnection>, + db_pool: &PgSchemaConnectionPool, ) -> superposition::Result { - Ok(dsl::webhooks - .filter(webhooks::name.eq(w_name)) - .schema_name(schema_name) - .get_result::(conn)?) + let webhook = run_query!( + db_pool, + conn, + dsl::webhooks + .filter(webhooks::name.eq(w_name)) + .schema_name(schema_name) + .get_result::(&mut conn) + )?; + + Ok(webhook) } diff --git a/crates/superposition/src/workspace/handlers.rs b/crates/superposition/src/workspace/handlers.rs index 71f9b2257..93ba3ecba 100644 --- a/crates/superposition/src/workspace/handlers.rs +++ b/crates/superposition/src/workspace/handlers.rs @@ -6,26 +6,27 @@ use actix_web::{ }; use chrono::Utc; use diesel::{ - Connection, ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl, - TextExpressionMethods, + ExpressionMethods, PgConnection, QueryDsl, RunQueryDsl, TextExpressionMethods, connection::SimpleConnection, r2d2::{ConnectionManager, PooledConnection}, }; use regex::Regex; use service_utils::{ + db::run_transaction, encryption::{ encrypt_workspace_key, generate_encryption_key, rotate_workspace_encryption_key_helper, }, helpers::get_workspace, + run_query, service::types::{ - AppState, DbConnection, OrganisationId, SchemaName, WorkspaceContext, WorkspaceId, + AppState, OrganisationId, SchemaName, WorkspaceContext, WorkspaceId, }, }; use superposition_derives::authorized; use superposition_macros::{bad_argument, db_error, unexpected_error, validation_error}; use superposition_types::{ - PaginatedResponse, User, + DBConnection, PaginatedResponse, User, api::{ I64Update, workspace::{ @@ -82,15 +83,18 @@ pub fn endpoints(scope: Scope) -> Scope { #[get("/{workspace_name}")] async fn get_handler( workspace_name: Path, - db_conn: DbConnection, + state: Data, org_id: OrganisationId, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; let workspace_name = workspace_name.into_inner(); - let workspace: Workspace = workspaces::dsl::workspaces - .filter(workspaces::organisation_id.eq(&org_id.0)) - .filter(workspaces::workspace_name.eq(workspace_name)) - .get_result(&mut conn)?; + let workspace: Workspace = run_query!( + state.db_pool, + conn, + workspaces::dsl::workspaces + .filter(workspaces::organisation_id.eq(&org_id.0)) + .filter(workspaces::workspace_name.eq(workspace_name)) + .get_result(&mut conn) + )?; let response = WorkspaceResponse::from(workspace); Ok(Json(response)) } @@ -99,15 +103,17 @@ async fn get_handler( #[post("")] async fn create_handler( request: Json, - db_conn: DbConnection, + state: Data, org_id: OrganisationId, user: User, - state: web::Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; - let org_info: Organisation = organisations::dsl::organisations - .filter(organisations::id.eq(&org_id.0)) - .get_result::(&mut conn)?; + let org_info = run_query!( + state.db_pool, + conn, + organisations::dsl::organisations + .filter(organisations::id.eq(&org_id.0)) + .get_result::(&mut conn) + )?; let timestamp = Utc::now(); let request = request.into_inner(); let email = user.get_email(); @@ -154,14 +160,13 @@ async fn create_handler( }; let created_workspace = - conn.transaction::(|transaction_conn| { - let mut inserted_workspace: Vec = - diesel::insert_into(workspaces::dsl::workspaces) - .values(workspace) - .get_results(transaction_conn)?; + run_transaction(&state.db_pool, |conn: &mut DBConnection| { + let inserted_workspace = diesel::insert_into(workspaces::table) + .values(workspace) + .get_result(conn)?; - setup_workspace_schema(transaction_conn, &workspace_schema_name)?; - Ok(inserted_workspace.remove(0)) + setup_workspace_schema(conn, &workspace_schema_name)?; + Ok::(inserted_workspace) })?; let response = WorkspaceResponse::from(created_workspace); Ok(Json(response)) @@ -174,7 +179,7 @@ async fn create_handler( async fn update_handler( workspace_name: web::Path, request: Json, - db_conn: DbConnection, + state: Data, org_id: OrganisationId, user: User, ) -> superposition::Result> { @@ -184,17 +189,17 @@ async fn update_handler( let schema_name = SchemaName(format!("{}_{}", *org_id, workspace_name)); // TODO: mandatory dimensions updation needs to be validated // for the existance of the dimensions in the workspace - let DbConnection(mut conn) = db_conn; - if let Some(I64Update::Add(version)) = request.config_version { - let _ = config_versions::config_versions - .select(config_versions::id) - .filter(config_versions::id.eq(version)) - .schema_name(&schema_name) - .first::(&mut conn)?; - } let updated_workspace = - conn.transaction::(|transaction_conn| { + run_transaction(&state.db_pool, |conn: &mut DBConnection| { + if let Some(I64Update::Add(version)) = request.config_version { + config_versions::config_versions + .select(config_versions::id) + .filter(config_versions::id.eq(version)) + .schema_name(&schema_name) + .first::(conn)?; + } + let updated_workspace = diesel::update(workspaces::table) .filter(workspaces::organisation_id.eq(&org_id.0)) .filter(workspaces::workspace_name.eq(workspace_name)) @@ -203,13 +208,13 @@ async fn update_handler( workspaces::last_modified_by.eq(user.email), workspaces::last_modified_at.eq(timestamp), )) - .get_result::(transaction_conn) + .get_result::(conn) .map_err(|err| { log::error!("failed to update workspace with error: {}", err); err })?; - Ok(updated_workspace) + Ok::(updated_workspace) })?; let response = WorkspaceResponse::from(updated_workspace); Ok(Json(response)) @@ -218,19 +223,22 @@ async fn update_handler( #[authorized] #[get("")] async fn list_handler( - db_conn: DbConnection, + state: Data, pagination_filters: Query, filters: Query, org_id: OrganisationId, ) -> superposition::Result>> { - let DbConnection(mut conn) = db_conn; if let Some(true) = pagination_filters.all { - let result: Vec = workspaces::dsl::workspaces - .filter(workspaces::organisation_id.eq(&org_id.0)) - .get_results::(&mut conn)? - .into_iter() - .map(WorkspaceResponse::from) - .collect(); + let result = run_query!( + state.db_pool, + conn, + workspaces::dsl::workspaces + .filter(workspaces::organisation_id.eq(&org_id.0)) + .get_results::(&mut conn) + )? + .into_iter() + .map(WorkspaceResponse::from) + .collect::>(); return Ok(Json(PaginatedResponse::all(result))); }; @@ -250,7 +258,11 @@ async fn list_handler( let count_query = query_builder(&filters); let base_query = query_builder(&filters); - let n_types: i64 = count_query.count().get_result(&mut conn)?; + let n_types = run_query!( + state.db_pool, + conn, + count_query.count().get_result(&mut conn) + )?; let limit = pagination_filters.count.unwrap_or(10); let mut builder = base_query .order(workspaces::dsl::created_at.desc()) @@ -259,11 +271,11 @@ async fn list_handler( let offset = (page - 1) * limit; builder = builder.offset(offset); } - let workspaces: Vec = builder - .load::(&mut conn)? - .into_iter() - .map(WorkspaceResponse::from) - .collect(); + let workspaces = + run_query!(state.db_pool, conn, builder.load::(&mut conn))? + .into_iter() + .map(WorkspaceResponse::from) + .collect::>(); let total_pages = (n_types as f64 / limit as f64).ceil() as i64; Ok(Json(PaginatedResponse { total_pages, @@ -308,27 +320,28 @@ fn validate_workspace_name(workspace_name: &String) -> superposition::Result<()> #[post("/{workspace_name}/db/migrate")] async fn migrate_schema_handler( workspace_name: Path, - db_conn: DbConnection, org_id: OrganisationId, state: Data, user: User, ) -> superposition::Result> { let workspace_name = workspace_name.into_inner(); - let DbConnection(mut conn) = db_conn; let schema_name = SchemaName(format!("{}_{}", *org_id, &workspace_name)); - let workspace = get_workspace(&schema_name, &mut conn)?; + let workspace = get_workspace(&schema_name, &state.db_pool)?; - conn.transaction::<(), superposition::AppError, _>(|transaction_conn| { - setup_workspace_schema(transaction_conn, &workspace.workspace_schema_name)?; + run_transaction(&state.db_pool, |conn: &mut DBConnection| { + setup_workspace_schema(conn, &workspace.workspace_schema_name)?; if workspace.encryption_key.is_empty() { match state.master_encryption_key { Some(ref master_encryption_key) => { let new_key = generate_encryption_key(); - let encrypted_key = - encrypt_workspace_key(&new_key, &master_encryption_key.current_key).map_err(|e| { - log::error!("Failed to encrypt workspace key: {}", e); - unexpected_error!("Failed to encrypt workspace key") - })?; + let encrypted_key = encrypt_workspace_key( + &new_key, + &master_encryption_key.current_key, + ) + .map_err(|e| { + log::error!("Failed to encrypt workspace key: {}", e); + unexpected_error!("Failed to encrypt workspace key") + })?; diesel::update(workspaces::table) .filter(workspaces::organisation_id.eq(&org_id.0)) @@ -336,9 +349,9 @@ async fn migrate_schema_handler( .set(( workspaces::encryption_key.eq(encrypted_key), workspaces::last_modified_by.eq(user.get_username()), - workspaces::last_modified_at.eq(Utc::now()) + workspaces::last_modified_at.eq(Utc::now()), )) - .execute(transaction_conn)?; + .execute(conn)?; } None => { log::warn!( @@ -349,7 +362,7 @@ async fn migrate_schema_handler( } } } - Ok(()) + Ok::<(), superposition::AppError>(()) })?; let response = WorkspaceResponse::from(workspace); @@ -361,12 +374,9 @@ async fn migrate_schema_handler( pub async fn rotate_encryption_key_handler( workspace_name: Path, user: User, - db_conn: DbConnection, org_id: OrganisationId, state: Data, ) -> superposition::Result> { - let DbConnection(mut conn) = db_conn; - let Some(ref master_encryption_key) = state.master_encryption_key else { log::error!("Master encryption key not configured"); return Err(bad_argument!( @@ -375,7 +385,7 @@ pub async fn rotate_encryption_key_handler( }; let schema_name = SchemaName(format!("{}_{}", *org_id, workspace_name.into_inner())); - let workspace = get_workspace(&schema_name, &mut conn)?; + let workspace = get_workspace(&schema_name, &state.db_pool)?; let workspace_context = WorkspaceContext { schema_name, organisation_id: org_id, @@ -383,8 +393,8 @@ pub async fn rotate_encryption_key_handler( settings: workspace, }; - let total_secrets_re_encrypted = conn - .transaction::(|conn| { + let total_secrets_re_encrypted = + run_transaction(&state.db_pool, |conn: &mut DBConnection| { rotate_workspace_encryption_key_helper( &workspace_context, conn, diff --git a/flake.nix b/flake.nix index 81d8b79ff..05273bb6a 100644 --- a/flake.nix +++ b/flake.nix @@ -70,6 +70,7 @@ cargo-watch cargo-edit cargo-msrv + cargo-expand diesel-cli leptosfmt wasm-pack