diff --git a/cli/src/modules/rpc.rs b/cli/src/modules/rpc.rs index 6e42922b14..4a8d3b959a 100644 --- a/cli/src/modules/rpc.rs +++ b/cli/src/modules/rpc.rs @@ -2,6 +2,15 @@ use crate::imports::*; use convert_case::{Case, Casing}; use kaspa_rpc_core::api::ops::RpcApiOps; +fn parse_get_headers_direction(direction: Option<&str>) -> Result { + match direction.map(str::to_ascii_lowercase).as_deref() { + None => Ok(true), + Some("ascending" | "asc" | "true" | "1") => Ok(true), + Some("descending" | "desc" | "false" | "0") => Ok(false), + Some(_) => Err(Error::custom("Direction must be one of: ascending, asc, true, 1, descending, desc, false, 0")), + } +} + #[derive(Default, Handler)] #[help("Execute RPC commands against the connected Kaspa node")] pub struct Rpc; @@ -162,10 +171,16 @@ impl Rpc { let result = rpc.shutdown_call(None, ShutdownRequest {}).await?; self.println(&ctx, result); } - // RpcApiOps::GetHeaders => { - // let result = rpc.get_headers_call(GetHeadersRequest { }).await?; - // self.println(&ctx, result); - // } + RpcApiOps::GetHeaders => { + if argv.len() < 2 { + return Err(Error::custom("Usage: rpc get-headers [ascending|descending|true|false]")); + } + let start_hash = RpcHash::from_hex(argv.remove(0).as_str())?; + let limit = argv.remove(0).parse::()?; + let is_ascending = parse_get_headers_direction(argv.first().map(String::as_str))?; + let result = rpc.get_headers_call(None, GetHeadersRequest { start_hash, limit, is_ascending }).await?; + self.println(&ctx, result); + } RpcApiOps::GetUtxosByAddresses => { if argv.is_empty() { return Err(Error::custom("Please specify at least one address")); @@ -340,3 +355,29 @@ impl Rpc { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn get_headers_direction_defaults_to_ascending() { + assert!(parse_get_headers_direction(None).unwrap()); + } + + #[test] + fn get_headers_direction_accepts_aliases() { + for direction in ["ascending", "asc", "true", "1"] { + assert!(parse_get_headers_direction(Some(direction)).unwrap()); + } + + for direction in ["descending", "desc", "false", "0"] { + assert!(!parse_get_headers_direction(Some(direction)).unwrap()); + } + } + + #[test] + fn get_headers_direction_rejects_unknown_values() { + assert!(parse_get_headers_direction(Some("sideways")).is_err()); + } +} diff --git a/components/consensusmanager/src/session.rs b/components/consensusmanager/src/session.rs index e1b183d298..3bbe641f01 100644 --- a/components/consensusmanager/src/session.rs +++ b/components/consensusmanager/src/session.rs @@ -338,7 +338,7 @@ impl ConsensusSessionOwned { .await } - /// Returns the antipast of block `hash` from the POV of `context`, i.e. `antipast(hash) ∩ past(context)`. + /// Returns the antipast of block `hash` from the POV of `context`, i.e. the intersection of `antipast(hash)` and `past(context)`. /// Since this might be an expensive operation for deep blocks, we allow the caller to specify a limit /// `max_traversal_allowed` on the maximum amount of blocks to traverse for obtaining the answer pub async fn async_get_antipast_from_pov( @@ -367,6 +367,11 @@ impl ConsensusSessionOwned { self.clone().spawn_blocking(move |c| c.create_virtual_selected_chain_block_locator(low, high)).await } + /// Returns up to `limit` hashes from `start` toward the virtual selected-chain tip, including `start`. + pub async fn async_get_virtual_selected_chain_from(&self, start: Hash, limit: usize) -> ConsensusResult> { + self.clone().spawn_blocking(move |c| c.get_virtual_selected_chain_from(start, limit)).await + } + pub async fn async_create_block_locator_from_pruning_point(&self, high: Hash, limit: usize) -> ConsensusResult> { self.clone().spawn_blocking(move |c| c.create_block_locator_from_pruning_point(high, limit)).await } diff --git a/consensus/core/src/api/mod.rs b/consensus/core/src/api/mod.rs index 0215f3a994..005af23279 100644 --- a/consensus/core/src/api/mod.rs +++ b/consensus/core/src/api/mod.rs @@ -275,7 +275,7 @@ pub trait ConsensusApi: Send + Sync { unimplemented!() } - /// Returns the antipast of block `hash` from the POV of `context`, i.e. `antipast(hash) ∩ past(context)`. + /// Returns the antipast of block `hash` from the POV of `context`, i.e. the intersection of `antipast(hash)` and `past(context)`. /// Since this might be an expensive operation for deep blocks, we allow the caller to specify a limit /// `max_traversal_allowed` on the maximum amount of blocks to traverse for obtaining the answer fn get_antipast_from_pov(&self, hash: Hash, context: Hash, max_traversal_allowed: Option) -> ConsensusResult> { @@ -295,6 +295,11 @@ pub trait ConsensusApi: Send + Sync { unimplemented!() } + /// Returns up to `limit` hashes from `start` toward the virtual selected-chain tip, including `start`. + fn get_virtual_selected_chain_from(&self, start: Hash, limit: usize) -> ConsensusResult> { + unimplemented!() + } + fn create_block_locator_from_pruning_point(&self, high: Hash, limit: usize) -> ConsensusResult> { unimplemented!() } diff --git a/consensus/src/consensus/mod.rs b/consensus/src/consensus/mod.rs index 81d92750df..77cda757dc 100644 --- a/consensus/src/consensus/mod.rs +++ b/consensus/src/consensus/mod.rs @@ -61,6 +61,7 @@ use kaspa_consensus_core::{ consensus::{ConsensusError, ConsensusResult}, difficulty::DifficultyError, pruning::PruningImportError, + sync::SyncManagerError, tx::TxResult, }, header::Header, @@ -1159,6 +1160,32 @@ impl ConsensusApi for Consensus { Ok(self.services.sync_manager.create_virtual_selected_chain_block_locator(low, high)?) } + fn get_virtual_selected_chain_from(&self, start: Hash, limit: usize) -> ConsensusResult> { + let _guard = self.pruning_lock.blocking_read(); + self.validate_block_exists(start)?; + + if limit == 0 { + return Ok(Vec::new()); + } + + let selected_chain = self.storage.selected_chain_store.read(); + let start_index = + selected_chain.get_by_hash(start).optional().unwrap().ok_or(SyncManagerError::BlockNotInSelectedParentChain(start))?; + let (tip_index, _) = + selected_chain.get_tip().map_err(|err| ConsensusError::GeneralOwned(format!("selected chain tip read failed: {err}")))?; + let limit_end_index = start_index.saturating_add(limit.saturating_sub(1) as u64); + let end_index = cmp::min(tip_index, limit_end_index); + let mut hashes = Vec::with_capacity((end_index - start_index + 1) as usize); + for index in start_index..=end_index { + hashes.push( + selected_chain + .get_by_index(index) + .map_err(|err| ConsensusError::GeneralOwned(format!("selected chain hash read failed: {err}")))?, + ); + } + Ok(hashes) + } + fn pruning_point_headers(&self) -> Vec> { // PRUNE SAFETY: index is monotonic and past pruning point headers are expected permanently let (pruning_point, pruning_index) = self.pruning_point_store.read().pruning_point_and_index().unwrap(); diff --git a/rpc/core/src/api/rpc.rs b/rpc/core/src/api/rpc.rs index 649e883250..767d8ff007 100644 --- a/rpc/core/src/api/rpc.rs +++ b/rpc/core/src/api/rpc.rs @@ -312,7 +312,9 @@ pub trait RpcApi: Sync + Send + AnySync { } async fn shutdown_call(&self, connection: Option<&DynRpcConnection>, request: ShutdownRequest) -> RpcResult; - /// Requests headers between the given `start_hash` and the current virtual, up to the given limit. + /// Requests selected-parent-chain headers from `start_hash`, up to the given inclusive limit. + /// + /// Ascending requests walk toward the sink. Descending requests walk toward genesis. async fn get_headers(&self, start_hash: RpcHash, limit: u64, is_ascending: bool) -> RpcResult> { Ok(self.get_headers_call(None, GetHeadersRequest::new(start_hash, limit, is_ascending)).await?.headers) } diff --git a/rpc/core/src/model/message.rs b/rpc/core/src/model/message.rs index ed3f9b43ce..8a032f02c0 100644 --- a/rpc/core/src/model/message.rs +++ b/rpc/core/src/model/message.rs @@ -1215,8 +1215,11 @@ impl Deserializer for ShutdownResponse { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetHeadersRequest { + /// The selected-parent-chain header where traversal starts. pub start_hash: RpcHash, + /// The maximum number of headers to return, including `start_hash`. pub limit: u64, + /// When true, walk toward the sink; when false, walk toward genesis. pub is_ascending: bool, } diff --git a/rpc/core/src/wasm/message.rs b/rpc/core/src/wasm/message.rs index 9afa292bee..5de07c8fb5 100644 --- a/rpc/core/src/wasm/message.rs +++ b/rpc/core/src/wasm/message.rs @@ -1003,7 +1003,11 @@ declare! { IGetHeadersRequest, r#" /** + * Selected-parent-chain header request. * + * If `isAscending` is true, traversal walks from `startHash` toward the + * sink. Otherwise, traversal walks from `startHash` toward genesis. + * `limit` includes `startHash` when greater than zero. * * @category Node RPC */ diff --git a/rpc/grpc/core/proto/rpc.proto b/rpc/grpc/core/proto/rpc.proto index b9e283959e..79b3ec6981 100644 --- a/rpc/grpc/core/proto/rpc.proto +++ b/rpc/grpc/core/proto/rpc.proto @@ -489,8 +489,9 @@ message ShutdownResponseMessage { RPCError error = 1000; } -// GetHeadersRequestMessage requests headers between the given startHash and the -// current virtual, up to the given limit. +// GetHeadersRequestMessage requests selected-parent-chain headers from +// startHash, up to the inclusive limit. Ascending requests walk toward the +// sink; descending requests walk toward genesis. message GetHeadersRequestMessage { string startHash = 1; uint64 limit = 2; @@ -499,6 +500,7 @@ message GetHeadersRequestMessage { message GetHeadersResponseMessage { repeated string headers = 1; + repeated RpcBlockHeader blockHeaders = 2; RPCError error = 1000; } diff --git a/rpc/grpc/core/src/convert/message.rs b/rpc/grpc/core/src/convert/message.rs index d60acb3c56..0a3dba2301 100644 --- a/rpc/grpc/core/src/convert/message.rs +++ b/rpc/grpc/core/src/convert/message.rs @@ -327,7 +327,11 @@ from!(item: &kaspa_rpc_core::GetHeadersRequest, protowire::GetHeadersRequestMess Self { start_hash: item.start_hash.to_string(), limit: item.limit, is_ascending: item.is_ascending } }); from!(item: RpcResult<&kaspa_rpc_core::GetHeadersResponse>, protowire::GetHeadersResponseMessage, { - Self { headers: item.headers.iter().map(|x| x.hash.to_string()).collect(), error: None } + Self { + headers: item.headers.iter().map(|x| x.hash.to_string()).collect(), + block_headers: item.headers.iter().map(protowire::RpcBlockHeader::from).collect(), + error: None, + } }); from!(item: &kaspa_rpc_core::GetUtxosByAddressesRequest, protowire::GetUtxosByAddressesRequestMessage, { @@ -848,8 +852,19 @@ try_from!(item: &protowire::GetHeadersRequestMessage, kaspa_rpc_core::GetHeaders Self { start_hash: RpcHash::from_str(&item.start_hash)?, limit: item.limit, is_ascending: item.is_ascending } }); try_from!(item: &protowire::GetHeadersResponseMessage, RpcResult, { - // TODO - Self { headers: vec![] } + if !item.headers.is_empty() && item.block_headers.is_empty() { + return Err(RpcError::General("get headers response contains only header hashes without full headers".to_string())); + } + let headers = item.block_headers.iter().map(kaspa_rpc_core::RpcHeader::try_from).collect::>>()?; + if item.headers.len() != headers.len() { + return Err(RpcError::General("get headers response has inconsistent legacy hashes and full headers".to_string())); + } + for (legacy_hash, header) in item.headers.iter().zip(headers.iter()) { + if RpcHash::from_str(legacy_hash)? != header.hash { + return Err(RpcError::General("get headers response has inconsistent legacy hashes and full headers".to_string())); + } + } + Self { headers } }); try_from!(item: &protowire::GetUtxosByAddressesRequestMessage, kaspa_rpc_core::GetUtxosByAddressesRequest, { @@ -1099,9 +1114,112 @@ try_from!(&protowire::NotifySinkBlueScoreChangedResponseMessage, RpcResult RpcHash { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(1); + let c = COUNTER.fetch_add(1, Ordering::Relaxed); + RpcHash::from_u64_word(c) + } + + fn rpc_header() -> RpcHeader { + Header::new_finalized( + 0, + vec![vec![new_unique()]].try_into().unwrap(), + new_unique(), + new_unique(), + new_unique(), + 123, + 456, + 789, + 101_112, + 131_415.into(), + 161_718, + new_unique(), + ) + .into() + } + + fn legacy_header_hash() -> String { + new_unique().to_string() + } + + #[test] + fn get_headers_response_accepts_empty_success() { + let protowire = GetHeadersResponseMessage { headers: Vec::new(), block_headers: Vec::new(), error: None }; + + let rpc_core: RpcResult = (&protowire).try_into(); + + assert!(matches!(rpc_core, Ok(GetHeadersResponse { headers }) if headers.is_empty())); + } + + #[test] + fn get_headers_response_rejects_hashes_without_full_headers() { + let protowire = GetHeadersResponseMessage { headers: vec![legacy_header_hash()], block_headers: Vec::new(), error: None }; - use crate::protowire::{self, SubmitBlockResponseMessage, submit_block_response_message::RejectReason}; + let rpc_core: RpcResult = (&protowire).try_into(); + + assert!( + matches!(rpc_core, Err(RpcError::General(message)) if message == "get headers response contains only header hashes without full headers") + ); + } + + #[test] + fn get_headers_response_accepts_matching_legacy_hashes_and_full_headers() { + let header = rpc_header(); + let protowire = + GetHeadersResponseMessage { headers: vec![header.hash.to_string()], block_headers: vec![(&header).into()], error: None }; + + let rpc_core: RpcResult = (&protowire).try_into(); + + let headers = rpc_core.unwrap().headers; + assert_eq!(headers.len(), 1); + assert_eq!(headers[0].hash, header.hash); + } + + #[test] + fn get_headers_response_rejects_legacy_hash_count_mismatch() { + let header = rpc_header(); + let protowire = GetHeadersResponseMessage { headers: Vec::new(), block_headers: vec![(&header).into()], error: None }; + + let rpc_core: RpcResult = (&protowire).try_into(); + + assert!( + matches!(rpc_core, Err(RpcError::General(message)) if message == "get headers response has inconsistent legacy hashes and full headers") + ); + } + + #[test] + fn get_headers_response_rejects_legacy_hash_mismatch() { + let header = rpc_header(); + let protowire = + GetHeadersResponseMessage { headers: vec![legacy_header_hash()], block_headers: vec![(&header).into()], error: None }; + + let rpc_core: RpcResult = (&protowire).try_into(); + + assert!( + matches!(rpc_core, Err(RpcError::General(message)) if message == "get headers response has inconsistent legacy hashes and full headers") + ); + } + + #[test] + fn get_headers_response_error_takes_precedence_over_payload() { + let protowire = GetHeadersResponseMessage { + headers: vec![legacy_header_hash()], + block_headers: Vec::new(), + error: Some(protowire::RpcError { message: "upstream error".to_string() }), + }; + + let rpc_core: RpcResult = (&protowire).try_into(); + + assert!(matches!(rpc_core, Err(RpcError::General(message)) if message == "upstream error")); + } #[test] fn test_submit_block_response() { diff --git a/rpc/service/src/service.rs b/rpc/service/src/service.rs index ee64be96c1..8ad9ed06db 100644 --- a/rpc/service/src/service.rs +++ b/rpc/service/src/service.rs @@ -125,6 +125,7 @@ pub struct RpcCoreService { } const RPC_CORE: &str = "rpc-core"; +const GET_HEADERS_MAX_WINDOW_SIZE: u32 = MAX_SAFE_WINDOW_SIZE; impl RpcCoreService { pub const IDENT: &'static str = "rpc-core-service"; @@ -924,9 +925,45 @@ NOTE: This error usually indicates an RPC conversion error between the node and async fn get_headers_call( &self, _connection: Option<&DynRpcConnection>, - _request: GetHeadersRequest, + request: GetHeadersRequest, ) -> RpcResult { - Err(RpcError::NotImplemented) + if request.limit > u64::from(GET_HEADERS_MAX_WINDOW_SIZE) { + let requested_limit = request.limit.min(u64::from(u32::MAX)) as u32; + return Err(RpcError::WindowSizeExceedingMaximum(requested_limit, GET_HEADERS_MAX_WINDOW_SIZE)); + } + + let requested_limit = request.limit as u32; + let session = self.consensus_manager.consensus().session().await; + session.async_get_header(request.start_hash).await?; + + let limit = requested_limit as usize; + if limit == 0 { + return Ok(GetHeadersResponse::new(Vec::new())); + } + + let hashes = if request.is_ascending { + session.async_get_virtual_selected_chain_from(request.start_hash, limit).await? + } else { + session.async_get_virtual_selected_chain_from(request.start_hash, 1).await?; + let mut hashes = Vec::with_capacity(limit); + let mut current = request.start_hash; + while hashes.len() < limit { + hashes.push(current); + if current == self.config.genesis.hash { + break; + } + current = session.async_get_ghostdag_data(current).await?.selected_parent; + } + hashes + }; + + let mut headers = Vec::with_capacity(hashes.len()); + for hash in hashes { + let header = session.async_get_header(hash).await?; + headers.push(header.as_ref().into()); + } + + Ok(GetHeadersResponse::new(headers)) } async fn get_block_dag_info_call( diff --git a/rpc/wrpc/wasm/src/client.rs b/rpc/wrpc/wasm/src/client.rs index 4254d2d156..f8e28f8a0a 100644 --- a/rpc/wrpc/wasm/src/client.rs +++ b/rpc/wrpc/wasm/src/client.rs @@ -1018,8 +1018,8 @@ build_wrpc_wasm_bindgen_interface!( GetDaaScoreTimestampEstimate, /// Feerate estimates (experimental) GetFeeEstimateExperimental, - /// Retrieves block headers from the Kaspa BlockDAG. - /// Returned information: List of block headers. + /// Retrieves selected-parent-chain block headers from the Kaspa BlockDAG. + /// Returned information: List of block headers, inclusive of the start hash when the limit is greater than zero. GetHeaders, /// Retrieves mempool entries from the Kaspa node's mempool. /// Returned information: List of mempool entries. diff --git a/testing/integration/src/rpc_tests.rs b/testing/integration/src/rpc_tests.rs index d3cc0fdc1f..034f54a86d 100644 --- a/testing/integration/src/rpc_tests.rs +++ b/testing/integration/src/rpc_tests.rs @@ -15,7 +15,11 @@ use kaspa_notify::{ SinkBlueScoreChangedScope, UtxosChangedScope, VirtualChainChangedScope, VirtualDaaScoreChangedScope, }, }; -use kaspa_rpc_core::{Notification, api::rpc::RpcApi, model::*}; +use kaspa_rpc_core::{ + Notification, + api::rpc::{MAX_SAFE_WINDOW_SIZE, RpcApi}, + model::*, +}; use kaspa_utils::{fd_budget, networking::ContextualNetAddress}; use kaspad_lib::args::Args; use tokio::task::JoinHandle; @@ -36,6 +40,43 @@ macro_rules! tst { }; } +async fn submit_simnet_block( + rpc_client: &kaspa_grpc_client::GrpcClient, + event_receiver: &async_channel::Receiver, + expected_daa_score: u64, +) -> Hash { + let GetBlockTemplateResponse { block, .. } = rpc_client + .get_block_template_call( + None, + GetBlockTemplateRequest { pay_address: Address::new(Prefix::Simnet, Version::PubKey, &[0u8; 32]), extra_data: Vec::new() }, + ) + .await + .unwrap(); + + let header: Header = (&block.header).try_into().unwrap(); + let block_hash = header.hash; + let response = rpc_client.submit_block(block, false).await.unwrap(); + assert_eq!(response.report, SubmitBlockReport::Success); + + while let Ok(notification) = match tokio::time::timeout(Duration::from_secs(1), event_receiver.recv()).await { + Ok(res) => res, + Err(elapsed) => panic!("expected virtual event before {}", elapsed), + } { + match notification { + Notification::VirtualDaaScoreChanged(msg) if msg.virtual_daa_score == expected_daa_score => { + break; + } + Notification::VirtualDaaScoreChanged(msg) if msg.virtual_daa_score > expected_daa_score => { + panic!("DAA score too high for number of submitted blocks") + } + Notification::VirtualDaaScoreChanged(_) => {} + _ => {} + } + } + + block_hash +} + /// `cargo test --release --package kaspa-testing-integration --lib -- rpc_tests::sanity_test` #[tokio::test] async fn sanity_test() { @@ -408,12 +449,12 @@ async fn sanity_test() { KaspadPayloadOps::GetHeaders => { let rpc_client = client.clone(); tst!(op, { - let response_result = rpc_client + let response = rpc_client .get_headers_call(None, GetHeadersRequest { start_hash: SIMNET_GENESIS.hash, limit: 1, is_ascending: true }) - .await; - - // Err because it's currently unimplemented - assert!(response_result.is_err()); + .await + .unwrap(); + assert_eq!(response.headers.len(), 1); + assert_eq!(response.headers[0].hash, SIMNET_GENESIS.hash); }) } @@ -803,3 +844,59 @@ async fn sanity_test() { drop(client); daemon.shutdown(); } + +#[tokio::test] +async fn get_headers_selected_chain_test() { + kaspa_core::log::try_init_logger("info"); + kaspa_core::panic::configure_panic(); + + let args = Args { + simnet: true, + disable_upnp: true, + enable_unsynced_mining: true, + block_template_cache_lifetime: Some(0), + unsafe_rpc: true, + ..Default::default() + }; + + let fd_total_budget = fd_budget::test_limit(); + let mut daemon = Daemon::new_random_with_args(args, fd_total_budget); + let client = daemon.start().await; + + let (sender, event_receiver) = async_channel::unbounded(); + client.start(Some(Arc::new(ChannelNotify::new(sender)))).await; + client.start_notify(Default::default(), Scope::VirtualDaaScoreChanged(VirtualDaaScoreChangedScope {})).await.unwrap(); + + let first_hash = submit_simnet_block(&client, &event_receiver, 1).await; + let second_hash = submit_simnet_block(&client, &event_receiver, 2).await; + + let empty_headers = client.get_headers(SIMNET_GENESIS.hash, 0, true).await.unwrap(); + assert!(empty_headers.is_empty()); + + let genesis_only = client.get_headers(SIMNET_GENESIS.hash, 1, true).await.unwrap(); + assert_eq!(genesis_only.iter().map(|header| header.hash).collect::>(), vec![SIMNET_GENESIS.hash]); + + let ascending = client.get_headers(SIMNET_GENESIS.hash, 3, true).await.unwrap(); + assert_eq!(ascending.iter().map(|header| header.hash).collect::>(), vec![SIMNET_GENESIS.hash, first_hash, second_hash]); + + let sink_only = client.get_headers(second_hash, 10, true).await.unwrap(); + assert_eq!(sink_only.iter().map(|header| header.hash).collect::>(), vec![second_hash]); + + let descending = client.get_headers(second_hash, 3, false).await.unwrap(); + assert_eq!(descending.iter().map(|header| header.hash).collect::>(), vec![second_hash, first_hash, SIMNET_GENESIS.hash]); + + for header in ascending.iter().chain(descending.iter()) { + let block = client.get_block(header.hash, false).await.unwrap(); + assert_eq!(block.header.hash, header.hash); + } + + let missing_hash = Hash::from_bytes([42; 32]); + assert!(client.get_headers(missing_hash, 1, true).await.is_err()); + + assert!(client.get_headers(SIMNET_GENESIS.hash, u64::from(MAX_SAFE_WINDOW_SIZE) + 1, true).await.is_err()); + + let _ = client.shutdown_call(None, ShutdownRequest {}).await.unwrap(); + client.disconnect().await.unwrap(); + drop(client); + daemon.shutdown(); +}