diff --git a/zebrad/src/commands/start.rs b/zebrad/src/commands/start.rs index ac0c93c525c..f44e090e40c 100644 --- a/zebrad/src/commands/start.rs +++ b/zebrad/src/commands/start.rs @@ -28,11 +28,15 @@ use color_eyre::eyre::{eyre, Report}; use futures::{select, FutureExt}; use tokio::sync::oneshot; use tower::builder::ServiceBuilder; +use tower::util::BoxService; -use crate::components::{tokio::RuntimeRun, Inbound}; -use crate::config::ZebradConfig; use crate::{ - components::{mempool, tokio::TokioComponent, ChainSync}, + components::{ + mempool::{self, Mempool}, + tokio::{RuntimeRun, TokioComponent}, + ChainSync, Inbound, + }, + config::ZebradConfig, prelude::*, }; @@ -65,7 +69,8 @@ impl StartCmd { .await; info!("initializing mempool"); - let mempool = mempool::Mempool::new(config.network.network); + let mempool_service = BoxService::new(Mempool::new(config.network.network)); + let mempool = ServiceBuilder::new().buffer(20).service(mempool_service); info!("initializing network"); // The service that our node uses to respond to requests by peers. The @@ -80,7 +85,7 @@ impl StartCmd { state.clone(), chain_verifier.clone(), tx_verifier.clone(), - mempool, + mempool.clone(), )); let (peer_set, address_book) = diff --git a/zebrad/src/components/inbound.rs b/zebrad/src/components/inbound.rs index ece90903a11..2e54a4cf87f 100644 --- a/zebrad/src/components/inbound.rs +++ b/zebrad/src/components/inbound.rs @@ -21,8 +21,11 @@ use zebra_consensus::transaction; use zebra_consensus::{chain::VerifyChainError, error::TransactionError}; use zebra_network::AddressBook; -use super::mempool::downloads::{ - Downloads as TxDownloads, TRANSACTION_DOWNLOAD_TIMEOUT, TRANSACTION_VERIFY_TIMEOUT, +use super::mempool::{ + self as mp, + downloads::{ + Downloads as TxDownloads, TRANSACTION_DOWNLOAD_TIMEOUT, TRANSACTION_VERIFY_TIMEOUT, + }, }; // Re-use the syncer timeouts for consistency. use super::{ @@ -38,13 +41,14 @@ use downloads::Downloads as BlockDownloads; type Outbound = Buffer, zn::Request>; type State = Buffer, zs::Request>; +type Mempool = Buffer, mp::Request>; type BlockVerifier = Buffer, block::Hash, VerifyChainError>, Arc>; type TxVerifier = Buffer< BoxService, transaction::Request, >; type InboundBlockDownloads = BlockDownloads, Timeout, State>; -type InboundTxDownloads = TxDownloads, Timeout, State>; +type InboundTxDownloads = TxDownloads, Timeout, State, Mempool>; pub type NetworkSetupData = (Outbound, Arc>); @@ -134,7 +138,7 @@ pub struct Inbound { state: State, /// A service that manages transactions in the memory pool. - mempool: mempool::Mempool, + mempool: Mempool, } impl Inbound { @@ -143,7 +147,7 @@ impl Inbound { state: State, block_verifier: BlockVerifier, tx_verifier: TxVerifier, - mempool: mempool::Mempool, + mempool: Mempool, ) -> Self { Self { network_setup: Setup::AwaitingNetwork { @@ -195,6 +199,7 @@ impl Service for Inbound { Timeout::new(outbound, TRANSACTION_DOWNLOAD_TIMEOUT), Timeout::new(tx_verifier, TRANSACTION_VERIFY_TIMEOUT), self.state.clone(), + self.mempool.clone(), )); result = Ok(()); Setup::Initialized { @@ -350,6 +355,7 @@ impl Service for Inbound { zn::Request::PushTransaction(_transaction) => { debug!("ignoring unimplemented request"); // TODO: send to Tx Download & Verify Stream + // https://github.com/ZcashFoundation/zebra/issues/2692 async { Ok(zn::Response::Nil) }.boxed() } zn::Request::AdvertiseTransactionIds(transactions) => { diff --git a/zebrad/src/components/inbound/tests.rs b/zebrad/src/components/inbound/tests.rs index bced22afb58..857cdf285e7 100644 --- a/zebrad/src/components/inbound/tests.rs +++ b/zebrad/src/components/inbound/tests.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use super::mempool::{unmined_transactions_in_blocks, Mempool}; use tokio::sync::oneshot; -use tower::{builder::ServiceBuilder, ServiceExt}; +use tower::{builder::ServiceBuilder, util::BoxService, ServiceExt}; use zebra_chain::{ parameters::Network, @@ -26,6 +26,9 @@ async fn mempool_requests_for_transactions() { let added_transactions = add_some_stuff_to_mempool(&mut mempool_service, network); let added_transaction_ids: Vec = added_transactions.iter().map(|t| t.id).collect(); + let mempool_service = BoxService::new(mempool_service); + let mempool = ServiceBuilder::new().buffer(1).service(mempool_service); + let (block_verifier, transaction_verifier) = zebra_consensus::chain::init(consensus_config.clone(), network, state_service.clone()) .await; @@ -39,7 +42,7 @@ async fn mempool_requests_for_transactions() { state_service, block_verifier.clone(), transaction_verifier.clone(), - mempool_service, + mempool, )); // Test `Request::MempoolTransactionIds` diff --git a/zebrad/src/components/mempool.rs b/zebrad/src/components/mempool.rs index b1d3f0e323c..7bec123aebe 100644 --- a/zebrad/src/components/mempool.rs +++ b/zebrad/src/components/mempool.rs @@ -15,7 +15,7 @@ use zebra_chain::{ transaction::{UnminedTx, UnminedTxId}, }; -use crate::BoxError; +pub use crate::BoxError; mod crawler; pub mod downloads; @@ -35,12 +35,14 @@ pub use self::storage::tests::unmined_transactions_in_blocks; pub enum Request { TransactionIds, TransactionsById(HashSet), + RejectedTransactionIds(HashSet), } #[derive(Debug)] pub enum Response { Transactions(Vec), TransactionIds(Vec), + RejectedTransactionIds(Vec), } /// Mempool async management and query service. @@ -93,6 +95,11 @@ impl Service for Mempool { let rsp = Ok(self.storage.clone().transactions(ids)).map(Response::Transactions); async move { rsp }.boxed() } + Request::RejectedTransactionIds(ids) => { + let rsp = Ok(self.storage.clone().rejected_transactions(ids)) + .map(Response::RejectedTransactionIds); + async move { rsp }.boxed() + } } } } diff --git a/zebrad/src/components/mempool/downloads.rs b/zebrad/src/components/mempool/downloads.rs index 56b51bd2daa..14b5f3ba21a 100644 --- a/zebrad/src/components/mempool/downloads.rs +++ b/zebrad/src/components/mempool/downloads.rs @@ -21,6 +21,7 @@ use zebra_consensus::transaction as tx; use zebra_network as zn; use zebra_state as zs; +use crate::components::mempool as mp; use crate::components::sync::{BLOCK_DOWNLOAD_TIMEOUT, BLOCK_VERIFY_TIMEOUT}; type BoxError = Box; @@ -85,7 +86,7 @@ pub enum DownloadAction { /// Represents a [`Stream`] of download and verification tasks. #[pin_project] #[derive(Debug)] -pub struct Downloads +pub struct Downloads where ZN: Service + Send + 'static, ZN::Future: Send, @@ -93,6 +94,8 @@ where ZV::Future: Send, ZS: Service + Send + Clone + 'static, ZS::Future: Send, + ZM: Service + Send + Clone + 'static, + ZM::Future: Send, { // Services /// A service that forwards requests to connected peers, and returns their @@ -105,6 +108,9 @@ where /// A service that manages cached blockchain state. state: ZS, + /// A service that manages the mempool. + mempool: ZM, + // Internal downloads state /// A list of pending transaction download and verify tasks. #[pin] @@ -115,7 +121,7 @@ where cancel_handles: HashMap>, } -impl Stream for Downloads +impl Stream for Downloads where ZN: Service + Send + Clone + 'static, ZN::Future: Send, @@ -123,6 +129,8 @@ where ZV::Future: Send, ZS: Service + Send + Clone + 'static, ZS::Future: Send, + ZM: Service + Send + Clone + 'static, + ZM::Future: Send, { type Item = Result; @@ -158,7 +166,7 @@ where } } -impl Downloads +impl Downloads where ZN: Service + Send + Clone + 'static, ZN::Future: Send, @@ -166,6 +174,8 @@ where ZV::Future: Send, ZS: Service + Send + Clone + 'static, ZS::Future: Send, + ZM: Service + Send + Clone + 'static, + ZM::Future: Send, { /// Initialize a new download stream with the provided `network` and /// `verifier` services. @@ -173,11 +183,12 @@ where /// The [`Downloads`] stream is agnostic to the network policy, so retry and /// timeout limits should be applied to the `network` service passed into /// this constructor. - pub fn new(network: ZN, verifier: ZV, state: ZS) -> Self { + pub fn new(network: ZN, verifier: ZV, state: ZS, mempool: ZM) -> Self { Self { network, verifier, state, + mempool, pending: FuturesUnordered::new(), cancel_handles: HashMap::new(), } @@ -213,19 +224,11 @@ where let network = self.network.clone(); let verifier = self.verifier.clone(); - let state = self.state.clone(); + let mut state = self.state.clone(); + let mut mempool = self.mempool.clone(); let fut = async move { - // TODO: adapt this for transaction / mempool - // // Check if the block is already in the state. - // // BUG: check if the hash is in any chain (#862). - // // Depth only checks the main chain. - // match state.oneshot(zs::Request::Depth(hash)).await { - // Ok(zs::Response::Depth(None)) => Ok(()), - // Ok(zs::Response::Depth(Some(_))) => Err("already present".into()), - // Ok(_) => unreachable!("wrong response"), - // Err(e) => Err(e), - // }?; + Self::should_download(&mut state, &mut mempool, txid).await?; let height = match state.oneshot(zs::Request::Tip).await { Ok(zs::Response::Tip(None)) => Err("no block at the tip".into()), @@ -298,4 +301,67 @@ where DownloadAction::AddedToQueue } + + /// Check if transaction should be downloaded and verified. + /// + /// If it is already in the mempool (or in its rejected list) + /// or in state, then it shouldn't be downloaded (and an error is returned). + async fn should_download( + state: &mut ZS, + mempool: &mut ZM, + txid: UnminedTxId, + ) -> Result<(), BoxError> { + // Check if the transaction is already in the mempool. + match mempool + .ready_and() + .await? + .call(mp::Request::TransactionsById( + [txid].iter().cloned().collect(), + )) + .await + { + Ok(mp::Response::Transactions(txs)) => { + if txs.is_empty() { + Ok(()) + } else { + Err("already present in mempool".into()) + } + } + Ok(_) => unreachable!("wrong response"), + Err(e) => Err(e), + }?; + + // Check if the transaction is in the mempool rejected list. + match mempool + .oneshot(mp::Request::RejectedTransactionIds( + [txid].iter().cloned().collect(), + )) + .await + { + Ok(mp::Response::RejectedTransactionIds(txs)) => { + if txs.is_empty() { + Ok(()) + } else { + Err("in mempool rejected list".into()) + } + } + Ok(_) => unreachable!("wrong response"), + Err(e) => Err(e), + }?; + + // Check if the transaction is already in the state. + match state + .ready_and() + .await? + .call(zs::Request::Transaction(txid.mined_id())) + .await + { + Ok(zs::Response::Transaction(None)) => Ok(()), + Ok(zs::Response::Transaction(Some(_))) => Err("already present in state".into()), + Ok(_) => unreachable!("wrong response"), + Err(e) => Err(e), + }?; + + Ok(()) + } } diff --git a/zebrad/src/components/mempool/storage.rs b/zebrad/src/components/mempool/storage.rs index 0aa3727b754..08f9e1b4478 100644 --- a/zebrad/src/components/mempool/storage.rs +++ b/zebrad/src/components/mempool/storage.rs @@ -103,4 +103,12 @@ impl Storage { .filter(|tx| tx_ids.contains(&tx.id)) .collect() } + + /// Returns the set of [`UnminedTxId`]s matching ids in the rejected list. + pub fn rejected_transactions(self, tx_ids: HashSet) -> Vec { + tx_ids + .into_iter() + .filter(|tx| self.rejected.contains_key(tx)) + .collect() + } } diff --git a/zebrad/src/components/mempool/storage/tests.rs b/zebrad/src/components/mempool/storage/tests.rs index 332dc059948..803c40b5db8 100644 --- a/zebrad/src/components/mempool/storage/tests.rs +++ b/zebrad/src/components/mempool/storage/tests.rs @@ -47,6 +47,19 @@ fn mempool_storage_basic_for_network(network: Network) -> Result<()> { assert!(!storage.clone().contains(&tx.id)); } + // Query all the ids we have for rejected, get back `total - MEMPOOL_SIZE` + let all_ids: HashSet = unmined_transactions.iter().map(|tx| tx.id).collect(); + let rejected_ids: HashSet = unmined_transactions + .iter() + .take(total_transactions - MEMPOOL_SIZE) + .map(|tx| tx.id) + .collect(); + // Convert response to a `HashSet` as we need a fixed order to compare. + let rejected_response: HashSet = + storage.rejected_transactions(all_ids).into_iter().collect(); + + assert_eq!(rejected_response, rejected_ids); + Ok(()) } diff --git a/zebrad/src/components/mempool/tests.rs b/zebrad/src/components/mempool/tests.rs index 57c8070457a..a8634d5169b 100644 --- a/zebrad/src/components/mempool/tests.rs +++ b/zebrad/src/components/mempool/tests.rs @@ -22,15 +22,21 @@ async fn mempool_service_basic() -> Result<(), Report> { .oneshot(Request::TransactionIds) .await .unwrap(); - let transaction_ids = match response { + let genesis_transaction_ids = match response { Response::TransactionIds(ids) => ids, _ => unreachable!("will never happen in this test"), }; // Test `Request::TransactionsById` - let hash_set = transaction_ids.iter().copied().collect::>(); + let genesis_transactions_hash_set = genesis_transaction_ids + .iter() + .copied() + .collect::>(); let response = service - .oneshot(Request::TransactionsById(hash_set)) + .clone() + .oneshot(Request::TransactionsById( + genesis_transactions_hash_set.clone(), + )) .await .unwrap(); let transactions = match response { @@ -42,5 +48,26 @@ async fn mempool_service_basic() -> Result<(), Report> { // response of `Request::TransactionsById` assert_eq!(genesis_transactions.1[0], transactions[0]); + // Insert more transactions into the mempool storage. + // This will cause the genesis transaction to be moved into rejected. + let more_transactions = unmined_transactions_in_blocks(10, network); + for tx in more_transactions.1.iter().skip(1) { + service.storage.insert(tx.clone())?; + } + + // Test `Request::RejectedTransactionIds` + let response = service + .oneshot(Request::RejectedTransactionIds( + genesis_transactions_hash_set, + )) + .await + .unwrap(); + let rejected_ids = match response { + Response::RejectedTransactionIds(ids) => ids, + _ => unreachable!("will never happen in this test"), + }; + + assert_eq!(rejected_ids, genesis_transaction_ids); + Ok(()) }