Skip to content

Commit 074670b

Browse files
authored
Compress by vaa payload (#42)
* refactor: improve type hints * fix: use vaa payload instead of vaa for accumulator compression * chore: bump version
1 parent dfa8ab4 commit 074670b

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

pythclient/price_feeds.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import base64
22
import binascii
3-
import struct
43
from struct import unpack
5-
from typing import Any, Dict, List, Optional
4+
from typing import List, Literal, Optional, Union, cast
65

76
from Crypto.Hash import keccak
87
from loguru import logger
@@ -164,7 +163,7 @@ def __str__(self):
164163

165164

166165
# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/encoding.ts#L24
167-
def encode_vaa_for_chain(vaa, vaa_format, buffer=False):
166+
def encode_vaa_for_chain(vaa: str, vaa_format: str, buffer=False) -> Union[bytes, str]:
168167
# check if vaa is already in vaa_format
169168
if isinstance(vaa, str):
170169
if vaa_format == DEFAULT_VAA_ENCODING:
@@ -197,7 +196,7 @@ def encode_vaa_for_chain(vaa, vaa_format, buffer=False):
197196

198197
# Referenced from https://github.com/wormhole-foundation/wormhole/blob/main/sdk/js/src/vaa/wormhole.ts#L26-L56
199198
def parse_vaa(vaa, encoding):
200-
vaa = encode_vaa_for_chain(vaa, encoding, buffer=True)
199+
vaa = cast(bytes, encode_vaa_for_chain(vaa, encoding, buffer=True))
201200

202201
num_signers = vaa[5]
203202
sig_length = 66
@@ -284,7 +283,7 @@ def parse_batch_price_attestation(bytes_):
284283
offset += 2
285284

286285
price_attestations = []
287-
for i in range(batch_len):
286+
for _ in range(batch_len):
288287
price_attestations.append(
289288
parse_price_attestation(bytes_[offset : offset + attestation_size])
290289
)
@@ -401,13 +400,13 @@ def is_accumulator_update(vaa, encoding=DEFAULT_VAA_ENCODING) -> bool:
401400
Returns:
402401
bool: True if the VAA is an accumulator update, False otherwise.
403402
"""
404-
if encode_vaa_for_chain(vaa, encoding, buffer=True)[:4].hex() == ACCUMULATOR_MAGIC:
403+
if cast(bytes, encode_vaa_for_chain(vaa, encoding, buffer=True))[:4].hex() == ACCUMULATOR_MAGIC:
405404
return True
406405
return False
407406

408407

409408
# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/rest.ts#L139
410-
def vaa_to_price_infos(vaa, encoding=DEFAULT_VAA_ENCODING) -> List[PriceInfo]:
409+
def vaa_to_price_infos(vaa, encoding: Literal["hex", "base64"] = DEFAULT_VAA_ENCODING) -> Optional[List[PriceInfo]]:
411410
if is_accumulator_update(vaa, encoding):
412411
return extract_price_info_from_accumulator_update(vaa, encoding)
413412
parsed_vaa = parse_vaa(vaa, encoding)
@@ -425,7 +424,7 @@ def vaa_to_price_infos(vaa, encoding=DEFAULT_VAA_ENCODING) -> List[PriceInfo]:
425424
return price_infos
426425

427426

428-
def vaa_to_price_info(id, vaa, encoding=DEFAULT_VAA_ENCODING) -> Optional[PriceInfo]:
427+
def vaa_to_price_info(id: str, vaa: str, encoding: Literal["hex", "base64"] = DEFAULT_VAA_ENCODING) -> Optional[PriceInfo]:
429428
"""
430429
This function retrieves a specific PriceInfo object from a given VAA.
431430
@@ -502,14 +501,21 @@ def price_attestation_to_price_feed(price_attestation):
502501

503502
# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/1a00598334e52fc5faf967eb1170d7fc23ad828b/price_service/server/src/rest.ts#L137
504503
def extract_price_info_from_accumulator_update(
505-
update_data, encoding
506-
) -> Optional[Dict[str, Any]]:
504+
update_data: str,
505+
encoding: Literal["hex", "base64"]
506+
) -> Optional[List[PriceInfo]]:
507507
parsed_update_data = parse_accumulator_update(update_data, encoding)
508+
if parsed_update_data is None:
509+
return None
510+
508511
vaa_buffer = parsed_update_data.vaa
509512
if encoding == "hex":
510513
vaa_str = vaa_buffer.hex()
511514
elif encoding == "base64":
512515
vaa_str = base64.b64encode(vaa_buffer).decode("ascii")
516+
else:
517+
raise ValueError(f"Invalid encoding: {encoding}")
518+
513519
parsed_vaa = parse_vaa(vaa_str, encoding)
514520
price_infos = []
515521
for update in parsed_update_data.updates:
@@ -581,7 +587,6 @@ def extract_price_info_from_accumulator_update(
581587

582588
return price_infos
583589

584-
585590
def compress_accumulator_update(update_data_list, encoding) -> List[str]:
586591
"""
587592
This function compresses a list of accumulator update data by combining those with the same VAA.
@@ -593,17 +598,21 @@ def compress_accumulator_update(update_data_list, encoding) -> List[str]:
593598
594599
Returns:
595600
List[str]: A list of serialized accumulator update data. Each item in the list is a hexadecimal string representing
596-
an accumulator update data. The updates with the same VAA are combined and split into chunks of 255 updates each.
601+
an accumulator update data. The updates with the same VAA payload are combined and split into chunks of 255 updates each.
597602
"""
598603
parsed_data_dict = {} # Use a dictionary for O(1) lookup
599604
# Combine the ones with the same VAA to a list
600605
for update_data in update_data_list:
601606
parsed_update_data = parse_accumulator_update(update_data, encoding)
602-
vaa = parsed_update_data.vaa
603607

604-
if vaa not in parsed_data_dict:
605-
parsed_data_dict[vaa] = []
606-
parsed_data_dict[vaa].append(parsed_update_data)
608+
if parsed_update_data is None:
609+
raise ValueError(f"Invalid accumulator update data: {update_data}")
610+
611+
payload = parse_vaa(parsed_update_data.vaa.hex(), "hex")["payload"]
612+
613+
if payload not in parsed_data_dict:
614+
parsed_data_dict[payload] = []
615+
parsed_data_dict[payload].append(parsed_update_data)
607616
parsed_data_list = list(parsed_data_dict.values())
608617

609618
# Combines accumulator update data with the same VAA into a single dictionary
@@ -698,7 +707,7 @@ def serialize_accumulator_update(data, encoding):
698707
return base64.b64encode(serialized_data).decode("ascii")
699708

700709

701-
def parse_accumulator_update(update_data, encoding):
710+
def parse_accumulator_update(update_data: str, encoding: str) -> Optional[AccumulatorUpdate]:
702711
"""
703712
This function parses an accumulator update data.
704713
@@ -724,7 +733,8 @@ def parse_accumulator_update(update_data, encoding):
724733
725734
If the update type is not 0, the function logs an info message and returns None.
726735
"""
727-
encoded_update_data = encode_vaa_for_chain(update_data, encoding, buffer=True)
736+
encoded_update_data = cast(bytes, encode_vaa_for_chain(update_data, encoding, buffer=True))
737+
728738
offset = 0
729739
magic = encoded_update_data[offset : offset + 4]
730740
offset += 4

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name='pythclient',
10-
version='0.1.14',
10+
version='0.1.15',
1111
packages=['pythclient'],
1212
author='Pyth Developers',
1313
author_email='[email protected]',

0 commit comments

Comments
 (0)