Skip to content

Commit f652350

Browse files
committed
apollo_gateway: add tests for fetch_contract_classes on sync state reader
1 parent d7df3ad commit f652350

File tree

3 files changed

+128
-12
lines changed

3 files changed

+128
-12
lines changed

crates/apollo_gateway/src/sync_state_reader_test.rs

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ use apollo_state_sync_types::state_sync_types::SyncBlock;
1010
use apollo_test_utils::{get_rng, GetTestInstance};
1111
use blockifier::execution::contract_class::RunnableCompiledClass;
1212
use blockifier::state::errors::StateError;
13+
use blockifier::state::global_cache::CompiledClasses;
1314
use blockifier::state::state_api::{StateReader, StateResult};
1415
use blockifier::state::state_api_test_utils::assert_eq_state_result;
16+
use blockifier::state::state_reader_and_contract_manager::FetchCompiledClasses;
1517
use mockall::predicate;
1618
use rstest::rstest;
1719
use starknet_api::block::{
@@ -27,6 +29,7 @@ use starknet_api::block::{
2729
use starknet_api::contract_class::ContractClass;
2830
use starknet_api::core::{ClassHash, SequencerContractAddress};
2931
use starknet_api::data_availability::L1DataAvailabilityMode;
32+
use starknet_api::state::SierraContractClass;
3033
use starknet_api::{class_hash, contract_address, felt, nonce, storage_key};
3134

3235
use crate::state_reader::MempoolStateReader;
@@ -221,9 +224,9 @@ async fn test_get_class_hash_at() {
221224
)]
222225
#[tokio::test]
223226
async fn test_get_compiled_class(
224-
#[case] class_manager_client_result: ClassManagerClientResult<Option<ExecutableClass>>,
225-
#[case] n_calls_to_class_manager_client: usize,
226-
#[case] sync_client_result: StateSyncClientResult<bool>,
227+
#[case] get_executable_result: ClassManagerClientResult<Option<ExecutableClass>>,
228+
#[case] n_calls_to_get_executable: usize,
229+
#[case] is_class_declared_at_result: StateSyncClientResult<bool>,
227230
#[case] expected_result: StateResult<RunnableCompiledClass>,
228231
#[case] class_hash: ClassHash,
229232
) {
@@ -234,15 +237,15 @@ async fn test_get_compiled_class(
234237

235238
mock_class_manager_client
236239
.expect_get_executable()
237-
.times(n_calls_to_class_manager_client)
240+
.times(n_calls_to_get_executable)
238241
.with(predicate::eq(class_hash))
239-
.return_once(move |_| class_manager_client_result);
242+
.return_once(move |_| get_executable_result);
240243

241244
mock_state_sync_client
242245
.expect_is_class_declared_at()
243246
.times(1)
244247
.with(predicate::eq(block_number), predicate::eq(class_hash))
245-
.return_once(move |_, _| sync_client_result);
248+
.return_once(move |_, _| is_class_declared_at_result);
246249

247250
let state_sync_reader = SyncStateReader::from_number(
248251
Arc::new(mock_state_sync_client),
@@ -271,3 +274,118 @@ async fn test_get_compiled_class_panics_when_class_exists_in_sync_but_not_in_cla
271274
)
272275
.await;
273276
}
277+
278+
#[rstest]
279+
#[case::cairo_0_class_declared(
280+
Ok(true),
281+
Ok(Some(ContractClass::test_deprecated_casm_contract_class())),
282+
1,
283+
Ok(None),
284+
0,
285+
Ok(CompiledClasses::from_runnable_for_testing(
286+
RunnableCompiledClass::test_deprecated_casm_contract_class(),
287+
))
288+
)]
289+
#[case::class_declared(
290+
Ok(true),
291+
Ok(Some(ContractClass::test_casm_contract_class())),
292+
1,
293+
Ok(Some(SierraContractClass::default())),
294+
1,
295+
Ok(CompiledClasses::from_runnable_for_testing(
296+
RunnableCompiledClass::test_casm_contract_class(),
297+
))
298+
)]
299+
#[case::class_not_declared_but_in_class_manager(
300+
Ok(false),
301+
Ok(Some(ContractClass::test_casm_contract_class())),
302+
0,
303+
Ok(Some(SierraContractClass::default())),
304+
0,
305+
Err(StateError::UndeclaredClassHash(*DUMMY_CLASS_HASH)),
306+
)]
307+
#[case::class_not_declared(
308+
Ok(false),
309+
Ok(None),
310+
0,
311+
Ok(None),
312+
0,
313+
Err(StateError::UndeclaredClassHash(*DUMMY_CLASS_HASH)),
314+
)]
315+
#[tokio::test]
316+
async fn test_get_compiled_classes(
317+
#[case] is_class_declared_at_result: StateSyncClientResult<bool>,
318+
#[case] get_executable_result: ClassManagerClientResult<Option<ExecutableClass>>,
319+
#[case] n_calls_to_get_executable: usize,
320+
#[case] get_sierra_result: ClassManagerClientResult<Option<SierraContractClass>>,
321+
#[case] n_calls_to_get_sierra: usize,
322+
#[case] expected_result: StateResult<CompiledClasses>,
323+
) {
324+
let mut mock_state_sync_client = MockStateSyncClient::new();
325+
let mut mock_class_manager_client = MockClassManagerClient::new();
326+
327+
let block_number = BlockNumber(0);
328+
let class_hash = *DUMMY_CLASS_HASH;
329+
330+
mock_state_sync_client
331+
.expect_is_class_declared_at()
332+
.times(1)
333+
.with(predicate::eq(block_number), predicate::eq(class_hash))
334+
.return_once(|_, _| is_class_declared_at_result);
335+
mock_class_manager_client
336+
.expect_get_executable()
337+
.times(n_calls_to_get_executable)
338+
.with(predicate::eq(class_hash))
339+
.return_once(|_| get_executable_result);
340+
mock_class_manager_client
341+
.expect_get_sierra()
342+
.times(n_calls_to_get_sierra)
343+
.with(predicate::eq(class_hash))
344+
.return_once(|_| get_sierra_result);
345+
346+
let state_sync_reader = SyncStateReader::from_number(
347+
Arc::new(mock_state_sync_client),
348+
Arc::new(mock_class_manager_client),
349+
block_number,
350+
tokio::runtime::Handle::current(),
351+
);
352+
353+
let result =
354+
tokio::task::spawn_blocking(move || state_sync_reader.get_compiled_classes(class_hash))
355+
.await
356+
.unwrap();
357+
assert_eq_state_result(&result, &expected_result);
358+
}
359+
360+
#[rstest]
361+
#[case::declared(Ok(true), Ok(true))]
362+
#[case::not_declared(Ok(false), Ok(false))]
363+
#[tokio::test]
364+
async fn test_is_declared(
365+
#[case] is_cairo_1_declared_result: StateSyncClientResult<bool>,
366+
#[case] expected_result: StateResult<bool>,
367+
) {
368+
let mut mock_state_sync_client = MockStateSyncClient::new();
369+
let mock_class_manager_client = MockClassManagerClient::new();
370+
371+
let block_number = BlockNumber(0);
372+
let class_hash = *DUMMY_CLASS_HASH;
373+
374+
mock_state_sync_client
375+
.expect_is_cairo_1_class_declared_at()
376+
.times(1)
377+
.with(predicate::eq(block_number), predicate::eq(class_hash))
378+
.return_once(|_, _| is_cairo_1_declared_result);
379+
380+
let state_sync_reader = SyncStateReader::from_number(
381+
Arc::new(mock_state_sync_client),
382+
Arc::new(mock_class_manager_client),
383+
block_number,
384+
tokio::runtime::Handle::current(),
385+
);
386+
387+
let result = tokio::task::spawn_blocking(move || state_sync_reader.is_declared(class_hash))
388+
.await
389+
.unwrap();
390+
assert_eq!(result.expect("unexpected error in state reader"), expected_result.unwrap())
391+
}

crates/blockifier/src/state/global_cache.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::execution::native::contract_class::NativeCompiledClassV1;
99

1010
pub const GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST: usize = 600;
1111

12-
#[derive(Debug, Clone)]
12+
#[derive(Debug, Clone, PartialEq)]
1313
pub enum CompiledClasses {
1414
V0(CompiledClassV0),
1515
V1(CompiledClassV1, Arc<SierraContractClass>),

crates/blockifier/src/state/state_api_test_utils.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
use crate::execution::contract_class::RunnableCompiledClass;
1+
use std::fmt::Debug;
2+
23
use crate::state::errors::StateError;
34
use crate::state::state_api::StateResult;
45

5-
pub fn assert_eq_state_result(
6-
a: &StateResult<RunnableCompiledClass>,
7-
b: &StateResult<RunnableCompiledClass>,
8-
) {
6+
pub fn assert_eq_state_result<T: PartialEq + Debug>(a: &StateResult<T>, b: &StateResult<T>) {
97
match (a, b) {
108
(Ok(a), Ok(b)) => assert_eq!(a, b),
119
(Err(StateError::UndeclaredClassHash(a)), Err(StateError::UndeclaredClassHash(b))) => {

0 commit comments

Comments
 (0)