1
1
import base64
2
2
import binascii
3
- import struct
4
3
from struct import unpack
5
- from typing import Any , Dict , List , Optional
4
+ from typing import List , Literal , Optional , Union , cast
6
5
7
6
from Crypto .Hash import keccak
8
7
from loguru import logger
@@ -164,7 +163,7 @@ def __str__(self):
164
163
165
164
166
165
# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/encoding.ts#L24
167
- def encode_vaa_for_chain (vaa , vaa_format , buffer = False ):
166
+ def encode_vaa_for_chain (vaa : str , vaa_format : str , buffer = False ) -> Union [ bytes , str ] :
168
167
# check if vaa is already in vaa_format
169
168
if isinstance (vaa , str ):
170
169
if vaa_format == DEFAULT_VAA_ENCODING :
@@ -197,7 +196,7 @@ def encode_vaa_for_chain(vaa, vaa_format, buffer=False):
197
196
198
197
# Referenced from https://github.com/wormhole-foundation/wormhole/blob/main/sdk/js/src/vaa/wormhole.ts#L26-L56
199
198
def parse_vaa (vaa , encoding ):
200
- vaa = encode_vaa_for_chain (vaa , encoding , buffer = True )
199
+ vaa = cast ( bytes , encode_vaa_for_chain (vaa , encoding , buffer = True ) )
201
200
202
201
num_signers = vaa [5 ]
203
202
sig_length = 66
@@ -284,7 +283,7 @@ def parse_batch_price_attestation(bytes_):
284
283
offset += 2
285
284
286
285
price_attestations = []
287
- for i in range (batch_len ):
286
+ for _ in range (batch_len ):
288
287
price_attestations .append (
289
288
parse_price_attestation (bytes_ [offset : offset + attestation_size ])
290
289
)
@@ -401,13 +400,13 @@ def is_accumulator_update(vaa, encoding=DEFAULT_VAA_ENCODING) -> bool:
401
400
Returns:
402
401
bool: True if the VAA is an accumulator update, False otherwise.
403
402
"""
404
- if encode_vaa_for_chain (vaa , encoding , buffer = True )[:4 ].hex () == ACCUMULATOR_MAGIC :
403
+ if cast ( bytes , encode_vaa_for_chain (vaa , encoding , buffer = True ) )[:4 ].hex () == ACCUMULATOR_MAGIC :
405
404
return True
406
405
return False
407
406
408
407
409
408
# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/rest.ts#L139
410
- def vaa_to_price_infos (vaa , encoding = DEFAULT_VAA_ENCODING ) -> List [PriceInfo ]:
409
+ def vaa_to_price_infos (vaa , encoding : Literal [ "hex" , "base64" ] = DEFAULT_VAA_ENCODING ) -> Optional [ List [PriceInfo ] ]:
411
410
if is_accumulator_update (vaa , encoding ):
412
411
return extract_price_info_from_accumulator_update (vaa , encoding )
413
412
parsed_vaa = parse_vaa (vaa , encoding )
@@ -425,7 +424,7 @@ def vaa_to_price_infos(vaa, encoding=DEFAULT_VAA_ENCODING) -> List[PriceInfo]:
425
424
return price_infos
426
425
427
426
428
- def vaa_to_price_info (id , vaa , encoding = DEFAULT_VAA_ENCODING ) -> Optional [PriceInfo ]:
427
+ def vaa_to_price_info (id : str , vaa : str , encoding : Literal [ "hex" , "base64" ] = DEFAULT_VAA_ENCODING ) -> Optional [PriceInfo ]:
429
428
"""
430
429
This function retrieves a specific PriceInfo object from a given VAA.
431
430
@@ -502,14 +501,21 @@ def price_attestation_to_price_feed(price_attestation):
502
501
503
502
# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/1a00598334e52fc5faf967eb1170d7fc23ad828b/price_service/server/src/rest.ts#L137
504
503
def extract_price_info_from_accumulator_update (
505
- update_data , encoding
506
- ) -> Optional [Dict [str , Any ]]:
504
+ update_data : str ,
505
+ encoding : Literal ["hex" , "base64" ]
506
+ ) -> Optional [List [PriceInfo ]]:
507
507
parsed_update_data = parse_accumulator_update (update_data , encoding )
508
+ if parsed_update_data is None :
509
+ return None
510
+
508
511
vaa_buffer = parsed_update_data .vaa
509
512
if encoding == "hex" :
510
513
vaa_str = vaa_buffer .hex ()
511
514
elif encoding == "base64" :
512
515
vaa_str = base64 .b64encode (vaa_buffer ).decode ("ascii" )
516
+ else :
517
+ raise ValueError (f"Invalid encoding: { encoding } " )
518
+
513
519
parsed_vaa = parse_vaa (vaa_str , encoding )
514
520
price_infos = []
515
521
for update in parsed_update_data .updates :
@@ -581,7 +587,6 @@ def extract_price_info_from_accumulator_update(
581
587
582
588
return price_infos
583
589
584
-
585
590
def compress_accumulator_update (update_data_list , encoding ) -> List [str ]:
586
591
"""
587
592
This function compresses a list of accumulator update data by combining those with the same VAA.
@@ -593,17 +598,21 @@ def compress_accumulator_update(update_data_list, encoding) -> List[str]:
593
598
594
599
Returns:
595
600
List[str]: A list of serialized accumulator update data. Each item in the list is a hexadecimal string representing
596
- an accumulator update data. The updates with the same VAA are combined and split into chunks of 255 updates each.
601
+ an accumulator update data. The updates with the same VAA payload are combined and split into chunks of 255 updates each.
597
602
"""
598
603
parsed_data_dict = {} # Use a dictionary for O(1) lookup
599
604
# Combine the ones with the same VAA to a list
600
605
for update_data in update_data_list :
601
606
parsed_update_data = parse_accumulator_update (update_data , encoding )
602
- vaa = parsed_update_data .vaa
603
607
604
- if vaa not in parsed_data_dict :
605
- parsed_data_dict [vaa ] = []
606
- parsed_data_dict [vaa ].append (parsed_update_data )
608
+ if parsed_update_data is None :
609
+ raise ValueError (f"Invalid accumulator update data: { update_data } " )
610
+
611
+ payload = parse_vaa (parsed_update_data .vaa .hex (), "hex" )["payload" ]
612
+
613
+ if payload not in parsed_data_dict :
614
+ parsed_data_dict [payload ] = []
615
+ parsed_data_dict [payload ].append (parsed_update_data )
607
616
parsed_data_list = list (parsed_data_dict .values ())
608
617
609
618
# Combines accumulator update data with the same VAA into a single dictionary
@@ -698,7 +707,7 @@ def serialize_accumulator_update(data, encoding):
698
707
return base64 .b64encode (serialized_data ).decode ("ascii" )
699
708
700
709
701
- def parse_accumulator_update (update_data , encoding ) :
710
+ def parse_accumulator_update (update_data : str , encoding : str ) -> Optional [ AccumulatorUpdate ] :
702
711
"""
703
712
This function parses an accumulator update data.
704
713
@@ -724,7 +733,8 @@ def parse_accumulator_update(update_data, encoding):
724
733
725
734
If the update type is not 0, the function logs an info message and returns None.
726
735
"""
727
- encoded_update_data = encode_vaa_for_chain (update_data , encoding , buffer = True )
736
+ encoded_update_data = cast (bytes , encode_vaa_for_chain (update_data , encoding , buffer = True ))
737
+
728
738
offset = 0
729
739
magic = encoded_update_data [offset : offset + 4 ]
730
740
offset += 4
0 commit comments