Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 124 additions & 6 deletions crates/apollo_gateway/src/sync_state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ use apollo_state_sync_types::state_sync_types::SyncBlock;
use apollo_test_utils::{get_rng, GetTestInstance};
use blockifier::execution::contract_class::RunnableCompiledClass;
use blockifier::state::errors::StateError;
use blockifier::state::global_cache::CompiledClasses;
use blockifier::state::state_api::{StateReader, StateResult};
use blockifier::state::state_api_test_utils::assert_eq_state_result;
use blockifier::state::state_reader_and_contract_manager::FetchCompiledClasses;
use mockall::predicate;
use rstest::rstest;
use starknet_api::block::{
Expand All @@ -27,6 +29,7 @@ use starknet_api::block::{
use starknet_api::contract_class::ContractClass;
use starknet_api::core::{ClassHash, SequencerContractAddress};
use starknet_api::data_availability::L1DataAvailabilityMode;
use starknet_api::state::SierraContractClass;
use starknet_api::{class_hash, contract_address, felt, nonce, storage_key};

use crate::state_reader::MempoolStateReader;
Expand Down Expand Up @@ -221,9 +224,9 @@ async fn test_get_class_hash_at() {
)]
#[tokio::test]
async fn test_get_compiled_class(
#[case] class_manager_client_result: ClassManagerClientResult<Option<ExecutableClass>>,
#[case] n_calls_to_class_manager_client: usize,
#[case] sync_client_result: StateSyncClientResult<bool>,
#[case] get_executable_result: ClassManagerClientResult<Option<ExecutableClass>>,
#[case] n_calls_to_get_executable: usize,
#[case] is_class_declared_at_result: StateSyncClientResult<bool>,
#[case] expected_result: StateResult<RunnableCompiledClass>,
#[case] class_hash: ClassHash,
) {
Expand All @@ -234,15 +237,15 @@ async fn test_get_compiled_class(

mock_class_manager_client
.expect_get_executable()
.times(n_calls_to_class_manager_client)
.times(n_calls_to_get_executable)
.with(predicate::eq(class_hash))
.return_once(move |_| class_manager_client_result);
.return_once(move |_| get_executable_result);

mock_state_sync_client
.expect_is_class_declared_at()
.times(1)
.with(predicate::eq(block_number), predicate::eq(class_hash))
.return_once(move |_, _| sync_client_result);
.return_once(move |_, _| is_class_declared_at_result);

let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Expand Down Expand Up @@ -271,3 +274,118 @@ async fn test_get_compiled_class_panics_when_class_exists_in_sync_but_not_in_cla
)
.await;
}

#[rstest]
#[case::cairo_0_class_declared(
Ok(true),
Ok(Some(ContractClass::test_deprecated_casm_contract_class())),
1,
Ok(None),
0,
Ok(CompiledClasses::from_runnable_for_testing(
RunnableCompiledClass::test_deprecated_casm_contract_class(),
))
)]
#[case::class_declared(
Ok(true),
Ok(Some(ContractClass::test_casm_contract_class())),
1,
Ok(Some(SierraContractClass::default())),
1,
Ok(CompiledClasses::from_runnable_for_testing(
RunnableCompiledClass::test_casm_contract_class(),
))
)]
#[case::class_not_declared_but_in_class_manager(
Ok(false),
Ok(Some(ContractClass::test_casm_contract_class())),
0,
Ok(Some(SierraContractClass::default())),
0,
Err(StateError::UndeclaredClassHash(*DUMMY_CLASS_HASH)),
)]
#[case::class_not_declared(
Ok(false),
Ok(None),
0,
Ok(None),
0,
Err(StateError::UndeclaredClassHash(*DUMMY_CLASS_HASH)),
)]
#[tokio::test]
async fn test_fetch_compiled_classes_get_compiled_classes(
#[case] is_class_declared_at_result: StateSyncClientResult<bool>,
#[case] get_executable_result: ClassManagerClientResult<Option<ExecutableClass>>,
#[case] n_calls_to_get_executable: usize,
#[case] get_sierra_result: ClassManagerClientResult<Option<SierraContractClass>>,
#[case] n_calls_to_get_sierra: usize,
#[case] expected_result: StateResult<CompiledClasses>,
) {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mut mock_class_manager_client = MockClassManagerClient::new();

let block_number = BlockNumber(0);
let class_hash = *DUMMY_CLASS_HASH;

mock_state_sync_client
.expect_is_class_declared_at()
.times(1)
.with(predicate::eq(block_number), predicate::eq(class_hash))
.return_once(|_, _| is_class_declared_at_result);
mock_class_manager_client
.expect_get_executable()
.times(n_calls_to_get_executable)
.with(predicate::eq(class_hash))
.return_once(|_| get_executable_result);
mock_class_manager_client
.expect_get_sierra()
.times(n_calls_to_get_sierra)
.with(predicate::eq(class_hash))
.return_once(|_| get_sierra_result);

let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
tokio::runtime::Handle::current(),
);

let result =
tokio::task::spawn_blocking(move || state_sync_reader.get_compiled_classes(class_hash))
.await
.unwrap();
assert_eq_state_result(&result, &expected_result);
}

#[rstest]
#[case::declared(Ok(true), Ok(true))]
#[case::not_declared(Ok(false), Ok(false))]
#[tokio::test]
async fn test_fetch_compiled_classes_is_declared(
#[case] is_cairo_1_declared_result: StateSyncClientResult<bool>,
#[case] expected_result: StateResult<bool>,
) {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mock_class_manager_client = MockClassManagerClient::new();

let block_number = BlockNumber(0);
let class_hash = *DUMMY_CLASS_HASH;

mock_state_sync_client
.expect_is_cairo_1_class_declared_at()
.times(1)
.with(predicate::eq(block_number), predicate::eq(class_hash))
.return_once(|_, _| is_cairo_1_declared_result);

let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
tokio::runtime::Handle::current(),
);

let result = tokio::task::spawn_blocking(move || state_sync_reader.is_declared(class_hash))
.await
.unwrap();
assert_eq!(result.expect("unexpected error in state reader"), expected_result.unwrap())
}
4 changes: 2 additions & 2 deletions crates/blockifier/src/state/global_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::execution::native::contract_class::NativeCompiledClassV1;

pub const GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST: usize = 600;

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum CompiledClasses {
V0(CompiledClassV0),
V1(CompiledClassV1, Arc<SierraContractClass>),
Expand Down Expand Up @@ -56,7 +56,7 @@ impl CompiledClasses {
pub type RawClassCache = GlobalContractCache<CompiledClasses>;

#[cfg(feature = "cairo_native")]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(test, derive(PartialEq))]
pub enum CachedCairoNative {
Compiled(NativeCompiledClassV1),
Expand Down
8 changes: 3 additions & 5 deletions crates/blockifier/src/state/state_api_test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use crate::execution::contract_class::RunnableCompiledClass;
use std::fmt::Debug;

use crate::state::errors::StateError;
use crate::state::state_api::StateResult;

pub fn assert_eq_state_result(
a: &StateResult<RunnableCompiledClass>,
b: &StateResult<RunnableCompiledClass>,
) {
pub fn assert_eq_state_result<T: PartialEq + Debug>(a: &StateResult<T>, b: &StateResult<T>) {
match (a, b) {
(Ok(a), Ok(b)) => assert_eq!(a, b),
(Err(StateError::UndeclaredClassHash(a)), Err(StateError::UndeclaredClassHash(b))) => {
Expand Down
Loading