Skip to content

Commit 314c447

Browse files
authored
Merge pull request #2929 from pyth-network/correcting-fee-calc-order
fix(stylus) - fixed order of fee calculation
2 parents ea00af3 + 3d76021 commit 314c447

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod test {
88
use motsu::prelude::*;
99
use pythnet_sdk::wire::v1::{AccumulatorUpdateData, Proof};
1010
use std::time::Duration;
11+
use stylus_sdk::types::AddressVM;
1112
use wormhole_contract::WormholeContract;
1213

1314
const PYTHNET_CHAIN_ID: u16 = 26;
@@ -118,8 +119,12 @@ mod test {
118119
let result = pyth_contract
119120
.sender_and_value(alice, update_fee)
120121
.update_price_feeds(update_data);
122+
121123
assert!(result.is_ok());
122124

125+
assert_eq!(alice.balance(), U256::ZERO);
126+
assert_eq!(pyth_contract.balance(), update_fee);
127+
123128
let price_result = pyth_contract
124129
.sender(alice)
125130
.get_price_unsafe(ban_usd_feed_id());
@@ -169,11 +174,17 @@ mod test {
169174
.update_price_feeds(update_data1);
170175
assert!(result1.is_ok());
171176

177+
assert_eq!(alice.balance(), update_fee2);
178+
assert_eq!(pyth_contract.balance(), update_fee1);
179+
172180
let result2 = pyth_contract
173181
.sender_and_value(alice, update_fee2)
174182
.update_price_feeds(update_data2);
175183
assert!(result2.is_ok());
176184

185+
assert_eq!(alice.balance(), U256::ZERO);
186+
assert_eq!(pyth_contract.balance(), update_fee1 + update_fee2);
187+
177188
let price_result = pyth_contract
178189
.sender(alice)
179190
.get_price_unsafe(ban_usd_feed_id());
@@ -243,6 +254,9 @@ mod test {
243254
.update_price_feeds(update_data);
244255
assert!(result.is_ok());
245256

257+
assert_eq!(alice.balance(), U256::ZERO);
258+
assert_eq!(pyth_contract.balance(), update_fee);
259+
246260
let price_result = pyth_contract
247261
.sender(alice)
248262
.get_price_no_older_than(btc_usd_feed_id(), u64::MAX);
@@ -269,6 +283,9 @@ mod test {
269283
.update_price_feeds(update_data);
270284
assert!(result.is_ok());
271285

286+
assert_eq!(alice.balance(), U256::ZERO);
287+
assert_eq!(pyth_contract.balance(), update_fee);
288+
272289
let price_result = pyth_contract
273290
.sender(alice)
274291
.get_price_no_older_than(btc_usd_feed_id(), 1);
@@ -298,6 +315,9 @@ mod test {
298315
.update_price_feeds(update_data);
299316
assert!(result.is_ok());
300317

318+
assert_eq!(alice.balance(), U256::ZERO);
319+
assert_eq!(pyth_contract.balance(), update_fee);
320+
301321
let first_price_result = pyth_contract
302322
.sender(alice)
303323
.get_price_unsafe(ban_usd_feed_id());
@@ -339,6 +359,9 @@ mod test {
339359
.update_price_feeds(update_data);
340360
assert!(result.is_ok());
341361

362+
assert_eq!(alice.balance(), U256::ZERO);
363+
assert_eq!(pyth_contract.balance(), update_fee);
364+
342365
assert!(pyth_contract
343366
.sender(alice)
344367
.price_feed_exists(ban_usd_feed_id()));
@@ -380,6 +403,9 @@ mod test {
380403
.sender_and_value(alice, update_fee)
381404
.update_price_feeds(update_data);
382405

406+
assert_eq!(alice.balance(), U256::ZERO);
407+
assert_eq!(pyth_contract.balance(), update_fee);
408+
383409
assert!(result.is_ok());
384410

385411
let price_result = pyth_contract
@@ -407,6 +433,9 @@ mod test {
407433
.sender_and_value(alice, update_fee)
408434
.update_price_feeds(update_data);
409435

436+
assert_eq!(alice.balance(), U256::ZERO);
437+
assert_eq!(pyth_contract.balance(), update_fee);
438+
410439
assert!(result.is_ok());
411440

412441
let price_result1 = pyth_contract

target_chains/stylus/contracts/pyth-receiver/src/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,17 @@ impl PythReceiver {
219219
&mut self,
220220
update_data: Vec<Vec<u8>>,
221221
) -> Result<(), PythReceiverError> {
222-
for data in &update_data {
223-
self.update_price_feeds_internal(data.clone(), 0, 0, false)?;
224-
}
225-
226-
let total_fee = self.get_update_fee(update_data)?;
222+
let total_fee = self.get_update_fee(update_data.clone())?;
227223

228224
let value = self.vm().msg_value();
229225

230226
if value < total_fee {
231227
return Err(PythReceiverError::InsufficientFee);
232228
}
229+
230+
for data in &update_data {
231+
self.update_price_feeds_internal(data.clone(), 0, 0, false)?;
232+
}
233233
Ok(())
234234
}
235235

0 commit comments

Comments
 (0)