Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions crates/apollo_gateway/src/rpc_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,20 @@ impl RpcStateReader {
}
}

#[async_trait]
impl MempoolStateReader for RpcStateReader {
fn get_block_info(&self) -> StateResult<BlockInfo> {
async fn get_block_info(&self) -> StateResult<BlockInfo> {
let get_block_params = GetBlockWithTxHashesParams { block_id: self.block_id };
let reader = self.clone();

let rpc_request_value = tokio::task::spawn_blocking(move || {
reader.send_rpc_request("starknet_getBlockWithTxHashes", get_block_params)
})
.await
.map_err(|e| StateError::StateReadError(format!("JoinError: {e}")))??;

// The response from the rpc is a full block but we only deserialize the header.
let block_header: BlockHeader = serde_json::from_value(
self.send_rpc_request("starknet_getBlockWithTxHashes", get_block_params)?,
)
.map_err(serde_err_to_state_err)?;
let block_header: BlockHeader =
serde_json::from_value(rpc_request_value).map_err(serde_err_to_state_err)?;
let block_info = block_header.try_into()?;
Ok(block_info)
}
Expand Down
3 changes: 1 addition & 2 deletions crates/apollo_gateway/src/rpc_state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ async fn test_get_block_info() {
);

let client = RpcStateReader::from_latest(&config);
let result =
tokio::task::spawn_blocking(move || client.get_block_info()).await.unwrap().unwrap();
let result = client.get_block_info().await.unwrap();
assert_eq!(result, expected_result);
mock.assert_async().await;
}
Expand Down
8 changes: 5 additions & 3 deletions crates/apollo_gateway/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

#[async_trait]
pub trait MempoolStateReader: BlockifierStateReader + Send + Sync {
fn get_block_info(&self) -> Result<BlockInfo, StateError>;
async fn get_block_info(&self) -> Result<BlockInfo, StateError>;
}

#[async_trait]
Expand All @@ -25,9 +26,10 @@ pub trait StateReaderFactory: Send + Sync {
// By default, a Box<dyn Trait> does not implement the trait of the object it contains.
// Therefore, for using the Box<dyn MempoolStateReader>, that the StateReaderFactory creates,
// we need to implement the MempoolStateReader trait for Box<dyn MempoolStateReader>.
#[async_trait]
impl MempoolStateReader for Box<dyn MempoolStateReader> {
fn get_block_info(&self) -> Result<BlockInfo, StateError> {
self.as_ref().get_block_info()
async fn get_block_info(&self) -> Result<BlockInfo, StateError> {
self.as_ref().get_block_info().await
}
}

Expand Down
3 changes: 2 additions & 1 deletion crates/apollo_gateway/src/state_reader_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ pub struct TestStateReader {
pub blockifier_state_reader: DictStateReader,
}

#[async_trait]
impl MempoolStateReader for TestStateReader {
fn get_block_info(&self) -> Result<BlockInfo, StateError> {
async fn get_block_info(&self) -> Result<BlockInfo, StateError> {
Ok(self.block_info.clone())
}
}
Expand Down
5 changes: 3 additions & 2 deletions crates/apollo_gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl StatefulTransactionValidatorFactoryTrait for StatefulTransactionValidatorFa
e,
)
})?;
let latest_block_info = get_latest_block_info(&state_reader)?;
let latest_block_info = get_latest_block_info(&state_reader).await?;

let state = CachedState::new(state_reader);
let mut versioned_constants = VersionedConstants::get_versioned_constants(
Expand Down Expand Up @@ -334,10 +334,11 @@ fn skip_stateful_validations(
Ok(false)
}

pub fn get_latest_block_info(
pub async fn get_latest_block_info(
state_reader: &dyn MempoolStateReader,
) -> StatefulTransactionValidatorResult<BlockInfo> {
state_reader
.get_block_info()
.await
.map_err(|e| StarknetError::internal_with_logging("Failed to get latest block info", e))
}
9 changes: 6 additions & 3 deletions crates/apollo_gateway/src/sync_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use async_trait::async_trait;
use blockifier::execution::contract_class::RunnableCompiledClass;
use blockifier::state::errors::StateError;
use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult};
use futures::executor::block_on;
use starknet_api::block::{BlockHash, BlockInfo, BlockNumber, GasPriceVector, GasPrices};
use starknet_api::contract_class::ContractClass;
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
Expand Down Expand Up @@ -55,9 +54,13 @@ impl SyncStateReader {
}
}

#[async_trait]
impl MempoolStateReader for SyncStateReader {
fn get_block_info(&self) -> StateResult<BlockInfo> {
let block = block_on(self.state_sync_client.get_block(self.block_number))
async fn get_block_info(&self) -> StateResult<BlockInfo> {
let block = self
.state_sync_client
.get_block(self.block_number)
.await
.map_err(|e| StateError::StateReadError(e.to_string()))?;

let block_header = block.block_header_without_hash;
Expand Down
2 changes: 1 addition & 1 deletion crates/apollo_gateway/src/sync_state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async fn test_get_block_info() {
block_number,
tokio::runtime::Handle::current(),
);
let result = state_sync_reader.get_block_info().unwrap();
let result = state_sync_reader.get_block_info().await.unwrap();

assert_eq!(
result,
Expand Down
Loading