Skip to content

Commit 081adc6

Browse files
apollo_gateway: use get_nonce outside extract_state_nonce_and_run_validations
1 parent 8b444c3 commit 081adc6

File tree

8 files changed

+129
-52
lines changed

8 files changed

+129
-52
lines changed

crates/apollo_gateway/src/gateway.rs

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ use crate::errors::{
4242
GatewayResult,
4343
};
4444
use crate::metrics::{register_metrics, GatewayMetricHandle, GATEWAY_ADD_TX_LATENCY};
45-
use crate::state_reader::StateReaderFactory;
45+
use crate::state_reader::{
46+
GatewayStateReaderWithCompiledClasses,
47+
MempoolStateReader,
48+
StateReaderFactory,
49+
};
4650
use crate::stateful_transaction_validator::{
4751
StatefulTransactionValidatorFactory,
4852
StatefulTransactionValidatorFactoryTrait,
@@ -154,7 +158,7 @@ impl Gateway {
154158
transaction_converter_err_to_deprecated_gw_err(&tx_signature, e)
155159
})?;
156160

157-
let blocking_task =
161+
let mut blocking_task =
158162
ProcessTxBlockingTask::new(self, executable_tx, tokio::runtime::Handle::current())
159163
.await
160164
.map_err(|e| {
@@ -165,10 +169,22 @@ impl Gateway {
165169
metric_counters.record_add_tx_failure(&e);
166170
e
167171
})?;
172+
173+
let state_reader = blocking_task.get_state_reader().await?;
174+
175+
let account_address = blocking_task.executable_tx.contract_address();
176+
let account_nonce = state_reader
177+
.get_account_nonce(account_address)
178+
.await
179+
.map_err(|e| StarknetError::internal_with_logging("Failed to get account nonce", e))?;
180+
181+
let stateful_tx_validator = blocking_task.create_stateful_validator(state_reader).await?;
182+
168183
// Run the blocking task in the current span.
169184
let curr_span = Span::current();
170-
let handle =
171-
tokio::task::spawn_blocking(move || curr_span.in_scope(|| blocking_task.process_tx()));
185+
let handle = tokio::task::spawn_blocking(move || {
186+
curr_span.in_scope(|| blocking_task.process_tx(account_nonce, stateful_tx_validator))
187+
});
172188
let handle_result = handle.await;
173189
let nonce = match handle_result {
174190
Ok(Ok(nonce)) => nonce,
@@ -243,7 +259,8 @@ impl Gateway {
243259
/// CPU-intensive transaction processing, spawned in a blocking thread to avoid blocking other tasks
244260
/// from running.
245261
struct ProcessTxBlockingTask {
246-
stateful_transaction_validator: Box<dyn StatefulTransactionValidatorTrait + Send>,
262+
state_reader_factory: Arc<dyn StateReaderFactory>,
263+
stateful_tx_validator_factory: Arc<dyn StatefulTransactionValidatorFactoryTrait>,
247264
mempool_client: SharedMempoolClient,
248265
executable_tx: AccountTransaction,
249266
runtime: tokio::runtime::Handle,
@@ -255,22 +272,54 @@ impl ProcessTxBlockingTask {
255272
executable_tx: AccountTransaction,
256273
runtime: tokio::runtime::Handle,
257274
) -> GatewayResult<Self> {
258-
// TODO(Itamar): Extract creating validator to a separate function.
259-
let stateful_transaction_validator = gateway
260-
.stateful_tx_validator_factory
261-
.instantiate_validator(gateway.state_reader_factory.clone())
262-
.await?;
263275
Ok(Self {
264-
stateful_transaction_validator,
276+
state_reader_factory: gateway.state_reader_factory.clone(),
277+
stateful_tx_validator_factory: gateway.stateful_tx_validator_factory.clone(),
265278
mempool_client: gateway.mempool_client.clone(),
266279
executable_tx,
267280
runtime,
268281
})
269282
}
270283

271-
fn process_tx(mut self) -> GatewayResult<Nonce> {
272-
let nonce = self.stateful_transaction_validator.extract_state_nonce_and_run_validations(
284+
pub async fn get_state_reader(
285+
&self,
286+
) -> GatewayResult<Box<dyn GatewayStateReaderWithCompiledClasses>> {
287+
let state_reader = self
288+
.stateful_tx_validator_factory
289+
.get_state_reader_for_validation(self.state_reader_factory.clone())
290+
.await
291+
.map_err(|e| {
292+
StarknetError::internal_with_logging(
293+
"Failed to get state reader from latest block",
294+
e,
295+
)
296+
})?;
297+
Ok(state_reader)
298+
}
299+
300+
pub async fn create_stateful_validator(
301+
&mut self,
302+
state_reader: Box<dyn GatewayStateReaderWithCompiledClasses>,
303+
) -> GatewayResult<Box<dyn StatefulTransactionValidatorTrait>> {
304+
self.stateful_tx_validator_factory
305+
.create_validator_from_state_reader(state_reader)
306+
.await
307+
.map_err(|e| {
308+
StarknetError::internal_with_logging(
309+
"Failed to create stateful validator from state reader",
310+
e,
311+
)
312+
})
313+
}
314+
315+
fn process_tx(
316+
self,
317+
account_nonce: Nonce,
318+
mut stateful_tx_validator: Box<dyn StatefulTransactionValidatorTrait>,
319+
) -> GatewayResult<Nonce> {
320+
let nonce = stateful_tx_validator.extract_state_nonce_and_run_validations(
273321
&self.executable_tx,
322+
account_nonce,
274323
self.mempool_client,
275324
self.runtime,
276325
)?;

crates/apollo_gateway/src/gateway_test.rs

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ use crate::state_reader_test_utils::{local_test_state_reader_factory, TestStateR
9898
use crate::stateful_transaction_validator::{
9999
MockStatefulTransactionValidatorFactoryTrait,
100100
MockStatefulTransactionValidatorTrait,
101-
StatefulTransactionValidatorFactoryTrait,
102101
};
103102
use crate::stateless_transaction_validator::MockStatelessTransactionValidatorTrait;
104103

@@ -317,17 +316,12 @@ async fn run_add_tx_and_extract_metrics(
317316
}
318317

319318
async fn process_tx_task(
320-
stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
319+
stateful_tx_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
321320
) -> ProcessTxBlockingTask {
322321
let state_reader_factory = Arc::new(MockStateReaderFactory::new());
323-
let stateful_tx_validator = StatefulTransactionValidatorFactoryTrait::instantiate_validator(
324-
&stateful_transaction_validator_factory,
325-
state_reader_factory,
326-
)
327-
.await
328-
.expect("instantiate_validator should be mocked in tests");
329322
ProcessTxBlockingTask {
330-
stateful_transaction_validator: stateful_tx_validator,
323+
state_reader_factory,
324+
stateful_tx_validator_factory: Arc::new(stateful_tx_validator_factory),
331325
mempool_client: Arc::new(MockMempoolClient::new()),
332326
executable_tx: executable_invoke_tx(invoke_args()),
333327
runtime: tokio::runtime::Handle::current(),
@@ -559,7 +553,7 @@ fn test_full_cycle_dump_deserialize_authorized_declarer_accounts(
559553
async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_fails(
560554
#[case] error_code: StarknetErrorCode,
561555
mut mock_stateful_transaction_validator: MockStatefulTransactionValidatorTrait,
562-
mut mock_stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
556+
mock_stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
563557
) {
564558
let expected_error = StarknetError {
565559
code: error_code.clone(),
@@ -568,15 +562,15 @@ async fn process_tx_returns_error_when_extract_state_nonce_and_run_validations_f
568562

569563
mock_stateful_transaction_validator
570564
.expect_extract_state_nonce_and_run_validations()
571-
.return_once(|_, _, _| Err(expected_error));
572-
573-
mock_stateful_transaction_validator_factory
574-
.expect_instantiate_validator()
575-
.return_once(|_| Ok(Box::new(mock_stateful_transaction_validator)));
565+
.return_once(|_, _, _, _| Err(expected_error));
576566

577567
let process_tx_task = process_tx_task(mock_stateful_transaction_validator_factory).await;
578-
579-
let result = tokio::task::spawn_blocking(move || process_tx_task.process_tx()).await.unwrap();
568+
let account_nonce = nonce!(0);
569+
let result = tokio::task::spawn_blocking(move || {
570+
process_tx_task.process_tx(account_nonce, Box::new(mock_stateful_transaction_validator))
571+
})
572+
.await
573+
.unwrap();
580574

581575
assert!(result.is_err());
582576
assert_eq!(result.unwrap_err().code, error_code);
@@ -605,19 +599,18 @@ async fn stateless_transaction_validator_error(mut mock_dependencies: MockDepend
605599

606600
#[rstest]
607601
#[tokio::test]
608-
async fn add_tx_returns_error_when_instantiating_validator_fails(
602+
async fn add_tx_returns_error_when_getting_state_reader_fails(
609603
mut mock_dependencies: MockDependencies,
610604
mut mock_stateful_transaction_validator_factory: MockStatefulTransactionValidatorFactoryTrait,
611605
) {
612606
// Prepare transaction conversion to reach instantiation.
613607
let tx_args = invoke_args();
614608
setup_transaction_converter_mock(&mut mock_dependencies.mock_transaction_converter, &tx_args);
615609

616-
// Fail validator instantiation.
617610
let error_code = StarknetErrorCode::UnknownErrorCode("StarknetErrorCode.InternalError".into());
618611
let expected_error = StarknetError { code: error_code.clone(), message: "placeholder".into() };
619612
mock_stateful_transaction_validator_factory
620-
.expect_instantiate_validator()
613+
.expect_get_state_reader_for_validation()
621614
.return_once(|_| Err(expected_error));
622615

623616
// Build gateway and inject the failing factory.

crates/apollo_gateway/src/rpc_state_reader.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,19 @@ impl MempoolStateReader for RpcStateReader {
117117
let block_info = block_header.try_into()?;
118118
Ok(block_info)
119119
}
120+
121+
async fn get_account_nonce(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
122+
let get_nonce_params = GetNonceParams { block_id: self.block_id, contract_address };
123+
let result = self.send_rpc_request("starknet_getNonce", get_nonce_params);
124+
match result {
125+
Ok(value) => {
126+
let nonce: Nonce = serde_json::from_value(value).map_err(serde_err_to_state_err)?;
127+
Ok(nonce)
128+
}
129+
Err(RPCStateReaderError::ContractAddressNotFound(_)) => Ok(Nonce::default()),
130+
Err(e) => Err(e)?,
131+
}
132+
}
120133
}
121134

122135
impl BlockifierStateReader for RpcStateReader {

crates/apollo_gateway/src/state_reader.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use starknet_types_core::felt::Felt;
1717
#[async_trait]
1818
pub trait MempoolStateReader: BlockifierStateReader + Send + Sync {
1919
async fn get_block_info(&self) -> StateResult<BlockInfo>;
20+
async fn get_account_nonce(&self, contract_address: ContractAddress) -> StateResult<Nonce>;
2021
}
2122

2223
#[cfg_attr(test, automock)]
@@ -74,6 +75,10 @@ impl MempoolStateReader for Box<dyn GatewayStateReaderWithCompiledClasses> {
7475
async fn get_block_info(&self) -> StateResult<BlockInfo> {
7576
self.as_ref().get_block_info().await
7677
}
78+
79+
async fn get_account_nonce(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
80+
self.as_ref().get_account_nonce(contract_address).await
81+
}
7782
}
7883

7984
impl GatewayStateReaderWithCompiledClasses for Box<dyn GatewayStateReaderWithCompiledClasses> {}
@@ -85,4 +90,8 @@ impl MempoolStateReader
8590
async fn get_block_info(&self) -> StateResult<BlockInfo> {
8691
self.state_reader.get_block_info().await
8792
}
93+
94+
async fn get_account_nonce(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
95+
self.state_reader.get_account_nonce(contract_address).await
96+
}
8897
}

crates/apollo_gateway/src/state_reader_test_utils.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ impl MempoolStateReader for TestStateReader {
3434
async fn get_block_info(&self) -> Result<BlockInfo, StateError> {
3535
Ok(self.block_info.clone())
3636
}
37+
38+
async fn get_account_nonce(
39+
&self,
40+
contract_address: ContractAddress,
41+
) -> Result<Nonce, StateError> {
42+
self.blockifier_state_reader.get_nonce_at(contract_address)
43+
}
3744
}
3845

3946
impl FetchCompiledClasses for TestStateReader {

crates/apollo_gateway/src/stateful_transaction_validator.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@ pub trait StatefulTransactionValidatorFactoryTrait: Send + Sync {
5757
state_reader_factory: Arc<dyn StateReaderFactory>,
5858
) -> StatefulTransactionValidatorResult<Box<dyn StatefulTransactionValidatorTrait>>;
5959

60+
// 1) Get a state reader for validation (async, uses the factory).
6061
async fn get_state_reader_for_validation(
6162
&self,
6263
state_reader_factory: Arc<dyn StateReaderFactory>,
6364
) -> StatefulTransactionValidatorResult<Box<dyn GatewayStateReaderWithCompiledClasses>>;
64-
65+
// 2) Create the validator from a provided state reader (async, uses ChainInfo/config).
6566
async fn create_validator_from_state_reader(
6667
&self,
6768
state_reader: Box<dyn GatewayStateReaderWithCompiledClasses>,
@@ -147,6 +148,7 @@ pub trait StatefulTransactionValidatorTrait: Send {
147148
fn extract_state_nonce_and_run_validations(
148149
&mut self,
149150
executable_tx: &ExecutableTransaction,
151+
account_nonce: Nonce,
150152
mempool_client: SharedMempoolClient,
151153
runtime: tokio::runtime::Handle,
152154
) -> StatefulTransactionValidatorResult<Nonce>;
@@ -168,11 +170,10 @@ impl<B: BlockifierStatefulValidatorTrait + Send> StatefulTransactionValidatorTra
168170
fn extract_state_nonce_and_run_validations(
169171
&mut self,
170172
executable_tx: &ExecutableTransaction,
173+
account_nonce: Nonce,
171174
mempool_client: SharedMempoolClient,
172175
runtime: tokio::runtime::Handle,
173176
) -> StatefulTransactionValidatorResult<Nonce> {
174-
let address = executable_tx.contract_address();
175-
let account_nonce = self.get_nonce(address)?;
176177
self.run_transaction_validations(executable_tx, account_nonce, mempool_client, runtime)?;
177178
Ok(account_nonce)
178179
}

crates/apollo_gateway/src/stateful_transaction_validator_test.rs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,32 +62,24 @@ async fn test_get_nonce_fail_on_extract_state_nonce_and_run_validations() {
6262
// The mempool does not have any transactions from the sender.
6363
Ok(false)
6464
});
65-
let mempool_client = Arc::new(mock_mempool_client);
66-
let runtime = tokio::runtime::Handle::current();
65+
// Removed unused mempool client and runtime in this test since nonce is fetched outside now.
6766

6867
let mut stateful_validator = StatefulTransactionValidator {
6968
config: StatefulTransactionValidatorConfig::default(),
7069
blockifier_stateful_tx_validator: mock_blockifier_validator,
7170
};
7271

7372
let executable_tx = create_executable_invoke_tx(CairoVersion::Cairo1(RunnableCairo1::Casm));
74-
let result = tokio::task::spawn_blocking(move || {
75-
stateful_validator.extract_state_nonce_and_run_validations(
76-
&executable_tx,
77-
mempool_client,
78-
runtime,
79-
)
73+
// get_nonce should fail and return an InternalError.
74+
let err = tokio::task::spawn_blocking(move || {
75+
stateful_validator.get_nonce(executable_tx.contract_address())
8076
})
8177
.await
82-
.unwrap();
78+
.unwrap()
79+
.unwrap_err();
8380
assert_eq!(
84-
result,
85-
Err(StarknetError {
86-
code: StarknetErrorCode::UnknownErrorCode(
87-
"StarknetErrorCode.InternalError".to_string()
88-
),
89-
message: "Internal error".to_string(),
90-
})
81+
err.code,
82+
StarknetErrorCode::UnknownErrorCode("StarknetErrorCode.InternalError".to_string())
9183
);
9284
}
9385

@@ -145,6 +137,7 @@ async fn test_extract_state_nonce_and_run_validations(
145137
let result = tokio::task::spawn_blocking(move || {
146138
stateful_validator.extract_state_nonce_and_run_validations(
147139
&executable_tx,
140+
account_nonce,
148141
mempool_client,
149142
runtime,
150143
)

crates/apollo_gateway/src/sync_state_reader.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ impl MempoolStateReader for SyncStateReader {
131131

132132
Ok(block_info)
133133
}
134+
135+
async fn get_account_nonce(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
136+
let res = self.state_sync_client.get_nonce_at(self.block_number, contract_address).await;
137+
138+
match res {
139+
Ok(value) => Ok(value),
140+
Err(StateSyncClientError::StateSyncError(StateSyncError::ContractNotFound(_))) => {
141+
Ok(Nonce::default())
142+
}
143+
Err(e) => Err(StateError::StateReadError(e.to_string())),
144+
}
145+
}
134146
}
135147

136148
impl FetchCompiledClasses for SyncStateReader {

0 commit comments

Comments
 (0)