Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/apollo_gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ impl ProcessTxBlockingTask {
) -> GatewayResult<Self> {
let stateful_tx_validator = gateway
.stateful_tx_validator_factory
.instantiate_validator(gateway.state_reader_factory.as_ref())?;
.instantiate_validator(gateway.state_reader_factory.as_ref())
.await?;
Ok(Self {
stateful_tx_validator,
mempool_client: gateway.mempool_client.clone(),
Expand Down
18 changes: 12 additions & 6 deletions crates/apollo_gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ use crate::stateful_transaction_validator::{
MockStatefulTransactionValidatorFactoryTrait,
MockStatefulTransactionValidatorTrait,
StatefulTransactionValidatorFactoryTrait,
StatefulTransactionValidatorTrait,
};
use crate::stateless_transaction_validator::MockStatelessTransactionValidatorTrait;

Expand Down Expand Up @@ -314,14 +315,15 @@ async fn run_add_tx_and_extract_metrics(
AddTxResults { result, metric_handle_for_queries, metrics }
}

fn process_tx_task(
async fn process_tx_task(
stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
) -> ProcessTxBlockingTask {
let state_reader_factory = Arc::new(MockStateReaderFactory::new());
let stateful_tx_validator = StatefulTransactionValidatorFactoryTrait::instantiate_validator(
&stateful_transaction_validator_factory,
state_reader_factory.as_ref(),
)
.await
.expect("instantiate_validator should be mocked in tests");
ProcessTxBlockingTask {
stateful_tx_validator,
Expand Down Expand Up @@ -567,11 +569,15 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
.expect_extract_state_nonce_and_run_validations()
.return_once(|_, _, _| Err(expected_error));

mock_stateful_transaction_validator_factory
.expect_instantiate_validator()
.return_once(|_| Ok(Box::new(mock_stateful_transaction_validator)));
mock_stateful_transaction_validator_factory.expect_instantiate_validator().return_once(|_| {
Box::pin(async {
Ok::<Box<dyn StatefulTransactionValidatorTrait>, _>(Box::new(
mock_stateful_transaction_validator,
))
})
});

let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory);
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory).await;

let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();

Expand Down Expand Up @@ -615,7 +621,7 @@ async fn add_tx_returns_error_when_instantiating_validator_fails(
let expected_error = StarknetError { code: error_code.clone(), message: "placeholder".into() };
mock_stateful_transaction_validator_factory
.expect_instantiate_validator()
.return_once(|_| Err(expected_error));
.return_once(|_| Box::pin(async { Err::<_, _>(expected_error) }));

// Build gateway and inject the failing factory.
let mut gateway = mock_dependencies.gateway();
Expand Down
7 changes: 5 additions & 2 deletions crates/apollo_gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use apollo_gateway_types::deprecated_gateway_error::{
use apollo_gateway_types::errors::GatewaySpecError;
use apollo_mempool_types::communication::SharedMempoolClient;
use apollo_proc_macros::sequencer_latency_histogram;
use axum::async_trait;
use blockifier::blockifier::stateful_validator::{
StatefulValidator,
StatefulValidatorTrait as BlockifierStatefulValidatorTrait,
Expand Down Expand Up @@ -38,9 +39,10 @@ mod stateful_transaction_validator_test;

type BlockifierStatefulValidator = StatefulValidator<Box<dyn MempoolStateReader>>;

#[async_trait]
#[cfg_attr(test, mockall::automock)]
pub trait StatefulTransactionValidatorFactoryTrait: Send + Sync {
fn instantiate_validator(
async fn instantiate_validator(
&self,
state_reader_factory: &dyn StateReaderFactory,
) -> StatefulTransactionValidatorResult<Box<dyn StatefulTransactionValidatorTrait>>;
Expand All @@ -50,9 +52,10 @@ pub struct StatefulTransactionValidatorFactory {
pub chain_info: ChainInfo,
}

#[async_trait]
impl StatefulTransactionValidatorFactoryTrait for StatefulTransactionValidatorFactory {
// TODO(Ayelet): Move state_reader_factory and chain_info to the struct.
fn instantiate_validator(
async fn instantiate_validator(
&self,
state_reader_factory: &dyn StateReaderFactory,
) -> StatefulTransactionValidatorResult<Box<dyn StatefulTransactionValidatorTrait>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ async fn test_extract_state_nonce_and_run_validations(
}

#[rstest]
fn test_instantiate_validator() {
#[tokio::test]
async fn test_instantiate_validator() {
let stateful_validator_factory = StatefulTransactionValidatorFactory {
config: StatefulTransactionValidatorConfig::default(),
chain_info: ChainInfo::create_for_testing(),
Expand All @@ -173,7 +174,8 @@ fn test_instantiate_validator() {
.expect_get_state_reader_from_latest_block()
.return_once(|| latest_state_reader);

let validator = stateful_validator_factory.instantiate_validator(&mock_state_reader_factory);
let validator =
stateful_validator_factory.instantiate_validator(&mock_state_reader_factory).await;
assert!(validator.is_ok());
}

Expand Down
Loading