diff --git a/Cargo.toml b/Cargo.toml index 23821e3eda..1a260f53c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ exclude = [ "pgrx-examples/numeric", "pgrx-examples/pgtrybuilder", "pgrx-examples/operators", + "pgrx-examples/parallel_scan_lwlock", "pgrx-examples/range", "pgrx-examples/schemas", "pgrx-examples/shmem", diff --git a/pgrx-examples/parallel_scan_lwlock/.gitignore b/pgrx-examples/parallel_scan_lwlock/.gitignore new file mode 100644 index 0000000000..066dfdc21c --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/.gitignore @@ -0,0 +1,7 @@ +.DS_Store +.idea/ +/target +*.iml +**/*.rs.bk +Cargo.lock +sql/shmem-1.0.sql diff --git a/pgrx-examples/parallel_scan_lwlock/Cargo.toml b/pgrx-examples/parallel_scan_lwlock/Cargo.toml new file mode 100644 index 0000000000..f0428e9de4 --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "parallel_scan_lwlock" +version = "0.0.0" +edition = "2024" +publish = false + +[lib] +crate-type = ["cdylib", "lib"] + +[[bin]] +name = "pgrx_embed_parallel_scan_lwlock" +path = "./src/bin/pgrx_embed.rs" + +[features] +default = ["pg17"] +pg13 = ["pgrx/pg13", "pgrx-tests/pg13"] +pg14 = ["pgrx/pg14", "pgrx-tests/pg14"] +pg15 = ["pgrx/pg15", "pgrx-tests/pg15"] +pg16 = ["pgrx/pg16", "pgrx-tests/pg16"] +pg17 = ["pgrx/pg17", "pgrx-tests/pg17"] +pg18 = ["pgrx/pg18", "pgrx-tests/pg18"] +pg_test = [] + +[dependencies] +pgrx = { path = "../../pgrx", default-features = false } + +[dev-dependencies] +pgrx-tests = { path = "../../pgrx-tests" } + +# uncomment these if compiling outside of 'pgrx' +# [profile.dev] +# panic = "unwind" + +# [profile.release] +# panic = "unwind" +# opt-level = 3 +# lto = "fat" +# codegen-units = 1 diff --git a/pgrx-examples/parallel_scan_lwlock/README.md b/pgrx-examples/parallel_scan_lwlock/README.md new file mode 100644 index 0000000000..c2b6573502 --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/README.md @@ -0,0 +1,10 @@ +## Postgres Dynamic Shared Memory in Parallel Foreign Scans and LWLock Support + +Important: +> Extensions that use shared memory **must** be loaded via `postgresql.conf`'s +>`shared_preload_libraries` configuration setting. + +The example in [src/lib.rs](src/lib.rs) implements a parallel scan implementation for a "generator" foreign table type +that yields numbers from an integer counter, guarded by a dynamically allocated LWLock. It is meant to illustrate how to +set up a lock for foreign scans, keeping the shared memory handling as simple as possible. Check out the `shmem` example +project for more insights on shared memory handling in `pgrx` extensions. diff --git a/pgrx-examples/parallel_scan_lwlock/parallel_scan_lwlock.control b/pgrx-examples/parallel_scan_lwlock/parallel_scan_lwlock.control new file mode 100644 index 0000000000..20d92e4e15 --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/parallel_scan_lwlock.control @@ -0,0 +1,5 @@ +comment = 'parallel_scan_lwlock: Created by pgrx' +default_version = '@CARGO_VERSION@' +module_pathname = 'parallel_scan_lwlock' +relocatable = false +superuser = false diff --git a/pgrx-examples/parallel_scan_lwlock/src/bin/pgrx_embed.rs b/pgrx-examples/parallel_scan_lwlock/src/bin/pgrx_embed.rs new file mode 100644 index 0000000000..5f5c4d8581 --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/src/bin/pgrx_embed.rs @@ -0,0 +1 @@ +::pgrx::pgrx_embed!(); diff --git a/pgrx-examples/parallel_scan_lwlock/src/lib.rs b/pgrx-examples/parallel_scan_lwlock/src/lib.rs new file mode 100644 index 0000000000..99e39afe81 --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/src/lib.rs @@ -0,0 +1,342 @@ +//! Supports both parallel and non-parallel scans, avoiding code duplication. +//! +//! Initialization: +//! * [_PG_init] for startup initialization +//! * [counter_handler] for startup registration of FDW handler callbacks +//! +//! Scan start: +//! * [pgrx_get_foreign_rel_size] for table-level estimates +//! * [pgrx_get_foreign_paths] provides the scan strategies available (just scan and parallel_scan) +//! * [pgrx_get_foreign_plan] processes clauses and constructs an executable plan +//! +//! Sequential scan execution: +//! * [pgrx_begin_foreign_scan] prepares the execution, initializing executable expressions +//! * [pgrx_iterate_foreign_scan] fetches the next row from the block, advancing to the next if necessary +//! * [pgrx_end_foreign_scan] cleans up on scan complete +//! +//! Parallel scan execution: +//! * [pgrx_begin_foreign_scan] prepares the execution, initializing executable expressions +//! * [pgrx_estimate_dsm_foreign_scan] calculates the amount of shared memory needed for the block list (prefiltered by temporal index) +//! * [pgrx_initialize_dsm_foreign_scan] creates the shared memory locks in the allocated DSM +//! * [pgrx_initialize_worker_foreign_scan] saves the pointer to the mapped shared memory in the worker process space +//! * See `Sequential scan execution` above +//! +//! Other support functions: +//! * [counter_validator] called for option validation +//! +//! # Extension setup example +//! +//! First add the extension to the `shared_preload_libraries` in `postgresql.conf` +//! +//! shared_preload_libraries = 'parallel_scan_lwlock.so' +//! +//! Then run the project with `cargo pgrx run` and execute: +//! +//! CREATE EXTENSION parallel_scan_lwlock; +//! CREATE FOREIGN DATA WRAPPER counter HANDLER counter_handler VALIDATOR counter_validator; +//! CREATE SERVER counter_srv FOREIGN DATA WRAPPER counter; +//! IMPORT FOREIGN SCHEMA my_counter FROM SERVER counter_srv INTO public; +//! +//! ## Query examples +//! +//! select * from my_counter limit 10; +//! +//! select count(distinct(c.counter)) as count from (select counter from my_counter limit 10000000) as c; + +mod routine; + +use std::ffi::{c_void, CStr}; +use pgrx::*; +use pgrx::pg_sys::*; +use pgrx::lwlock::scan::*; + +use crate::routine::*; + +pg_module_magic!(name, version); + +static SCAN_COUNTER_LOCKS: ParallelScanLwLockTranche = ParallelScanLwLockTranche::new(c"scan_counter_lock"); + +#[pg_guard] +pub extern "C-unwind" fn _PG_init() { + if unsafe { !pgrx::pg_sys::process_shared_preload_libraries_in_progress } { + pgrx::error!("this extension must be loaded via shared_preload_libraries."); + } + pg_shmem_init!(SCAN_COUNTER_LOCKS); +} + +// CREATE FOREIGN DATA WRAPPER counter_fdw HANDLER counter_handler VALIDATOR counter_validator; + +/// ```pgrxsql +/// CREATE OR REPLACE FUNCTION "counter_handler"() RETURNS fdw_handler +/// STRICT LANGUAGE c /* Rust */ +/// AS 'MODULE_PATHNAME', '@FUNCTION_NAME@'; +/// ``` +#[pg_extern] +fn counter_handler() -> CounterFdwRoutine { + CounterFdwRoutine(FdwRoutine { + GetForeignRelSize: Some(pgrx_get_foreign_rel_size), + GetForeignPaths: Some(pgrx_get_foreign_paths), + GetForeignPlan: Some(pgrx_get_foreign_plan), + BeginForeignScan: Some(pgrx_begin_foreign_scan), + IterateForeignScan: Some(pgrx_iterate_foreign_scan), + EndForeignScan: Some(pgrx_end_foreign_scan), + ReScanForeignScan: Some(pgrx_re_scan_foreign_scan), + IsForeignScanParallelSafe: Some(pgrx_is_foreign_scan_parallel_safe), + EstimateDSMForeignScan: Some(pgrx_estimate_dsm_foreign_scan), + InitializeDSMForeignScan: Some(pgrx_initialize_dsm_foreign_scan), + ReInitializeDSMForeignScan: Some(pgrx_reinitialize_dsm_foreign_scan), + InitializeWorkerForeignScan: Some(pgrx_initialize_worker_foreign_scan), + ShutdownForeignScan: Some(pgrx_shutdown_foreign_scan), + ImportForeignSchema: Some(pgrx_import_foreign_schema), + ..EMPTY_FDW + }) +} + +/// ```pgrxsql +/// CREATE OR REPLACE FUNCTION "counter_validator"(text[], oid) RETURNS void +/// STRICT LANGUAGE c /* Rust */ +/// AS 'MODULE_PATHNAME', '@FUNCTION_NAME@'; +/// ``` +#[pg_extern] +fn counter_validator(_fcinfo: pg_sys::FunctionCallInfo) { + // No options +} + +#[cfg(test)] +pub mod pg_test { + pub fn setup(_options: Vec<&str>) {} + + pub fn postgresql_conf_options() -> Vec<&'static str> { + vec!["shared_preload_libraries='shmem'"] + } +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_get_foreign_rel_size(_root: *mut PlannerInfo, _baserel: *mut RelOptInfo, _foreigntableid: Oid) { + let expected_query_rows = 1_000_000f64; + (*_baserel).tuples = expected_query_rows; + (*_baserel).rows = expected_query_rows; +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_get_foreign_paths(root: *mut PlannerInfo, baserel: *mut RelOptInfo, _foreigntableid: Oid) { + // Just some funky cost estimates + let startup_cost = 100f64; + let total_cost = 1_000_000_000f64; + + let sequential_path = create_foreignscan_path( + root, + baserel, + std::ptr::null_mut(), // Use default target + (*baserel).rows, + #[cfg(not(any(feature="pg13", feature="pg14", feature="pg15", feature="pg16", feature="pg17")))] + 0, // disabled_nodes + startup_cost, + total_cost, + std::ptr::null_mut(), // pathkeys + std::ptr::null_mut(), // no outer rel + std::ptr::null_mut(), // no extra plan + #[cfg(not(any(feature="pg13", feature="pg14", feature="pg15", feature="pg16")))] + std::ptr::null_mut(), // fdw_restrictinfo + std::ptr::null_mut(), // no fdw_private data while planning + ); + add_path(baserel, sequential_path as *mut Path); + + let work_factor = max_parallel_workers_per_gather as f64; + let parallel_path = create_foreignscan_path( + root, + baserel, + std::ptr::null_mut(), // Use default target + (*baserel).rows, + #[cfg(not(any(feature="pg13", feature="pg14", feature="pg15", feature="pg16", feature="pg17")))] + 0, // disabled_nodes + startup_cost, + total_cost / work_factor, + std::ptr::null_mut(), // pathkeys + std::ptr::null_mut(), // no outer rel + std::ptr::null_mut(), // no extra plan + #[cfg(not(any(feature="pg13", feature="pg14", feature="pg15", feature="pg16")))] + std::ptr::null_mut(), // fdw_restrictinfo + std::ptr::null_mut(), // no fdw_private data while planning + ); + // Path might not be parallel_safe if parallel execution is disabled via max_parallel_workers_per_gather=0, + // failing an assertion in add_partial_path + if (*parallel_path).path.parallel_safe { + (*parallel_path).path.parallel_aware = true; + (*parallel_path).path.parallel_workers = max_parallel_workers_per_gather; + add_partial_path(baserel, parallel_path as *mut Path); + } +} + +#[pg_guard] +extern "C-unwind" fn pgrx_get_foreign_plan(_root: *mut PlannerInfo, baserel: *mut RelOptInfo, _foreigntableid: Oid, _best_path: *mut ForeignPath, tlist: *mut List, scan_clauses: *mut List, outer_plan: *mut Plan) -> *mut ForeignScan { + let where_clauses = unsafe { extract_actual_clauses(scan_clauses, false) }; + unsafe { make_foreignscan( + tlist, + where_clauses, + (*baserel).relid, + std::ptr::null_mut(), // fdw_exprs: no expressions to be evaluated by Postgres + std::ptr::null_mut(), // fdw_private: no fdw data while planning + std::ptr::null_mut(), // fdw_scan_tlist + std::ptr::null_mut(), // fdw_recheck_quals + outer_plan, + )} +} + +struct CounterScanState { + counter_lock: ParallelScanLwLock, +} + +#[pg_guard] +extern "C-unwind" fn pgrx_begin_foreign_scan(foreign_scan_state: *mut ForeignScanState, _eflags: ::std::os::raw::c_int) { + let scan_state = CounterScanState { + counter_lock: SCAN_COUNTER_LOCKS.lock_for(0i64), + }; + unsafe { (*foreign_scan_state).fdw_state = PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(scan_state) as *mut ::std::os::raw::c_void }; +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_iterate_foreign_scan(node: *mut ForeignScanState) -> *mut TupleTableSlot { + let scan_state = &mut *((*node).fdw_state as *mut c_void as *mut CounterScanState); + + let slot = (*node).ss.ss_ScanTupleSlot; + assert!(!slot.is_null()); + (*(*slot).tts_ops).clear.unwrap()(slot); + + let column_count = (*(*slot).tts_tupleDescriptor).natts as usize; + assert_eq!(column_count, 1, "Foreign table column count should be always = 1 in this example"); + let nulls = std::slice::from_raw_parts_mut((*slot).tts_isnull, column_count); + let values: &mut [Datum] = std::slice::from_raw_parts_mut((*slot).tts_values, column_count); + + let mut counter_lock = scan_state.counter_lock.exclusive(); + let current_value = *counter_lock; + *counter_lock = *counter_lock + 1; + + (*nulls)[0] = false; + (*values)[0] = current_value.into_datum().expect("Unexpected conversion error from i64 to Datum"); + + let slot = &mut *slot; + assert!(!slot.tts_tupleDescriptor.is_null()); + assert!(slot.tts_flags & TTS_FLAG_EMPTY as u16 != 0); + slot.tts_flags &= !(TTS_FLAG_EMPTY as u16); + slot.tts_nvalid = (*slot.tts_tupleDescriptor).natts as i16; + slot +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_end_foreign_scan(node: *mut ForeignScanState) { + let _scan_state = &mut *((*node).fdw_state as *mut c_void as *mut CounterScanState); + // Nothing to clean-up, as our fdw state is quite naive +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_re_scan_foreign_scan(node: *mut ForeignScanState) { + pgrx_end_foreign_scan(node); + pgrx_begin_foreign_scan(node, 0); +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_is_foreign_scan_parallel_safe(_root: *mut PlannerInfo, _rel: *mut RelOptInfo, _rte: *mut RangeTblEntry) -> bool { + true +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_estimate_dsm_foreign_scan(_foreign_scan_state: *mut ForeignScanState, _pcxt: *mut ParallelContext) -> Size { + ParallelScanLwLock::::mem_size() +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_initialize_dsm_foreign_scan(foreign_scan_state: *mut ForeignScanState, _pcxt: *mut ParallelContext, shared_mem: *mut ::std::os::raw::c_void) { + warning!("Initializing parallel foreign scan leader (PID: {}) DSM", std::process::id()); + let scan_state = &mut *((*foreign_scan_state).fdw_state as *mut CounterScanState); + scan_state.counter_lock.initialize_dsm_and_register_leader(shared_mem); +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_reinitialize_dsm_foreign_scan(node: *mut ForeignScanState, _pcxt: *mut ParallelContext, _coordinate: *mut ::std::os::raw::c_void) { + let scan_state = &mut *((*node).fdw_state as *mut CounterScanState); + *(scan_state.counter_lock.exclusive()) = 0i64; +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_initialize_worker_foreign_scan(foreign_scan_state: *mut ForeignScanState, _toc: *mut shm_toc, coordinate: *mut ::std::os::raw::c_void) { + warning!("Initializing parallel foreign scan worker (PID: {}) DSM", std::process::id()); + let scan_state = &mut *((*foreign_scan_state).fdw_state as *mut CounterScanState); + scan_state.counter_lock.register_parallel_worker(coordinate); +} + +#[pg_guard] +unsafe extern "C-unwind" fn pgrx_shutdown_foreign_scan(_node: *mut ForeignScanState) { + // Nothing to do here as well +} + +#[pg_guard] +extern "C-unwind" fn pgrx_import_foreign_schema(stmt: *mut ImportForeignSchemaStmt, _server_oid: Oid) -> *mut List { + let stmt = unsafe { &(*stmt) }; + let table_name = unsafe { CStr::from_ptr(quote_qualified_identifier(std::ptr::null_mut(), stmt.remote_schema)).to_str().unwrap() }; + let server_name = unsafe { CStr::from_ptr(quote_qualified_identifier(std::ptr::null_mut(), stmt.server_name)).to_str().unwrap() }; + let create_table_statement = format!(r#"CREATE FOREIGN TABLE {} ("counter" bigint) SERVER {}"#, table_name, server_name); + + pgrx::memcx::current_context(|memcx| { + let mut statements = pgrx::list::List::default(); + statements.unstable_push_in_context(pgrx::StringInfo::from(create_table_statement).into_char_ptr() as *const c_void as *mut c_void, memcx); + statements.into_ptr() + }) +} + + +pub static EMPTY_FDW: FdwRoutine = FdwRoutine { + type_: NodeTag::T_FdwRoutine, + GetForeignRelSize: None, + GetForeignPaths: None, + GetForeignPlan: None, + BeginForeignScan: None, + IterateForeignScan: None, + ReScanForeignScan: None, + EndForeignScan: None, + GetForeignJoinPaths: None, + GetForeignUpperPaths: None, + AddForeignUpdateTargets: None, + PlanForeignModify: None, + BeginForeignModify: None, + ExecForeignInsert: None, + ExecForeignUpdate: None, + ExecForeignDelete: None, + EndForeignModify: None, + BeginForeignInsert: None, + EndForeignInsert: None, + IsForeignRelUpdatable: None, + PlanDirectModify: None, + BeginDirectModify: None, + IterateDirectModify: None, + EndDirectModify: None, + GetForeignRowMarkType: None, + RefetchForeignRow: None, + RecheckForeignScan: None, + ExplainForeignScan: None, + ExplainForeignModify: None, + ExplainDirectModify: None, + AnalyzeForeignTable: None, + ImportForeignSchema: None, + IsForeignScanParallelSafe: None, + EstimateDSMForeignScan: None, + InitializeDSMForeignScan: None, + ReInitializeDSMForeignScan: None, + InitializeWorkerForeignScan: None, + ShutdownForeignScan: None, + ReparameterizeForeignPathByChild: None, + #[cfg(not(feature="pg13"))] + ExecForeignBatchInsert: None, + #[cfg(not(feature="pg13"))] + ExecForeignTruncate: None, + #[cfg(not(feature="pg13"))] + ForeignAsyncConfigureWait: None, + #[cfg(not(feature="pg13"))] + ForeignAsyncNotify: None, + #[cfg(not(feature="pg13"))] + ForeignAsyncRequest: None, + #[cfg(not(feature="pg13"))] + GetForeignModifyBatchSize: None, + #[cfg(not(feature="pg13"))] + IsForeignPathAsyncCapable: None, +}; diff --git a/pgrx-examples/parallel_scan_lwlock/src/routine.rs b/pgrx-examples/parallel_scan_lwlock/src/routine.rs new file mode 100644 index 0000000000..462d11d58d --- /dev/null +++ b/pgrx-examples/parallel_scan_lwlock/src/routine.rs @@ -0,0 +1,25 @@ +use pgrx::callconv::*; +use pgrx::pgbox::*; +use pgrx::pg_sys::*; +use pgrx::pgrx_sql_entity_graph::metadata::*; + +pub(crate) struct CounterFdwRoutine(pub FdwRoutine); + +unsafe impl SqlTranslatable for CounterFdwRoutine { + fn argument_sql() -> std::result::Result { + Ok(SqlMapping::literal("fdw_handler")) + } + + fn return_sql() -> std::result::Result { + Ok(Returns::One(SqlMapping::literal("fdw_handler"))) + } +} + +unsafe impl BoxRet for CounterFdwRoutine { + unsafe fn box_into<'fcx>(self, fcinfo: &mut FcInfo<'fcx>) -> pgrx::datum::Datum<'fcx> { + let mut pgbox = unsafe { PgBox::::alloc_node(NodeTag::T_FdwRoutine) }; + *pgbox = self.0; + let datum = Datum::from(pgbox.into_pg()); + fcinfo.return_raw_datum(datum) + } +} diff --git a/pgrx-tests/src/tests/shmem_tests.rs b/pgrx-tests/src/tests/shmem_tests.rs index 41f92d883a..6598ba83a0 100644 --- a/pgrx-tests/src/tests/shmem_tests.rs +++ b/pgrx-tests/src/tests/shmem_tests.rs @@ -7,12 +7,13 @@ //LICENSE All rights reserved. //LICENSE //LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. +use pgrx::lwlock::dsm::{DsmLwLock, DsmLwLockTranche}; +use pgrx::lwlock::scan::{ParallelScanLwLock, ParallelScanLwLockTranche}; use pgrx::prelude::*; -use pgrx::{PgAtomic, PgLwLock, pg_shmem_init}; -use std::sync::atomic::AtomicBool; - #[cfg(feature = "cshim")] use pgrx::spinlock::PgSpinLock; +use pgrx::{PgAtomic, PgLwLock, pg_shmem_init}; +use std::sync::atomic::AtomicBool; static ATOMIC: PgAtomic = unsafe { PgAtomic::new(c"pgrx_tests_atomic") }; static LWLOCK: PgLwLock = unsafe { PgLwLock::new(c"pgrx_tests_lwlock") }; @@ -20,20 +21,85 @@ static LWLOCK: PgLwLock = unsafe { PgLwLock::new(c"pgrx_tests_lwlock") }; #[cfg(feature = "cshim")] static SPINLOCK: PgAtomic> = unsafe { PgAtomic::new(c"pgrx_tests_spinlock") }; +static DSMLWLOCK: DsmLwLockTranche = DsmLwLockTranche::new(c"pgrx_tests_dsm_lwlock"); +static DSMLWLOCKMEM: TestDSM = + unsafe { TestDSM::new(DsmLwLock::::mem_size(), c"pgrx_tests_dsm_lwlock_mem") }; +static SCANLWLOCK: ParallelScanLwLockTranche = + ParallelScanLwLockTranche::new(c"pgrx_tests_scan_lwlock"); +static SCANLWLOCKMEM: TestDSM = + unsafe { TestDSM::new(ParallelScanLwLock::::mem_size(), c"pgrx_tests_scan_lwlock_mem") }; + #[pg_guard] pub extern "C-unwind" fn _PG_init() { // This ensures that this functionality works across PostgreSQL versions pg_shmem_init!(ATOMIC); pg_shmem_init!(LWLOCK); + #[cfg(feature = "cshim")] pg_shmem_init!(SPINLOCK = PgSpinLock::new(0)); + + pg_shmem_init!(DSMLWLOCK); + pg_shmem_init!(DSMLWLOCKMEM); + pg_shmem_init!(SCANLWLOCK); + pg_shmem_init!(SCANLWLOCKMEM); +} + +// Allocates just plain shared memory. +// TODO: Should be easier by using GetNamedDSMSegment when its bindings are included. +struct TestDSM { + size: usize, + name: &'static std::ffi::CStr, + inner: std::cell::UnsafeCell<*mut std::ffi::c_void>, +} + +impl TestDSM { + pub const unsafe fn new(size: usize, name: &'static std::ffi::CStr) -> Self { + Self { size, name, inner: std::cell::UnsafeCell::new(std::ptr::null_mut()) } + } + + pub unsafe fn mem(&self) -> *mut std::ffi::c_void { + *(self.inner.get()) + } +} + +unsafe impl Sync for TestDSM {} + +impl pgrx::PgSharedMemoryInitialization for TestDSM { + type Value = (); + + unsafe fn on_shmem_request(&'static self) { + unsafe { + pgrx::pg_sys::RequestAddinShmemSpace(self.size); + } + } + + unsafe fn on_shmem_startup(&'static self, _value: ()) { + unsafe { + use pgrx::pg_sys; + + let shm_name = self.name; + let addin_shmem_init_lock = &raw mut (*pg_sys::MainLWLockArray.add(21)).lock; + pg_sys::LWLockAcquire(addin_shmem_init_lock, pg_sys::LWLockMode::LW_EXCLUSIVE); + + let mut found = false; + let fv_shmem = pg_sys::ShmemInitStruct(shm_name.as_ptr(), self.size, &mut found); + assert!(fv_shmem.is_aligned(), "shared memory is not aligned"); + + *self.inner.get() = fv_shmem; + + pg_sys::LWLockRelease(addin_shmem_init_lock); + } + } } + #[cfg(any(test, feature = "pg_test"))] #[pgrx::pg_schema] mod tests { #[allow(unused_imports)] use crate as pgrx_tests; + use pgrx::lwlock::dsm::DsmLwLockHandle; + use pgrx::lwlock::scan::ParallelScanLwLock; use pgrx::prelude::*; #[pg_test] @@ -75,4 +141,79 @@ mod tests { drop(lock); } } + + fn init_dsm_lwlock() -> DsmLwLockHandle { + use super::{DSMLWLOCK, DSMLWLOCKMEM}; + let data: bool = false; + unsafe { + let shmem = DSMLWLOCKMEM.mem(); + DSMLWLOCK.init(shmem, &data as *const bool); + DSMLWLOCK.register(shmem) + } + } + + #[pg_test] + #[should_panic(expected = "cache lookup failed for type 0")] + pub fn dsm_test_behaves_normally_when_elog_while_holding_lock() { + let handle = init_dsm_lwlock(); + let _lock = handle.exclusive(); + // Call into pg_guarded postgres function which internally reports an error + unsafe { pg_sys::format_type_extended(pg_sys::InvalidOid, -1, 0) }; + } + + #[pg_test] + pub fn dsm_test_lock_is_released_on_drop() { + let handle = init_dsm_lwlock(); + let lock = handle.exclusive(); + drop(lock); + let _lock = handle.exclusive(); + } + + #[pg_test] + pub fn dsm_test_lock_is_released_on_unwind() { + let handle = init_dsm_lwlock(); + let _res = std::panic::catch_unwind(|| { + let _lock = handle.exclusive(); + panic!("get out") + }); + let _lock = handle.exclusive(); + } + + fn init_scan_lwlock() -> ParallelScanLwLock { + use super::{SCANLWLOCK, SCANLWLOCKMEM}; + let data: bool = false; + let mut lock = SCANLWLOCK.lock_for(data); + unsafe { + lock.initialize_dsm_and_register_leader(SCANLWLOCKMEM.mem()); + } + lock + } + + #[pg_test] + #[should_panic(expected = "cache lookup failed for type 0")] + pub fn scan_test_behaves_normally_when_elog_while_holding_lock() { + let mut handle = init_scan_lwlock(); + let _lock = handle.exclusive(); + // Call into pg_guarded postgres function which internally reports an error + unsafe { pg_sys::format_type_extended(pg_sys::InvalidOid, -1, 0) }; + } + + #[pg_test] + pub fn scan_test_lock_is_released_on_drop() { + let mut handle = init_scan_lwlock(); + let lock = handle.exclusive(); + drop(lock); + let _lock = handle.exclusive(); + } + + #[pg_test] + pub fn scan_test_lock_is_released_on_unwind() { + let mut handle = init_scan_lwlock(); + let handle_ref = &handle; + let _res = std::panic::catch_unwind(|| { + let _lock = handle_ref.shared(); + panic!("get out") + }); + let _lock = handle.exclusive(); + } } diff --git a/pgrx/src/lwlock.rs b/pgrx/src/lwlock.rs index 43ab26dbb3..2e5295001a 100644 --- a/pgrx/src/lwlock.rs +++ b/pgrx/src/lwlock.rs @@ -73,6 +73,27 @@ impl PgLwLock { } } +struct AddinShmemInitLock(*mut crate::pg_sys::LWLock); + +impl AddinShmemInitLock { + unsafe fn exclusive() -> Self { + const ADDIN_SHMEM_INIT_LOCK_POS: usize = 21; + let lock = &raw mut (*crate::pg_sys::MainLWLockArray.add(ADDIN_SHMEM_INIT_LOCK_POS)).lock; + crate::pg_sys::LWLockAcquire(lock, crate::pg_sys::LWLockMode::LW_EXCLUSIVE); + Self(lock) + } +} + +impl Drop for AddinShmemInitLock { + fn drop(&mut self) { + unsafe { + if !self.0.is_null() { + crate::pg_sys::LWLockRelease(self.0); + } + } + } +} + impl PgSharedMemoryInitialization for PgLwLock { type Value = T; @@ -88,8 +109,7 @@ impl PgSharedMemoryInitialization for PgLwLock { use crate::pg_sys; let shm_name = self.name; - let addin_shmem_init_lock = &raw mut (*pg_sys::MainLWLockArray.add(21)).lock; - pg_sys::LWLockAcquire(addin_shmem_init_lock, pg_sys::LWLockMode::LW_EXCLUSIVE); + let addin_shmem_init_lock = AddinShmemInitLock::exclusive(); let mut found = false; let fv_shmem = @@ -105,7 +125,7 @@ impl PgSharedMemoryInitialization for PgLwLock { *self.inner.get() = fv_shmem; - pg_sys::LWLockRelease(addin_shmem_init_lock); + drop(addin_shmem_init_lock); } } } @@ -183,3 +203,589 @@ unsafe fn release_unless_elog_unwinding(lock: *mut crate::pg_sys::LWLock) { crate::pg_sys::LWLockRelease(lock); } } + +/// LWLock for dynamic shared memory (DSM). +pub mod dsm { + use crate::lwlock::AddinShmemInitLock; + use std::cell::UnsafeCell; + use std::ffi::{CStr, c_int, c_void}; + + /// A PostgreSQL LWLock-backed locking mechanism for dynamic shared memory (DSM). + /// + /// This is a lower level component which defines operations close to the PostgreSQL LWLock API. + /// If you are interested in locking the shared memory of a foreign parallel scan, please refer + /// to [ParallelScanLwLock](crate::lwlock::scan::ParallelScanLwLock) and + /// [ParallelScanLwLockTranche](crate::lwlock::scan::ParallelScanLwLockTranche). + /// + /// # Usage + /// + /// First, the user may need to obtain a new tranche ID, which is a marker for a family of + /// LWLocks. It can be retrieved by calling [new_lwlock_tranche_id], and it's recommended + /// that a tranche ID for a LWLock family is retrieved once per server instance. This function + /// can be called during the extension setup in the SHMEM hooks (which are made available by + /// pgrx through the [crate::PgSharedMemoryInitialization] trait and the [crate::pg_shmem_init!] + /// macro). If you find more convenient a tranche ID provided at startup time, refer to the + /// documentation of the [DsmLwLockTranche] type. + /// + /// In addition to the tranche ID, the DSM lock must be initialized along with the data it + /// wraps. Data is byte-wise copied from the pointer passed to [DsmLwLock::init] to the DSM. Its + /// type must be self-contained, avoiding any reference or pointer to local memory, like + /// heap-allocated data structures. The user must guarantee that this requirement is met by + /// using a type implementing the [crate::PGRXSharedMemory] trait, or by providing its own + /// implementation. + /// + /// After the lock is initialized on the DSM, it must be registered in every involved process + /// with [DsmLwLock::register], which returns a [DsmLwLockHandle] that can be stored in the + /// process local memory. The handle provides methods to obtain + /// [exclusive](DsmLwLockHandle::exclusive) or [shared](DsmLwLockHandle::shared) lock guards. + /// When dropped, a guard releases the lock. Quoting the PostgreSQL documentation, each process + /// using the tranche must register it separately, as "dynamic shared memory segments aren't + /// guaranteed to be mapped at the same address in all coordinating backends, so storing the + /// registration in the main shared memory segment wouldn't work for that case". + #[repr(C)] + pub struct DsmLwLock { + lock: crate::pg_sys::LWLock, + data: T, + } + + /// Handle to a DSM LWLock. + /// It can be obtained from [DsmLwLock::register] after the DSM is initialized by + /// [DsmLwLock::init]. + pub struct DsmLwLockHandle { + handle: *mut DsmLwLock, + } + + impl DsmLwLock { + /// Memory size in bytes required to store a lock instance, along its wrapped value. + pub const fn mem_size() -> usize { + size_of::>() + } + + /// Initialize the DSM with a lock and a copy of the wrapped value. + /// + /// # Safety + /// + /// * `dsm` must not be null. + /// * `dsm` must be aligned. + /// * `dsm` must have at least [Self::mem_size] space. + /// * `data` must not be null. + /// * `data` must not point to an address inside the `dsm` memory allocation. + pub unsafe fn init(dsm: *mut c_void, tranche_id: c_int, data: *const T) { + assert!(dsm.is_aligned(), "dynamic shared memory is not aligned"); + let dsm = dsm as *mut DsmLwLock; + (&raw mut (*dsm).lock).write(crate::pg_sys::LWLock::default()); + crate::pg_sys::LWLockInitialize(&raw mut (*dsm).lock, tranche_id); + (&raw mut (*dsm).data).copy_from_nonoverlapping(data, 1); + } + + /// Register the lock tranche to associate its ID with a name. + /// + /// # Safety + /// + /// * `dsm` was already initialized with [Self::init] (and therefore all its safety requirements are met). + pub unsafe fn register(dsm: *mut c_void, name: &'static CStr) -> DsmLwLockHandle { + assert!(dsm.is_aligned(), "dynamic shared memory is not aligned"); + let dsm = dsm as *mut DsmLwLock; + crate::pg_sys::LWLockRegisterTranche((*dsm).lock.tranche as _, name.as_ptr()); + DsmLwLockHandle { handle: dsm } + } + } + + impl DsmLwLockHandle { + /// Obtain a shared lock (which comes with `&T` access). + pub fn shared(&self) -> super::PgLwLockShareGuard<'_, T> { + assert!(!self.handle.is_null(), "unregistered DSM LWLock handle"); + unsafe { + let lock_ptr = (&raw mut (*self.handle).lock); + crate::pg_sys::LWLockAcquire(lock_ptr, crate::pg_sys::LWLockMode::LW_SHARED); + super::PgLwLockShareGuard { + data: (&raw const (*self.handle).data) + .as_ref() + .expect("Unexpected null raw pointer to field"), + lock: lock_ptr, + } + } + } + + /// Obtain an exclusive lock (which comes with `&mut T` access). + pub fn exclusive(&self) -> super::PgLwLockExclusiveGuard<'_, T> { + assert!(!self.handle.is_null(), "unregistered DSM LWLock handle"); + unsafe { + let lock_ptr = (&raw mut (*self.handle).lock); + crate::pg_sys::LWLockAcquire(lock_ptr, crate::pg_sys::LWLockMode::LW_EXCLUSIVE); + super::PgLwLockExclusiveGuard { + data: (&raw mut (*self.handle).data) + .as_mut() + .expect("Unexpected null raw pointer to field"), + lock: lock_ptr, + } + } + } + } + + /// Request a new tranche ID for dynamically allocated LWLocks. + /// + /// # Caution + /// + /// Use parsimoniously, for locks store tranche IDs in 16-bit unsigned integers. The user + /// should not request a tranche ID per LWLock instance, but per LWLock family instead, + /// which groups instances of locks created for the same purpose. + /// + /// # Panics + /// + /// This function checks whether the next tranche ID exceeds the unsigned 16-bit boundary, + /// to avoid subtle errors inside PostgreSQL LWLock API that may associate a new lock to a + /// completely different tranche because of the ID truncation. + pub fn new_lwlock_tranche_id() -> c_int { + let tranche_id = unsafe { crate::pg_sys::LWLockNewTrancheId() }; + if tranche_id > (u16::MAX as i32) { + panic!( + "all valid LWLock tranche IDs have been consumed: this or any other extension is probably requesting a new tranche ID on every dynamic LWLock creation" + ); + } + tranche_id + } + + /// Component that obtains a LWLock tranche ID on Postgres Shared Memory initialization. + /// + /// To be used as the type for a static global, initialized by the [crate::pg_shmem_init!] macro + /// in the extension `_PG_init()` function. + /// + /// ```rust,no_run + /// use ::pgrx::*; + /// use ::pgrx_pg_sys::*; + /// use ::pgrx::lwlock::dsm::*; + /// + /// static LOCKS_FOR_MY_TASK: DsmLwLockTranche = DsmLwLockTranche::new(c"my_task_lock"); + /// + /// #[allow(non_snake_case)] + /// #[pg_guard] + /// pub extern "C-unwind" fn _PG_init() { + /// //... + /// pg_shmem_init!(LOCKS_FOR_MY_TASK); + /// } + /// ``` + /// + /// This type provides convenience methods for initializing and registering LWLocks on the DSM. + /// Safety requirements are those specified in the respective [DsmLwLock] functions. + pub struct DsmLwLockTranche { + name: &'static CStr, + lock: UnsafeCell>, + } + + /// UnsafeCell cannot be shared between threads safely, we allow its use within static globals. + unsafe impl Sync for DsmLwLockTranche {} + + impl DsmLwLockTranche { + /// Define a LWLock tranche, along with the tranche name that backends will associate locks + /// to when created from this tranche. + pub const fn new(name: &'static CStr) -> Self { + Self { name, lock: UnsafeCell::new(None) } + } + + /// The name assigned to this tranche. + pub const fn name(&self) -> &'static CStr { + self.name + } + + /// Get the tranche ID. + /// Make sure that the static global is initialized by the [crate::pg_shmem_init!] macro in + /// `_PG_init()`. + /// + /// # Panics + /// + /// This method must not be invoked on an uninitialized tranche, otherwise it will panic. + pub fn tranche_id(&self) -> c_int { + unsafe { + (*self.lock.get()).expect("uninitialized DSM LWLock tranche (use pg_shmem_init!() in _PG_init() to initialize it)") + } + } + + /// Initialize the DSM with a lock and a copy of the wrapped value. + /// + /// # Panics + /// + /// This method must not be invoked on an uninitialized tranche, otherwise it will panic. + /// + /// # Safety + /// + /// * `dsm` must not be null. + /// * `dsm` must have at least [DsmLwLock::mem_size] space. + /// * `data` must not be null. + /// * `data` must not point to an address inside the `dsm` memory allocation. + pub unsafe fn init(&self, dsm: *mut c_void, data: *const T) + where + T: crate::PGRXSharedMemory, + { + DsmLwLock::::init(dsm, self.tranche_id(), data) + } + + /// Register the lock tranche to associate its ID with a name. + /// + /// # Safety + /// + /// * `dsm` was already initialized with [Self::init] (and therefore all its safety + /// requirements are met). + pub unsafe fn register(&self, dsm: *mut c_void) -> DsmLwLockHandle + where + T: crate::PGRXSharedMemory, + { + DsmLwLock::::register(dsm, self.name) + } + } + + impl crate::PgSharedMemoryInitialization for DsmLwLockTranche { + type Value = (); + + unsafe fn on_shmem_request(&'static self) { + // Nothing to do here + } + + unsafe fn on_shmem_startup(&'static self, _value: Self::Value) { + let addin_shmem_init_lock = AddinShmemInitLock::exclusive(); + if (*self.lock.get()).is_none() { + *self.lock.get() = Some(new_lwlock_tranche_id()); + } + drop(addin_shmem_init_lock); + } + } +} + +/// LWLock for dynamic shared memory (DSM) during parallel foreign scans. +pub mod scan { + use super::dsm::{DsmLwLock, DsmLwLockHandle, DsmLwLockTranche}; + use std::borrow::{Borrow, BorrowMut}; + use std::ffi::{CStr, c_void}; + use std::ops::{Deref, DerefMut}; + use std::panic::AssertUnwindSafe; + + enum ParallelScanSharedState { + Local(A), + Shared(DsmLwLockHandle), + } + + /// This type of lock is designed to manage access to the DSM allocated for parallel foreign + /// scans, to coordinate work among parallel workers. Some of the FDW routines exposed by + /// PostgreSQL provide a template for setting up shared state once, in the leader process, and + /// then use it to initialize parallel workers local state. Creating a LWLock in the DSM from + /// the leader process guarantees that the lock is initialized once. When using this locking + /// mechanism for other types of shared memory, the user must guarantee that the lock + /// initialization is run once by one of the participating worker processes. + /// + /// The value wrapped by this lock is initially stored within the local memory, then it's copied + /// to the shared memory upon its initialization. This transitional local state is necessary, + /// because the FDW routines dedicated to the DSM initialization may not be called for parallel + /// foreign scans with the participation of a single parallel worker process. + /// + /// The shared state type must implement the [crate::PGRXSharedMemory] trait. Then, a reference + /// to the LWLock-wrapped state can be stored in the parallel scan state with + /// [ParallelScanLwLock]. + /// + /// ```rust,no_run + /// use ::pgrx::*; + /// use ::pgrx_pg_sys::*; + /// use ::pgrx::lwlock::scan::*; + /// + /// struct MySharedState { + /// a: usize, + /// b: i64, + /// } + /// + /// unsafe impl PGRXSharedMemory for MySharedState {} + /// + /// struct MyParallelScanState { + /// local_data: Vec, + /// shared_data: ParallelScanLwLock, + /// } + /// ``` + /// + /// The user may need to store large datatypes in the DSM, e.g. buffers or other data structures + /// with a pre-allocated capacity. As the wrapped type must not contain any reference or pointer + /// (see [DsmLwLock]), the value to be allocated may be too large to fit on the stack. To + /// overcome such kind of memory issues, the user should use instead a [Box] or any other boxing + /// type that implements [BorrowMut] for its type. + /// + /// ```rust,no_run + /// use ::pgrx::*; + /// use ::pgrx_pg_sys::*; + /// use ::pgrx::lwlock::scan::*; + /// + /// struct MyLargeSharedBuffer { + /// start: usize, + /// end: usize, + /// buffer: [u8; 10 * 1024 * 1024], // 10MB on the stack -> ☠️ + /// } + /// + /// unsafe impl PGRXSharedMemory for MyLargeSharedBuffer {} // yet DSM-safe + /// + /// impl MyLargeSharedBuffer { + /// + /// pub fn boxed() -> Box { + /// unsafe { + /// let ptr = std::alloc::alloc_zeroed(std::alloc::Layout::new::()) as *mut MyLargeSharedBuffer; + /// (&raw mut (*ptr).start).write(0usize); + /// (&raw mut (*ptr).end).write(0usize); + /// Box::from_raw(ptr) + /// } + /// } + /// } + /// + /// struct MyParallelScanState { + /// local_buffer: Vec, + /// shared_buffer: ParallelScanLwLock>, + /// } + /// ``` + /// + /// # Lock initialization + /// + /// The user should declare a static global for a LWLock tranche of type + /// [ParallelScanLwLockTranche] to be initialized with the [crate::pg_shmem_init!] macro in the + /// extension `_PG_init()` function. Then the lock methods must be called within the PostgreSQL + /// parallel foreign scan routines. + /// + /// ```rust,no_run + /// use ::pgrx::*; + /// use ::pgrx_pg_sys::*; + /// use ::pgrx::lwlock::dsm::*; + /// use ::pgrx::lwlock::scan::*; + /// + /// struct MySharedState { + /// //... + /// } + /// + /// unsafe impl PGRXSharedMemory for MySharedState {} + /// + /// struct MyParallelScanState { + /// local_state: Vec, + /// shared_state: ParallelScanLwLock>, + /// } + /// + /// static PARALLEL_SCAN_LWLOCKS: ParallelScanLwLockTranche = ParallelScanLwLockTranche::new(c"parallel_scan_lock"); + /// + /// #[allow(non_snake_case)] + /// #[pg_guard] + /// pub extern "C-unwind" fn _PG_init() { + /// // other required initialization + /// pg_shmem_init!(PARALLEL_SCAN_LWLOCKS); + /// } + /// + /// #[pg_guard] + /// extern "C-unwind" fn pgrx_begin_foreign_scan(foreign_scan_state: *mut ForeignScanState, _eflags: ::std::os::raw::c_int) { + /// // Other foreign scan setup... + /// + /// let scan_state = MyParallelScanState { + /// local_state: vec![], + /// shared_state: PARALLEL_SCAN_LWLOCKS.lock_for(Box::new(MySharedState { /* ... */ } )), + /// }; + /// + /// // We rely on Postgres memory context to drop our value on delete + /// unsafe { (*foreign_scan_state).fdw_state = PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(scan_state) as *mut std::os::raw::c_void }; + /// } + /// + /// #[pg_guard] + /// unsafe extern "C-unwind" fn pgrx_is_foreign_scan_parallel_safe(_root: *mut PlannerInfo, _rel: *mut RelOptInfo, _rte: *mut RangeTblEntry) -> bool { + /// true + /// } + /// + /// #[pg_guard] + /// unsafe extern "C-unwind" fn pgrx_estimate_dsm_foreign_scan(_foreign_scan_state: *mut ForeignScanState, _pcxt: *mut ParallelContext) -> Size { + /// ParallelScanLwLock::>::mem_size() + /// } + /// + /// #[pg_guard] + /// unsafe extern "C-unwind" fn pgrx_initialize_dsm_foreign_scan(foreign_scan_state: *mut ForeignScanState, _pcxt: *mut ParallelContext, shared_mem: *mut ::std::os::raw::c_void) { + /// let scan_state = &mut *((*foreign_scan_state).fdw_state as *mut MyParallelScanState); + /// // Initialize DSM and update the scan state to refer to the DSM-stored LWLock + /// scan_state.shared_state.initialize_dsm_and_register_leader(shared_mem); + /// // Initialize leader process local state, if needed + /// } + /// + /// #[pg_guard] + /// unsafe extern "C-unwind" fn pgrx_reinitialize_dsm_foreign_scan(_node: *mut ForeignScanState, _pcxt: *mut ParallelContext, _coordinate: *mut ::std::os::raw::c_void) { + /// // Do some re-initialization, if needed + /// } + /// + /// #[pg_guard] + /// unsafe extern "C-unwind" fn pgrx_initialize_worker_foreign_scan(foreign_scan_state: *mut ForeignScanState, _toc: *mut shm_toc, coordinate: *mut ::std::os::raw::c_void) { + /// let scan_state = &mut *((*foreign_scan_state).fdw_state as *mut MyParallelScanState); + /// // Update scan_state, replacing the invalid pointer from the parent process with the remapped DSM pointer + /// scan_state.shared_state.register_parallel_worker(coordinate); + /// // Initialize parallel worker local state, if needed + /// } + /// + /// #[pg_guard] + /// unsafe extern "C-unwind" fn pgrx_shutdown_foreign_scan(_node: *mut ForeignScanState) { + /// // Invoked when the node will not be executed to completion, if you wish to take some action + /// // before the DSM segment is destroyed. Look at FDW callbacks documentation for more info. + /// } + /// ``` + pub struct ParallelScanLwLock + where + A: BorrowMut, + { + tranche: AssertUnwindSafe<&'static DsmLwLockTranche>, + data: ParallelScanSharedState, + } + + /// A shared LWLock guard that skips locking if the lock was not moved to shared memory yet. + pub enum ParallelScanLwLockShareGuard<'a, T> { + Local(&'a T), + Shared(super::PgLwLockShareGuard<'a, T>), + } + + unsafe impl Sync for ParallelScanLwLockShareGuard<'_, T> {} + + impl Deref for ParallelScanLwLockShareGuard<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + match self { + Self::Local(value) => value, + Self::Shared(guard) => guard.deref(), + } + } + } + + /// An exclusive LWLock guard that skips locking if the lock was not moved to shared memory yet. + pub enum ParallelScanLwLockExclusiveGuard<'a, T> { + Local(&'a mut T), + Shared(super::PgLwLockExclusiveGuard<'a, T>), + } + + unsafe impl Sync for ParallelScanLwLockExclusiveGuard<'_, T> {} + + impl Deref for ParallelScanLwLockExclusiveGuard<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + match self { + Self::Local(value) => value, + Self::Shared(guard) => guard.deref(), + } + } + } + + impl DerefMut for ParallelScanLwLockExclusiveGuard<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + match self { + Self::Local(value) => value, + Self::Shared(guard) => guard.deref_mut(), + } + } + } + + impl ParallelScanLwLock + where + A: BorrowMut, + { + /// Constructs a new LWLock, given a tranche and an initial value (or a box type from which + /// you can [BorrowMut] it). + pub fn new(tranche: &'static DsmLwLockTranche, value: A) -> Self { + Self { tranche: AssertUnwindSafe(tranche), data: ParallelScanSharedState::Local(value) } + } + + /// Obtain a shared lock (which comes with `&T` access). + pub fn shared(&self) -> ParallelScanLwLockShareGuard<'_, T> { + match &self.data { + ParallelScanSharedState::Local(value) => { + ParallelScanLwLockShareGuard::Local(value.borrow()) + } + ParallelScanSharedState::Shared(handle) => { + ParallelScanLwLockShareGuard::Shared(handle.shared()) + } + } + } + + /// Obtain an exclusive lock (which comes with `&mut T` access). + pub fn exclusive(&mut self) -> ParallelScanLwLockExclusiveGuard<'_, T> { + match &mut self.data { + ParallelScanSharedState::Local(value) => { + ParallelScanLwLockExclusiveGuard::Local(value.borrow_mut()) + } + ParallelScanSharedState::Shared(handle) => { + ParallelScanLwLockExclusiveGuard::Shared(handle.exclusive()) + } + } + } + } + + impl ParallelScanLwLock + where + A: BorrowMut, + { + pub const fn mem_size() -> usize { + DsmLwLock::::mem_size() + } + + /// To be called by the leader process of a parallel foreign scan within the + /// `pgrx_initialize_dsm_foreign_scan` function. + /// + /// # Panics + /// + /// This method panics if it or [Self::register_parallel_worker] were already called. + /// + /// # Safety + /// + /// * `dsm` must not be null. + /// * `dsm` must have at least [ParallelScanLwLock::mem_size] space. + pub unsafe fn initialize_dsm_and_register_leader(&mut self, dsm: *mut c_void) { + match &self.data { + ParallelScanSharedState::Local(value) => { + self.tranche.init(dsm, >::borrow(value) as *const T); + self.data = ParallelScanSharedState::Shared(self.tranche.register(dsm)); + } + ParallelScanSharedState::Shared(_) => { + panic!("DSM LWLock already initialized"); + } + } + } + + /// To be called by parallel worker processes of a parallel foreign scan within the + /// `pgrx_initialize_worker_foreign_scan` function. + /// + /// # Safety + /// + /// * The leader process must have invoked [Self::initialize_dsm_and_register_leader] on + /// `dsm` first (and therefore all its safety requirements are met). + pub unsafe fn register_parallel_worker(&mut self, dsm: *mut c_void) { + self.data = ParallelScanSharedState::Shared(self.tranche.register(dsm)); + } + } + + /// LWLock tranche for foreign parallel scans. + /// Similar to the [DsmLwLockTranche] type, offering a more convenient method to create a LWLock + /// directly from the tranche instance. + pub struct ParallelScanLwLockTranche(DsmLwLockTranche); + + impl ParallelScanLwLockTranche { + /// Define a LWLock tranche, along with the tranche name that backends will associate locks + /// to when created from this tranche. + pub const fn new(name: &'static CStr) -> Self { + Self(DsmLwLockTranche::new(name)) + } + + /// Creates a new LWLock associated to this tranche. + pub fn lock_for(&'static self, value: A) -> ParallelScanLwLock + where + T: crate::PGRXSharedMemory, + A: BorrowMut, + { + ParallelScanLwLock::new(&self.0, value) + } + } + + impl crate::PgSharedMemoryInitialization for ParallelScanLwLockTranche { + type Value = (); + + unsafe fn on_shmem_request(&'static self) { + self.0.on_shmem_request() + } + + unsafe fn on_shmem_startup(&'static self, value: Self::Value) { + self.0.on_shmem_startup(value) + } + } +}