Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 131 additions & 3 deletions crates/price-estimation/src/competition/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use {
super::{QuoteVerificationMode, native::NativePriceEstimating},
crate::PriceEstimationError,
crate::{
PriceEstimateResult,
PriceEstimating,
PriceEstimationError,
Query,
StreamingPriceEstimating,
},
futures::{
future::{BoxFuture, FutureExt},
stream::{FuturesUnordered, StreamExt},
stream::{BoxStream, FuturesUnordered, StreamExt},
},
gas_price_estimation::GasPriceEstimating,
model::order::OrderKind,
Expand Down Expand Up @@ -168,6 +174,21 @@ impl<T: Send + Sync + 'static> CompetitionEstimator<T> {
}
}

impl StreamingPriceEstimating for CompetitionEstimator<Arc<dyn PriceEstimating>> {
/// Runs every estimator concurrently across all stages and yields each
/// result as it arrives. No ranking, no early return. The caller stops
/// by dropping the stream.
fn estimate_stream(&self, query: Arc<Query>) -> BoxStream<'_, PriceEstimateResult> {
let futures: FuturesUnordered<BoxFuture<'_, PriceEstimateResult>> = FuturesUnordered::new();
for stage in &self.stages {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we iterate self.stages and call each leaf estimator's estimate directly — it does not go through CompetitionEstimator::estimate. So it bypasses the wrapper layer that does is_reasonable and emit_quote_event - is that intended?

for (_name, estimator) in stage {
futures.push(estimator.estimate(query.clone()));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should return the name with the result for logging/metric purposes when being used in the followup stacked PR?

}
}
futures.boxed()
Comment on lines +182 to +188

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to turn this into a single chain:

        self.stages
            .iter()
            .flatten()
            .map(|(_name, estimator)| estimator.estimate(query.clone()))
            .collect::<FuturesUnordered<_>>()
            .boxed()

Feel free to ignore if you prefer the current logic.

}
}

struct Context<'a, ESTIMATOR, QUERY> {
/// the estimator that is supposed to compute a price
estimator: &'a ESTIMATOR,
Expand Down Expand Up @@ -253,11 +274,13 @@ mod tests {
HEALTHY_PRICE_ESTIMATION_TIME,
MockPriceEstimating,
PriceEstimating,
PriceEstimationError,
Query,
StreamingPriceEstimating,
},
alloy::primitives::{Address, U256},
anyhow::anyhow,
futures::channel::oneshot::channel,
futures::{StreamExt, channel::oneshot::channel},
model::order::OrderKind,
number::nonzero::NonZeroU256,
std::time::Duration,
Expand Down Expand Up @@ -598,6 +621,111 @@ mod tests {
racing.estimate(query).await.unwrap();
}

fn make_query() -> Arc<Query> {
Arc::new(Query {
verification: Default::default(),
sell_token: Address::with_last_byte(0),
buy_token: Address::with_last_byte(1),
in_amount: NonZeroU256::try_from(1).unwrap(),
kind: OrderKind::Sell,
block_dependent: false,
timeout: HEALTHY_PRICE_ESTIMATION_TIME,
})
}

#[tokio::test]
async fn estimate_stream_yields_all_results() {
let fast = {
let mut m = MockPriceEstimating::new();
m.expect_estimate().times(1).returning(|_| {
async {
Ok(Estimate {
out_amount: U256::from(1u64),
gas: 1,
..Default::default()
})
}
.boxed()
});
m
};
let slow = {
let mut m = MockPriceEstimating::new();
m.expect_estimate().times(1).returning(|_| {
async {
sleep(Duration::from_millis(10)).await;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt you do a notify instead? making this one block until the fast stream is "done"

Ok(Estimate {
out_amount: U256::from(2u64),
gas: 1,
..Default::default()
})
}
.boxed()
});
m
};

let estimator: CompetitionEstimator<Arc<dyn PriceEstimating>> = CompetitionEstimator::new(
vec![vec![
("fast".to_owned(), Arc::new(fast)),
("slow".to_owned(), Arc::new(slow)),
]],
PriceRanking::MaxOutAmount,
);

let results: Vec<_> = estimator.estimate_stream(make_query()).collect().await;

assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.is_ok()));
// `fast` (out_amount 1) resolves immediately while `slow` sleeps, so a
// stream that yields as results arrive must emit `fast` first. A serial
// collect-then-yield implementation would fail this.
let amounts: Vec<_> = results
.iter()
.map(|r| r.as_ref().unwrap().out_amount)
.collect();
assert_eq!(amounts, vec![U256::from(1u64), U256::from(2u64)]);
}

#[tokio::test]
async fn estimate_stream_passes_through_errors() {
let ok = {
let mut m = MockPriceEstimating::new();
m.expect_estimate().times(1).returning(|_| {
async {
Ok(Estimate {
out_amount: U256::from(1u64),
gas: 1,
..Default::default()
})
}
.boxed()
});
m
};
let err = {
let mut m = MockPriceEstimating::new();
m.expect_estimate()
.times(1)
.returning(|_| async { Err(PriceEstimationError::NoLiquidity) }.boxed());
m
};

let estimator: CompetitionEstimator<Arc<dyn PriceEstimating>> = CompetitionEstimator::new(
vec![vec![
("ok".to_owned(), Arc::new(ok)),
("err".to_owned(), Arc::new(err)),
Comment on lines +716 to +717

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test could give slightly stronger guarantees that an error doesn't abort the entire stream if you emit the ok response after a tiny sleep.

Having the ordered results can also make the assertion logic at the end a bit easier to read.

]],
PriceRanking::MaxOutAmount,
);

let results: Vec<_> = estimator.estimate_stream(make_query()).collect().await;

assert_eq!(results.len(), 2);
assert_eq!(results.iter().filter(|r| r.is_ok()).count(), 1);
assert_eq!(results.iter().filter(|r| r.is_err()).count(), 1);
}

#[test]
fn custom_solver_errors_have_higher_priority_than_generic_errors() {
let custom_errors = [
Expand Down
9 changes: 8 additions & 1 deletion crates/price-estimation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use {
crate::trade_finding::QuoteExecution,
alloy::primitives::{Address, U256},
anyhow::Result,
futures::future::BoxFuture,
futures::{future::BoxFuture, stream::BoxStream},
model::order::OrderKind,
number::nonzero::NonZeroU256,
rate_limit::RateLimiter,
Expand Down Expand Up @@ -223,6 +223,13 @@ pub trait PriceEstimating: Send + Sync + 'static {
fn estimate(&self, query: Arc<Query>) -> BoxFuture<'_, PriceEstimateResult>;
}

/// Like `PriceEstimating`, but yields every estimator's result as it
/// completes instead of collapsing to the single best one.
#[cfg_attr(any(test, feature = "test-util"), mockall::automock)]
pub trait StreamingPriceEstimating: Send + Sync + 'static {
fn estimate_stream(&self, query: Arc<Query>) -> BoxStream<'_, PriceEstimateResult>;
}

pub const HEALTHY_PRICE_ESTIMATION_TIME: Duration = Duration::from_millis(5_000);

pub async fn rate_limited<T>(
Expand Down
Loading