Skip to content

Commit 90afd02

Browse files
committed
update
1 parent 84dc98c commit 90afd02

File tree

3 files changed

+111
-50
lines changed

3 files changed

+111
-50
lines changed

target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,10 @@ contract PythUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
375375
int32 expo1,
376376
int64 price2,
377377
int32 expo2,
378-
int32 targetExpo,
378+
uint8 targetDecimals,
379379
int64 expectedPrice
380380
) internal {
381-
int64 price = PythUtils.deriveCrossRate(price1, expo1, price2, expo2, targetExpo);
381+
int64 price = PythUtils.deriveCrossRate(price1, expo1, price2, expo2, targetDecimals);
382382
assertEq(price, expectedPrice);
383383
}
384384

@@ -387,11 +387,11 @@ contract PythUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
387387
int32 expo1,
388388
int64 price2,
389389
int32 expo2,
390-
int32 targetExpo,
390+
uint8 targetDecimals,
391391
bytes4 expectedError
392392
) internal {
393393
vm.expectRevert(expectedError);
394-
PythUtils.deriveCrossRate(price1, expo1, price2, expo2, targetExpo);
394+
PythUtils.deriveCrossRate(price1, expo1, price2, expo2, targetDecimals);
395395
}
396396

397397
function testConvertToUnit() public {
@@ -403,6 +403,18 @@ contract PythUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
403403
vm.expectRevert(PythErrors.InvalidInputExpo.selector);
404404
PythUtils.convertToUint(100, -256, 18);
405405

406+
// This test will fail as the 10 ** 237 is too large for a uint256
407+
vm.expectRevert(PythErrors.ExponentOverflow.selector);
408+
assertEq(PythUtils.convertToUint(100, -255, 18), 0);
409+
410+
// Combined Exponent can't be greater than 77
411+
vm.expectRevert(PythErrors.ExponentOverflow.selector);
412+
assertEq(PythUtils.convertToUint(100, 60, 18), 0);
413+
414+
// Combined Exponent can't be less than -77
415+
vm.expectRevert(PythErrors.ExponentOverflow.selector);
416+
assertEq(PythUtils.convertToUint(100, -96, 18), 0);
417+
406418
// Negative Exponent Tests
407419
// Price with 18 decimals and exponent -5
408420
assertEq(
@@ -438,39 +450,76 @@ contract PythUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
438450

439451

440452
// Edge Cases
441-
// This test will fail as the 10 ** 237 is too large for a uint256
442-
// assertEq(PythUtils.convertToUint(100, -255, 18), 0);
443-
// assertEq(PythUtils.convertToUint(100, 255, 18), 100_00_000_000_000_000_000_000_000);
453+
// 1. Test: price = 0, any expo/decimals returns 0
454+
assertEq(PythUtils.convertToUint(0, -77, 0), 0);
455+
assertEq(PythUtils.convertToUint(0, 0, 0), 0);
456+
assertEq(PythUtils.convertToUint(0, 77, 0), 0);
457+
assertEq(PythUtils.convertToUint(0, -77, 77), 0);
458+
459+
// 2. Test: smallest positive price, maximum downward exponent (should round to zero)
460+
assertEq(PythUtils.convertToUint(1, -77, 0), 0);
461+
assertEq(PythUtils.convertToUint(1, -77, 77), 1);
462+
463+
// 3. Test: combinedExpo == 0 (should be identical to price)
464+
assertEq(PythUtils.convertToUint(123456, 0, 0), 123456);
465+
assertEq(PythUtils.convertToUint(123456, -5, 5), 123456); // -5 + 5 == 0
466+
467+
// 4. Test: combinedExpo > 0 (should shift price up)
468+
assertEq(PythUtils.convertToUint(123456, 5, 0), 12345600000);
469+
assertEq(PythUtils.convertToUint(123456, 5, 2), 1234560000000);
470+
471+
// 5. Test: combinedExpo < 0 (should shift price down)
472+
assertEq(PythUtils.convertToUint(123456, -5, 0), 1);
473+
assertEq(PythUtils.convertToUint(123456, -5, 2), 123);
474+
475+
// 6. Test: division with truncation
476+
assertEq(PythUtils.convertToUint(999, -2, 0), 9); // 999/100 = 9 (truncated)
477+
assertEq(PythUtils.convertToUint(199, -2, 0), 1); // 199/100 = 1 (truncated)
478+
assertEq(PythUtils.convertToUint(99, -2, 0), 0); // 99/100 = 0 (truncated)
479+
480+
// 7. Test: Big price and scaling, but outside of bounds
481+
vm.expectRevert(PythErrors.CombinedPriceOverflow.selector);
482+
assertEq(PythUtils.convertToUint(100_000_000, 10, 60),0);
483+
484+
// 8. Test: Big price and scaling
485+
assertEq(PythUtils.convertToUint(100_000_000, -80, 10),0);
486+
487+
// 9. Test: Decimals just save from truncation
488+
assertEq(PythUtils.convertToUint(5, -1, 1), 5); // 5/10*10 = 5
489+
assertEq(PythUtils.convertToUint(5, -1, 2), 50); // 5/10*100 = 50
444490
}
445491

446-
function testCombinePrices() public {
492+
function testDeriveCrossRate() public {
447493

448-
// Basic Tests
449-
assertCrossRateEquals(500, -8, 500, -8, -5, 100000);
450-
assertCrossRateEquals(10_000, -8, 100, -2, -5, 10);
451-
assertCrossRateEquals(10_000, -2, 100, -8, -4, 1_000_000_000_000);
494+
// Test 1: Prices can't be negative
495+
assertCrossRateReverts(-100, -2, 100, -2, 5, PythErrors.NegativeInputPrice.selector);
496+
assertCrossRateReverts(100, -2, -100, -2, 5, PythErrors.NegativeInputPrice.selector);
497+
assertCrossRateReverts(-100, -2, -100, -2, 5, PythErrors.NegativeInputPrice.selector);
452498

453-
// Negative Price Tests
454-
assertCrossRateReverts(-100, -2, 100, -2, -5, PythErrors.NegativeInputPrice.selector);
455-
assertCrossRateReverts(100, -2, -100, -2, -5, PythErrors.NegativeInputPrice.selector);
456-
assertCrossRateReverts(-100, -2, -100, -2, -5, PythErrors.NegativeInputPrice.selector);
457499

458-
// Positive Exponent Tests
459-
assertCrossRateReverts(100, 2, 100, -2, -5, PythErrors.InvalidInputExpo.selector);
460-
assertCrossRateReverts(100, -2, 100, 2, -5, PythErrors.InvalidInputExpo.selector);
461-
assertCrossRateReverts(100, 2, 100, 2, -5, PythErrors.InvalidInputExpo.selector);
500+
// Test 2: Exponent can't be positive
501+
assertCrossRateReverts(100, 2, 100, -2, 5, PythErrors.InvalidInputExpo.selector);
502+
assertCrossRateReverts(100, -2, 100, 2, 5, PythErrors.InvalidInputExpo.selector);
503+
assertCrossRateReverts(100, 2, 100, 2, 5, PythErrors.InvalidInputExpo.selector);
504+
505+
// Test 3: Exponent can't be less than -255
506+
assertCrossRateReverts(100, -256, 100, -2, 5, PythErrors.InvalidInputExpo.selector);
507+
assertCrossRateReverts(100, -2, 100, -256, 5, PythErrors.InvalidInputExpo.selector);
462508

463-
// Invalid Target Exponent Tests
464-
assertCrossRateReverts(100, -2, 100, -2, 1, PythErrors.InvalidTargetExpo.selector);
509+
// Test 4: Basic Tests
510+
assertCrossRateEquals(500, -8, 500, -8, 5, 100000);
511+
assertCrossRateEquals(10_000, -8, 100, -2, 5, 10);
512+
assertCrossRateEquals(10_000, -2, 100, -8, 5, 1_000_000_000_000);
465513

466-
// Different Exponent Tests
467-
assertCrossRateEquals(10_000, -2, 100, -4, -4, 100_000_000);
468-
assertCrossRateEquals(10_000, -2, 10_000, -1, -2, 10);
514+
// Test 5: Different Exponent Tests
515+
assertCrossRateEquals(10_000, -2, 100, -4, 0, 10_000); // 10_000 / 100 = 100 * 10(-2 - -4) = 10_000 with 0 decimals = 10_000
516+
assertCrossRateEquals(10_000, -2, 100, -4, 5, 0); // 10_000 / 100 = 100 * 10(-2 - -4) = 10_000 with 5 decimals = 0
517+
assertCrossRateEquals(10_000, -2, 10_000, -1, 5, 0); // It will truncate to 0
469518
assertCrossRateEquals(10_000, -10, 10_000, -2, 0, 0); // It will truncate to 0
470519

471-
// Exponent Edge Tests
472-
assertCrossRateEquals(10_000, 0, 100, 0, 0, 100);
473-
assertCrossRateEquals(10_000, 0, 100, 0, -255, 100);
520+
// // Exponent Edge Tests
521+
// assertCrossRateEquals(10_000, 0, 100, 0, 0, 100);
522+
// assertCrossRateEquals(10_000, 0, 100, 0, -255, 100);
474523
// assertCrossRateEquals(10_000, 0, 100, -255, -255, 100, -255);
475524
// assertCrossRateEquals(10_000, -255, 100, 0, 0, 100, 0);
476525

target_chains/ethereum/sdk/solidity/PythErrors.sol

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ library PythErrors {
5353
error NegativeInputPrice();
5454
// The Input Exponent is invalid.
5555
error InvalidInputExpo();
56-
// The target exponent is invalid.
57-
error InvalidTargetExpo();
5856
// The combined price is greater than int64.max.
5957
error CombinedPriceOverflow();
58+
// The exponent is greater than 77 or less than -77.
59+
error ExponentOverflow();
6060
}

target_chains/ethereum/sdk/solidity/PythUtils.sol

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ library PythUtils {
1717
/// @dev Function will lose precision if targetDecimals is less than the Pyth price decimals.
1818
/// This method will truncate any digits that cannot be represented by the targetDecimals.
1919
/// e.g. If the price is 0.000123 and the targetDecimals is 2, the result will be 0
20+
/// This function will overflow if the combined exponent(targetDecimals + expo) is greater than 77 or less than -77.
21+
/// This function will also revert if prices combined with the targetDecimals are greater than 10 ** 77 or less than 10 ** -77.
2022
function convertToUint(
2123
int64 price,
2224
int32 expo,
@@ -28,17 +30,26 @@ library PythUtils {
2830
if (expo < -255) {
2931
revert PythErrors.InvalidInputExpo();
3032
}
33+
// Compute the combined exponent as an int256 for safety
34+
int256 combinedExpo = int256(uint256(targetDecimals)) + int256(expo);
3135

32-
int256 combinedExpo = int32(int8(targetDecimals)) + expo;
36+
// Bounds check: prevent overflow/underflow with base 10 exponentiation
37+
// Calculation: 10 ** n <= (2 ** 256) - 1
38+
// n <= log10((2 ** 256) - 1)
39+
// n <= 77.2
40+
if (combinedExpo > 77 || combinedExpo < -77) revert PythErrors.ExponentOverflow();
41+
42+
// price is int64, always >= 0 here
43+
uint256 unsignedPrice = uint256(uint64(price));
3344

3445
if (combinedExpo > 0) {
35-
(bool success, uint256 result) = Math.tryMul(uint64(price), 10 ** uint(uint32(int32(combinedExpo))));
46+
(bool success, uint256 result) = Math.tryMul(unsignedPrice, 10 ** uint(combinedExpo));
3647
if (!success) {
3748
revert PythErrors.CombinedPriceOverflow();
3849
}
3950
return result;
4051
} else {
41-
(bool success, uint256 result) = Math.tryDiv(uint64(price), 10 ** uint(Math.abs(combinedExpo)));
52+
(bool success, uint256 result) = Math.tryDiv(unsignedPrice, 10 ** uint(Math.abs(combinedExpo)));
4253
if (!success) {
4354
revert PythErrors.CombinedPriceOverflow();
4455
}
@@ -51,7 +62,7 @@ library PythUtils {
5162
/// @param expo1 The exponent of the first price
5263
/// @param price2 The second price (c/b)
5364
/// @param expo2 The exponent of the second price
54-
/// @param targetExpo The target exponent of the cross-rate
65+
/// @param targetDecimals The target number of decimals for the cross-rate
5566
/// @return crossRate The cross-rate (a/c)
5667
/// @dev This function will revert if either price is negative or if the exponents are invalid.
5768
/// @dev This function will also revert if the cross-rate is greater than int64.max
@@ -61,7 +72,7 @@ library PythUtils {
6172
int32 expo1,
6273
int64 price2,
6374
int32 expo2,
64-
int32 targetExpo
75+
uint8 targetExpo
6576
) public pure returns (int64 crossRate) {
6677
// Check if the input prices are negative
6778
if (price1 < 0 || price2 < 0) {
@@ -71,37 +82,38 @@ library PythUtils {
7182
if (expo1 > 0 || expo2 > 0 || expo1 < -255 || expo2 < -255) {
7283
revert PythErrors.InvalidInputExpo();
7384
}
74-
// Check if the target exponent is valid and not less than -255
75-
if (targetExpo > 0 || targetExpo < -255) {
76-
revert PythErrors.InvalidTargetExpo();
77-
}
7885

7986
// Calculate the combined price with precision of 36
8087
uint256 fixedPointPrice = Math.mulDiv(uint64(price1), 10 ** PRECISION, uint64(price2));
81-
// TODO: Check for underflow
8288
int32 combinedExpo = expo1 - expo2 - int32(PRECISION);
83-
console.log("fixedPointPrice", fixedPointPrice);
84-
console.log("combinedExpo", combinedExpo);
85-
console.log("targetExpo", targetExpo);
89+
90+
console.log("PythUtils.deriveCrossRate: fixedPointPrice", fixedPointPrice);
91+
console.log("PythUtils.deriveCrossRate: combinedExpo", combinedExpo);
92+
93+
// uint256 crossRateUnchecked = convertToUint(fixedPointPrice, combinedExpo, targetDecimals);
94+
95+
96+
8697
// Convert the price to the target exponent
87-
uint256 combined;
98+
// We can't use the convertToUint function because it accepts int64 and we need to use uint256
99+
uint256 fixedPointPrice;
88100
if (combinedExpo >= targetExpo) {
89101
console.log("combinedExpo >= targetExpo");
90102
// If combinedExpo is greater than or equal to targetExpo, we need to multiply
91-
combined = fixedPointPrice * 10 ** uint32(combinedExpo + targetExpo);
103+
fixedPointPrice = fixedPointPrice * 10 ** uint32(combinedExpo + targetExpo);
92104
} else {
93105
console.log("combinedExpo - targetExpo", combinedExpo - targetExpo);
94106
// If combinedExpo is less than targetExpo, we need to divide
95-
combined = fixedPointPrice / 10 ** uint32(targetExpo - combinedExpo);
107+
fixedPointPrice = fixedPointPrice / 10 ** uint32(targetExpo - combinedExpo);
96108
}
97109

98-
console.log("combined", combined);
110+
console.log("PythUtils.deriveCrossRate: crossRateUnchecked", fixedPointPrice);
99111

100112
// Check if the combined price fits in int64
101-
if (combined > uint256(uint64(type(int64).max))) {
102-
revert();
113+
if (fixedPointPrice > uint256(uint64(type(int64).max))) {
114+
revert PythErrors.CombinedPriceOverflow();
103115
}
104116

105-
return int64(uint64(combined));
117+
return int64(uint64(fixedPointPrice));
106118
}
107119
}

0 commit comments

Comments
 (0)