diff --git a/payjoin-cli/src/app/config.rs b/payjoin-cli/src/app/config.rs index f0cd22a0b..e4cb713d4 100644 --- a/payjoin-cli/src/app/config.rs +++ b/payjoin-cli/src/app/config.rs @@ -338,6 +338,8 @@ fn handle_subcommands(config: Builder, cli: &Cli) -> Result Ok(config), #[cfg(feature = "v2")] Commands::Fallback { .. } => Ok(config), + #[cfg(feature = "v2")] + Commands::Cancel { .. } => Ok(config), } } diff --git a/payjoin-cli/src/app/mod.rs b/payjoin-cli/src/app/mod.rs index 48499bebe..d1fc93fec 100644 --- a/payjoin-cli/src/app/mod.rs +++ b/payjoin-cli/src/app/mod.rs @@ -32,6 +32,8 @@ pub trait App: Send + Sync { async fn history(&self) -> Result<()>; #[cfg(feature = "v2")] async fn fallback_sender(&self, session_id: SessionId) -> Result<()>; + #[cfg(feature = "v2")] + async fn cancel_sender(&self, session_id: SessionId) -> Result<()>; fn create_original_psbt( &self, diff --git a/payjoin-cli/src/app/v1.rs b/payjoin-cli/src/app/v1.rs index 5d98ddc26..a6bd79a8e 100644 --- a/payjoin-cli/src/app/v1.rs +++ b/payjoin-cli/src/app/v1.rs @@ -132,6 +132,11 @@ impl AppTrait for App { async fn fallback_sender(&self, _session_id: crate::db::v2::SessionId) -> Result<()> { anyhow::bail!("fallback is only supported for v2 (BIP77) sessions") } + + #[cfg(feature = "v2")] + async fn cancel_sender(&self, _session_id: crate::db::v2::SessionId) -> Result<()> { + anyhow::bail!("cancel is only supported for v2 (BIP77) sessions") + } } impl App { diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 82fb87dcc..67e1100da 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -13,8 +13,8 @@ use payjoin::receive::v2::{ WantsOutputs, }; use payjoin::send::v2::{ - replay_event_log as replay_sender_event_log, PollingForProposal, SendSession, Sender, - SenderBuilder, SessionOutcome as SenderSessionOutcome, WithReplyKey, + replay_event_log as replay_sender_event_log, PendingFallback, PollingForProposal, SendSession, + Sender, SenderBuilder, SessionOutcome as SenderSessionOutcome, WithReplyKey, }; use payjoin::{ImplementationError, PjParam, Uri}; use tokio::sync::watch; @@ -55,8 +55,9 @@ impl StatusText for SendSession { SendSession::Closed(session_outcome) => match session_outcome { SenderSessionOutcome::Failure => "Session failure", SenderSessionOutcome::Success(_) => "Session success", - SenderSessionOutcome::Cancel => "Session cancelled", + SenderSessionOutcome::Aborted => "Session aborted", }, + SendSession::PendingFallback(_) => "Session aborted", } } } @@ -486,29 +487,71 @@ impl AppTrait for App { async fn fallback_sender(&self, session_id: SessionId) -> Result<()> { let persister = SenderPersister::from_id(self.db.clone(), session_id.clone()); - let (session, history) = replay_sender_event_log(&persister)?; - - if let SendSession::Closed(SenderSessionOutcome::Success(proposal)) = session { - let txid = proposal.clone().extract_tx_unchecked_fee_rate().compute_txid(); - println!( - "Session {session_id} already produced payjoin transaction {txid}. \ - Broadcasting the original now would double-spend against it. \ - If the payjoin tx needs re-broadcast, run \ - `bitcoin-cli gettransaction {txid}` to fetch the hex, then \ - `bitcoin-cli sendrawtransaction `." - ); - return Ok(()); - } + let (session, _history) = replay_sender_event_log(&persister)?; - let fallback_tx = history.fallback_tx(); - self.wallet().broadcast_tx(&fallback_tx)?; - println!("Broadcasted fallback transaction txid: {}", fallback_tx.compute_txid()); + let pending: Sender = match session { + SendSession::PendingFallback(sender) => sender, + SendSession::WithReplyKey(sender) => sender.cancel().save(&persister)?, + SendSession::PollingForProposal(sender) => sender.cancel().save(&persister)?, + SendSession::Closed(SenderSessionOutcome::Success(proposal)) => { + let txid = proposal.extract_tx_unchecked_fee_rate().compute_txid(); + println!( + "Session {session_id} already produced payjoin transaction {txid}. \ + Broadcasting the original now would double-spend against it. \ + If the payjoin tx needs re-broadcast, run \ + `bitcoin-cli gettransaction {txid}` to fetch the hex, then \ + `bitcoin-cli sendrawtransaction `." + ); + return Ok(()); + } + SendSession::Closed(_) => { + println!("Session {session_id} is already closed. Nothing left to do."); + return Ok(()); + } + }; + + self.wallet().broadcast_tx(pending.fallback_tx())?; + println!("Broadcasted fallback transaction txid: {}", pending.fallback_tx().compute_txid()); if let Err(e) = SessionPersister::close(&persister) { tracing::warn!("Failed to close session {session_id} after fallback: {e}"); } Ok(()) } + + async fn cancel_sender(&self, session_id: SessionId) -> Result<()> { + let persister = SenderPersister::from_id(self.db.clone(), session_id.clone()); + let (session, _history) = replay_sender_event_log(&persister)?; + + match session { + SendSession::WithReplyKey(sender) => { + sender.cancel().save(&persister)?; + } + SendSession::PollingForProposal(sender) => { + sender.cancel().save(&persister)?; + } + SendSession::PendingFallback(_) => { + println!("Session {session_id} is already cancelled."); + return Ok(()); + } + SendSession::Closed(SenderSessionOutcome::Success(proposal)) => { + let txid = proposal.extract_tx_unchecked_fee_rate().compute_txid(); + println!( + "Session {session_id} already produced payjoin transaction {txid}. \ + Cannot cancel a completed session." + ); + return Ok(()); + } + SendSession::Closed(_) => { + println!("Session {session_id} is already closed."); + return Ok(()); + } + } + println!( + "Session {session_id} cancelled. Run `payjoin-cli fallback {session_id}` to broadcast the original transaction." + ); + Ok(()) + } } impl App { @@ -538,10 +581,14 @@ impl App { return Ok(()); } SendSession::Closed(SenderSessionOutcome::Failure) - | SendSession::Closed(SenderSessionOutcome::Cancel) => { + | SendSession::Closed(SenderSessionOutcome::Aborted) => { + println!("Session is closed. Nothing left to do"); + return Ok(()); + } + SendSession::PendingFallback(_) => { let id = persister.session_id(); println!( - "Session {id} ended without payjoin. Run `payjoin-cli fallback {id}` to broadcast the original transaction." + "Session {id} was cancelled. Run `payjoin-cli fallback {id}` to broadcast the original transaction." ); return Ok(()); } diff --git a/payjoin-cli/src/cli/mod.rs b/payjoin-cli/src/cli/mod.rs index 7cef2551b..b7faf9b3e 100644 --- a/payjoin-cli/src/cli/mod.rs +++ b/payjoin-cli/src/cli/mod.rs @@ -139,6 +139,13 @@ pub enum Commands { #[arg(required = true)] session_id: i64, }, + #[cfg(feature = "v2")] + /// Cancel a sender session and broadcast the fallback transaction (BIP77/v2 only) + Cancel { + /// The session ID to cancel + #[arg(required = true)] + session_id: i64, + }, } pub fn parse_amount_in_sat(s: &str) -> Result { diff --git a/payjoin-cli/src/main.rs b/payjoin-cli/src/main.rs index 6b2b038f4..8f1c4741f 100644 --- a/payjoin-cli/src/main.rs +++ b/payjoin-cli/src/main.rs @@ -82,6 +82,10 @@ async fn main() -> Result<()> { Commands::Fallback { session_id } => { app.fallback_sender(SessionId(*session_id)).await?; } + #[cfg(feature = "v2")] + Commands::Cancel { session_id } => { + app.cancel_sender(SessionId(*session_id)).await?; + } }; Ok(()) diff --git a/payjoin-cli/tests/e2e.rs b/payjoin-cli/tests/e2e.rs index fe761e1fe..035781a96 100644 --- a/payjoin-cli/tests/e2e.rs +++ b/payjoin-cli/tests/e2e.rs @@ -712,7 +712,42 @@ mod e2e { // Ensure the fallback was not broadcast yet let mempool_size = sender.get_mempool_info().expect("should be able to get mempool").unbroadcast_count; - assert_eq!(mempool_size, 0, "fallback should not be in mempool"); + assert_eq!(mempool_size, 0, "fallback should not be in mempool before cancel"); + + // Run `payjoin-cli cancel ` and assert session is cancelled without broadcast + let mut cli_cancel = Command::new(payjoin_cli) + .arg("--root-certificate") + .arg(cert_path) + .arg("--rpchost") + .arg(&sender_rpchost) + .arg("--cookie-file") + .arg(cookie_file) + .arg("--db-path") + .arg(&sender_db_path) + .arg("--ohttp-relays") + .arg(ohttp_relay) + .arg("cancel") + .arg(session_id.to_string()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn() + .expect("Failed to execute payjoin-cli cancel"); + + let mut cancel_stdout = + cli_cancel.stdout.take().expect("failed to take stdout of cancel"); + let timeout = tokio::time::Duration::from_secs(10); + let cancel_line = tokio::time::timeout( + timeout, + wait_for_stdout_match(&mut cancel_stdout, |l| l.contains("cancelled")), + ) + .await?; + terminate(cli_cancel).await.expect("Failed to kill payjoin-cli cancel"); + assert!(cancel_line.is_some(), "cancel should output cancellation confirmation"); + + // Ensure the fallback was NOT broadcast after cancel + let mempool_size = + sender.get_mempool_info().expect("should be able to get mempool").unbroadcast_count; + assert_eq!(mempool_size, 0, "fallback should not be in mempool after cancel"); // Run `payjoin-cli fallback ` and assert broadcast let mut cli_fallback = Command::new(payjoin_cli) diff --git a/payjoin-ffi/csharp/UnitTests.cs b/payjoin-ffi/csharp/UnitTests.cs index f96b1318f..d924a5486 100644 --- a/payjoin-ffi/csharp/UnitTests.cs +++ b/payjoin-ffi/csharp/UnitTests.cs @@ -211,13 +211,13 @@ public void SenderCancelFromWithReplyKey() .BuildRecommended(1000) .Save(senderPersister); var cancelTransition = withReplyKey.Cancel(); - var fallbackTx = cancelTransition.Save(senderPersister); - Assert.NotNull(fallbackTx); - Assert.NotEmpty(fallbackTx); + var pendingFallback = cancelTransition.Save(senderPersister); + Assert.NotNull(pendingFallback); + Assert.NotEmpty(pendingFallback.FallbackTx()); var result = PayjoinMethods.ReplaySenderEventLog(senderPersister); var state = result.State(); - Assert.IsType(state); + Assert.IsType(state); } [Fact] @@ -238,13 +238,13 @@ public async Task SenderCancelFromWithReplyKeyAsync() .BuildRecommended(1000) .SaveAsync(senderPersister); var cancelTransition = withReplyKey.Cancel(); - var fallbackTx = await cancelTransition.SaveAsync(senderPersister); - Assert.NotNull(fallbackTx); - Assert.NotEmpty(fallbackTx); + var pendingFallback = await cancelTransition.SaveAsync(senderPersister); + Assert.NotNull(pendingFallback); + Assert.NotEmpty(pendingFallback.FallbackTx()); var result = await PayjoinMethods.ReplaySenderEventLogAsync(senderPersister); var state = result.State(); - Assert.IsType(state); + Assert.IsType(state); } } diff --git a/payjoin-ffi/dart/test/test_payjoin_unit_test.dart b/payjoin-ffi/dart/test/test_payjoin_unit_test.dart index 3e9baa1d0..f472e6b1c 100644 --- a/payjoin-ffi/dart/test/test_payjoin_unit_test.dart +++ b/payjoin-ffi/dart/test/test_payjoin_unit_test.dart @@ -183,14 +183,14 @@ void main() { .buildRecommended(minFeeRateSatPerKwu: 1000) .save(persister: sender_persister); var cancelTransition = withReplyKey.cancel(); - var fallbackTx = cancelTransition.save(persister: sender_persister); - expect(fallbackTx, isNotNull); - expect(fallbackTx.length, greaterThan(0)); + var pendingFallback = cancelTransition.save(persister: sender_persister); + expect(pendingFallback, isNotNull); + expect(pendingFallback!.fallbackTx().length, greaterThan(0)); final result = payjoin.replaySenderEventLog(persister: sender_persister); expect( result.state(), - isA(), - reason: "sender should be in Closed state after cancel", + isA(), + reason: "sender should be in Cancelled state after cancel", ); }); @@ -215,18 +215,18 @@ void main() { .buildRecommended(minFeeRateSatPerKwu: 1000) .saveAsync(persister: sender_persister); var cancelTransition = withReplyKey.cancel(); - var fallbackTx = await cancelTransition.saveAsync( + var pendingFallback = await cancelTransition.saveAsync( persister: sender_persister, ); - expect(fallbackTx, isNotNull); - expect(fallbackTx.length, greaterThan(0)); + expect(pendingFallback, isNotNull); + expect(pendingFallback!.fallbackTx().length, greaterThan(0)); final result = await payjoin.replaySenderEventLogAsync( persister: sender_persister, ); expect( result.state(), - isA(), - reason: "sender should be in Closed state after cancel", + isA(), + reason: "sender should be in Cancelled state after cancel", ); }); }); diff --git a/payjoin-ffi/javascript/test/unit.test.ts b/payjoin-ffi/javascript/test/unit.test.ts index 1199bcc74..b7eada06a 100644 --- a/payjoin-ffi/javascript/test/unit.test.ts +++ b/payjoin-ffi/javascript/test/unit.test.ts @@ -237,10 +237,10 @@ describe("Sender cancel tests", () => { .save(senderPersister); const cancelTransition = withReplyKey.cancel(); - const fallbackTx = cancelTransition.save(senderPersister); - assert.ok(fallbackTx, "fallback tx should be returned"); + const pendingFallback = cancelTransition.save(senderPersister); + assert.ok(pendingFallback, "pending fallback should be returned"); assert.ok( - fallbackTx.byteLength > 0, + pendingFallback.fallbackTx().byteLength > 0, "fallback tx bytes should be non-empty", ); @@ -248,8 +248,8 @@ describe("Sender cancel tests", () => { const state = result.state(); assert.strictEqual( state.tag, - "Closed", - "State should be Closed after cancel", + "Cancelled", + "State should be Cancelled after cancel", ); }); @@ -285,10 +285,11 @@ describe("Sender cancel tests", () => { .saveAsync(senderPersister); const cancelTransition = withReplyKey.cancel(); - const fallbackTx = await cancelTransition.saveAsync(senderPersister); - assert.ok(fallbackTx, "fallback tx should be returned"); + const pendingFallback = + await cancelTransition.saveAsync(senderPersister); + assert.ok(pendingFallback, "pending fallback should be returned"); assert.ok( - fallbackTx.byteLength > 0, + pendingFallback.fallbackTx().byteLength > 0, "fallback tx bytes should be non-empty", ); @@ -296,8 +297,8 @@ describe("Sender cancel tests", () => { const state = result.state(); assert.strictEqual( state.tag, - "Closed", - "State should be Closed after cancel", + "Cancelled", + "State should be Cancelled after cancel", ); }); }); diff --git a/payjoin-ffi/python/test/test_payjoin_unit_test.py b/payjoin-ffi/python/test/test_payjoin_unit_test.py index c63836de5..5b6410d7a 100644 --- a/payjoin-ffi/python/test/test_payjoin_unit_test.py +++ b/payjoin-ffi/python/test/test_payjoin_unit_test.py @@ -214,11 +214,11 @@ def test_sender_cancel(self): payjoin.SenderBuilder(psbt, uri).build_recommended(1000).save(persister) ) cancel_transition = with_reply_key.cancel() - fallback_tx = cancel_transition.save(persister) - self.assertIsNotNone(fallback_tx) - self.assertTrue(len(fallback_tx) > 0) + pending_fallback = cancel_transition.save(persister) + self.assertIsNotNone(pending_fallback) + self.assertTrue(len(pending_fallback.fallback_tx()) > 0) result = payjoin.replay_sender_event_log(persister) - self.assertTrue(result.state().is_CLOSED()) + self.assertTrue(result.state().is_CANCELLED()) class TestSenderCancelAsync(unittest.TestCase): @@ -251,11 +251,11 @@ async def run_test(): .save_async(persister) ) cancel_transition = with_reply_key.cancel() - fallback_tx = await cancel_transition.save_async(persister) - self.assertIsNotNone(fallback_tx) - self.assertTrue(len(fallback_tx) > 0) + pending_fallback = await cancel_transition.save_async(persister) + self.assertIsNotNone(pending_fallback) + self.assertTrue(len(pending_fallback.fallback_tx()) > 0) result = await payjoin.replay_sender_event_log_async(persister) - self.assertTrue(result.state().is_CLOSED()) + self.assertTrue(result.state().is_CANCELLED()) asyncio.run(run_test()) diff --git a/payjoin-ffi/src/send/error.rs b/payjoin-ffi/src/send/error.rs index e0ef6022d..36914cf0f 100644 --- a/payjoin-ffi/src/send/error.rs +++ b/payjoin-ffi/src/send/error.rs @@ -201,3 +201,15 @@ where SenderPersistedError::Unexpected } } + +impl From> for SenderPersistedError +where + S: std::error::Error + Send + Sync + 'static, +{ + fn from(err: payjoin::persist::PersistedError) -> Self { + if let Some(storage_err) = err.storage_error() { + return SenderPersistedError::from(ImplementationError::new(storage_err)); + } + SenderPersistedError::Unexpected + } +} diff --git a/payjoin-ffi/src/send/mod.rs b/payjoin-ffi/src/send/mod.rs index d78863edc..fef623233 100644 --- a/payjoin-ffi/src/send/mod.rs +++ b/payjoin-ffi/src/send/mod.rs @@ -51,53 +51,48 @@ macro_rules! impl_save_for_transition { }; } -/// A terminal transition produced by cancelling a sender session. #[derive(uniffi::Object)] -pub struct SenderCancelTransition { - transition: RwLock< - Option< - payjoin::persist::TerminalTransition< - payjoin::send::v2::SessionEvent, - payjoin::bitcoin::Transaction, +#[allow(clippy::type_complexity)] +pub struct SenderCancelTransition( + Arc< + RwLock< + Option< + payjoin::persist::NextStateTransition< + payjoin::send::v2::SessionEvent, + payjoin::send::v2::Sender, + >, >, >, >, -} +); #[uniffi::export] impl SenderCancelTransition { - /// Persist the cancellation and return the fallback transaction. - /// - /// The fallback transaction is the consensus-encoded raw transaction bytes of - /// the sender's original transaction that should be broadcast to complete the - /// payment without Payjoin. pub fn save( &self, persister: Arc, - ) -> Result, SenderPersistedError> { + ) -> Result { let adapter = CallbackPersisterAdapter::new(persister); - let mut inner = self.transition.write().expect("Lock should not be poisoned"); + let mut inner = self.0.write().expect("Lock should not be poisoned"); let value = inner.take().expect("Already saved or moved"); - let fallback = value - .save(&adapter) - .map_err(|e| SenderPersistedError::from(ImplementationError::new(e)))?; - Ok(payjoin::bitcoin::consensus::serialize(&fallback)) + let res = value.save(&adapter).map_err(|e| ForeignError::InternalError(e.to_string()))?; + Ok(res.into()) } pub async fn save_async( &self, persister: Arc, - ) -> Result, SenderPersistedError> { + ) -> Result { let adapter = AsyncCallbackPersisterAdapter::new(persister); let value = { - let mut inner = self.transition.write().expect("Lock should not be poisoned"); + let mut inner = self.0.write().expect("Lock should not be poisoned"); inner.take().expect("Already saved or moved") }; - let fallback = value + let res = value .save_async(&adapter) .await - .map_err(|e| SenderPersistedError::from(ImplementationError::new(e)))?; - Ok(payjoin::bitcoin::consensus::serialize(&fallback)) + .map_err(|e| ForeignError::InternalError(e.to_string()))?; + Ok(res.into()) } } @@ -107,14 +102,15 @@ macro_rules! impl_cancel_for_sender { impl $ty { /// Cancel the Payjoin session immediately. /// - /// Returns a [`SenderCancelTransition`] that, once persisted, yields the fallback - /// transaction. The fallback transaction is the sender's original transaction - /// that should be broadcast to complete the payment without Payjoin. + /// Returns a [`SenderCancelTransition`] that, once persisted, yields a + /// [`PendingFallback`] state. Call [`PendingFallback::fallback_tx`] to get + /// the original transaction and [`PendingFallback::broadcasted`] after + /// broadcasting it. /// /// This is a terminal transition — the session cannot be used after cancellation. pub fn cancel(&self) -> SenderCancelTransition { let transition = self.0.clone().cancel(); - SenderCancelTransition { transition: RwLock::new(Some(transition)) } + SenderCancelTransition(Arc::new(RwLock::new(Some(transition)))) } } }; @@ -175,8 +171,8 @@ impl SenderSessionOutcome { matches!(self.0, payjoin::send::v2::SessionOutcome::Failure) } - pub fn is_cancelled(&self) -> bool { - matches!(self.0, payjoin::send::v2::SessionOutcome::Cancel) + pub fn is_aborted(&self) -> bool { + matches!(self.0, payjoin::send::v2::SessionOutcome::Aborted) } } @@ -184,6 +180,7 @@ impl SenderSessionOutcome { pub enum SendSession { WithReplyKey { inner: Arc }, PollingForProposal { inner: Arc }, + Cancelled { inner: Arc }, Closed { inner: Arc }, } @@ -195,6 +192,8 @@ impl From for SendSession { Self::WithReplyKey { inner: Arc::new(inner.into()) }, SendSession::PollingForProposal(inner) => Self::PollingForProposal { inner: Arc::new(inner.into()) }, + SendSession::PendingFallback(inner) => + Self::Cancelled { inner: Arc::new(inner.into()) }, SendSession::Closed(session_outcome) => Self::Closed { inner: Arc::new(session_outcome.into()) }, } @@ -645,6 +644,74 @@ impl PollingForProposal { } } +#[derive(Clone, uniffi::Object)] +pub struct PendingFallback(payjoin::send::v2::Sender); + +impl From> for PendingFallback { + fn from(value: payjoin::send::v2::Sender) -> Self { + Self(value) + } +} + +#[derive(uniffi::Object)] +#[allow(clippy::type_complexity)] +pub struct BroadcastedTransition( + Arc< + RwLock< + Option< + payjoin::persist::MaybeSuccessTransition< + payjoin::send::v2::SessionEvent, + (), + std::convert::Infallible, + >, + >, + >, + >, +); + +#[uniffi::export] +impl BroadcastedTransition { + pub fn save( + &self, + persister: Arc, + ) -> Result<(), SenderPersistedError> { + let adapter = CallbackPersisterAdapter::new(persister); + let mut inner = self.0.write().expect("Lock should not be poisoned"); + let value = inner.take().expect("Already saved or moved"); + value.save(&adapter).map_err(SenderPersistedError::from) + } + + pub async fn save_async( + &self, + persister: Arc, + ) -> Result<(), SenderPersistedError> { + let adapter = AsyncCallbackPersisterAdapter::new(persister); + let value = { + let mut inner = self.0.write().expect("Lock should not be poisoned"); + inner.take().expect("Already saved or moved") + }; + value.save_async(&adapter).await.map_err(SenderPersistedError::from) + } +} + +#[uniffi::export] +impl PendingFallback { + /// Returns the fallback transaction as consensus-encoded raw bytes. + /// + /// This is the sender's original transaction that should be broadcast to + /// complete the payment without Payjoin. + pub fn fallback_tx(&self) -> Vec { + payjoin::bitcoin::consensus::serialize(self.0.fallback_tx()) + } + + /// Indicate that the fallback transaction has been broadcast. + /// + /// Persist the returned [`BroadcastedTransition`] to close the session. + pub fn broadcasted(&self) -> BroadcastedTransition { + BroadcastedTransition(Arc::new(RwLock::new(Some(self.0.broadcasted())))) + } +} + /// Session persister that should save and load events as JSON strings. #[uniffi::export(with_foreign)] pub trait JsonSenderSessionPersister: Send + Sync { diff --git a/payjoin/src/core/send/v2/mod.rs b/payjoin/src/core/send/v2/mod.rs index f0a345967..61a9622de 100644 --- a/payjoin/src/core/send/v2/mod.rs +++ b/payjoin/src/core/send/v2/mod.rs @@ -46,8 +46,8 @@ use crate::error::{InternalReplayError, ReplayError}; use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeSecretKey}; use crate::ohttp::{ohttp_encapsulate, process_get_res, process_post_res}; use crate::persist::{ - MaybeFatalTransition, MaybeSuccessTransitionWithNoResults, NextStateTransition, - TerminalTransition, + MaybeFatalTransition, MaybeSuccessTransition, MaybeSuccessTransitionWithNoResults, + NextStateTransition, }; use crate::uri::v2::PjParam; use crate::uri::ShortId; @@ -246,17 +246,24 @@ impl Sender { } impl Sender { - /// Cancel the Payjoin session immediately. - /// - /// Returns a [`TerminalTransition`] that, once persisted, yields the fallback - /// transaction. The fallback transaction is the sender's original transaction that + /// Cancel the Payjoin session and once the transition is persisted, return a [`PendingFallback`] state. + /// The fallback transaction is the sender's original transaction that /// should be broadcast to complete the payment without Payjoin. - /// - /// This is a terminal transition — the session cannot be used after cancellation. - pub fn cancel(self) -> TerminalTransition { - let fallback = - self.session_context.psbt_ctx.original_psbt.clone().extract_tx_unchecked_fee_rate(); - TerminalTransition::new(SessionEvent::Closed(SessionOutcome::Cancel), fallback) + pub fn cancel(self) -> NextStateTransition> { + NextStateTransition::success( + SessionEvent::Cancelled(), + Sender { + state: PendingFallback { + fallback_tx: self + .session_context + .psbt_ctx + .original_psbt + .clone() + .extract_tx_unchecked_fee_rate(), + }, + session_context: self.session_context, + }, + ) } } @@ -268,6 +275,7 @@ impl Sender { pub enum SendSession { WithReplyKey(Sender), PollingForProposal(Sender), + PendingFallback(Sender), Closed(SessionOutcome), } @@ -287,6 +295,30 @@ impl SendSession { SendSession::PollingForProposal(_state), SessionEvent::Closed(SessionOutcome::Success(proposal)), ) => Ok(SendSession::Closed(SessionOutcome::Success(proposal))), + (SendSession::WithReplyKey(state), SessionEvent::Cancelled()) => + Ok(SendSession::PendingFallback(Sender { + state: PendingFallback { + fallback_tx: state + .session_context + .psbt_ctx + .original_psbt + .clone() + .extract_tx_unchecked_fee_rate(), + }, + session_context: state.session_context, + })), + (SendSession::PollingForProposal(state), SessionEvent::Cancelled()) => + Ok(SendSession::PendingFallback(Sender { + state: PendingFallback { + fallback_tx: state + .session_context + .psbt_ctx + .original_psbt + .clone() + .extract_tx_unchecked_fee_rate(), + }, + session_context: state.session_context, + })), (_, SessionEvent::Closed(session_outcome)) => Ok(SendSession::Closed(session_outcome)), (current_state, event) => Err(InternalReplayError::InvalidEvent( Box::new(event), @@ -555,6 +587,23 @@ impl Sender { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingFallback { + fallback_tx: bitcoin::Transaction, +} + +impl Sender { + /// Returns the fallback transaction that should be broadcast to complete the payment without Payjoin. + pub fn fallback_tx(&self) -> &bitcoin::Transaction { &self.fallback_tx } + + /// Indicates that the fallback transaction has been broadcast and the session is complete. + pub fn broadcasted( + &self, + ) -> MaybeSuccessTransition { + MaybeSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Aborted), ()) + } +} + #[cfg(test)] mod test { use std::str::FromStr; @@ -732,7 +781,12 @@ mod test { .cancel() .save(&persister) .expect("save should succeed"); - assert_eq!(fallback, expected_tx, "cancel from {}", stringify!($state)); + assert_eq!( + *fallback.fallback_tx(), + expected_tx, + "cancel from {}", + stringify!($state) + ); }}; } diff --git a/payjoin/src/core/send/v2/session.rs b/payjoin/src/core/send/v2/session.rs index 1f0f8de18..524f42972 100644 --- a/payjoin/src/core/send/v2/session.rs +++ b/payjoin/src/core/send/v2/session.rs @@ -130,7 +130,7 @@ impl SessionHistory { match self.events.last() { Some(SessionEvent::Closed(outcome)) => match outcome { SessionOutcome::Success(_) => SessionStatus::Completed, - SessionOutcome::Failure | SessionOutcome::Cancel => SessionStatus::Failed, + SessionOutcome::Failure | SessionOutcome::Aborted => SessionStatus::Failed, }, _ => SessionStatus::Active, } @@ -153,6 +153,8 @@ pub enum SessionEvent { Created(Box), /// Sender POSTed the Original PSBT and is waiting to receive a Proposal PSBT PostedOriginalPsbt(), + /// User initiated cancellation of the session + Cancelled(), /// Closed successful or failed session Closed(SessionOutcome), } @@ -164,8 +166,8 @@ pub enum SessionOutcome { Success(bitcoin::Psbt), /// Payjoin failed to complete due to a counterparty deviation from the protocol Failure, - /// Payjoin was cancelled by the user - Cancel, + /// Payjoin was aborted by the user + Aborted, } #[cfg(test)] @@ -221,7 +223,8 @@ mod tests { SessionEvent::PostedOriginalPsbt(), SessionEvent::Closed(SessionOutcome::Success(PARSED_ORIGINAL_PSBT.clone())), SessionEvent::Closed(SessionOutcome::Failure), - SessionEvent::Closed(SessionOutcome::Cancel), + SessionEvent::Closed(SessionOutcome::Aborted), + SessionEvent::Cancelled(), ]; for event in test_cases {