Skip to content

python ss58 conversion #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: feat/thewhaleking/distribute-runtime
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
_determine_if_old_runtime_call,
_bt_decode_to_dict_or_list,
legacy_scale_decode,
convert_account_ids,
)
from async_substrate_interface.utils.storage import StorageKey
from async_substrate_interface.type_registry import _TYPE_REGISTRY
Expand Down Expand Up @@ -733,6 +734,7 @@ def __init__(
_mock: bool = False,
_log_raw_websockets: bool = False,
ws_shutdown_timer: float = 5.0,
decode_ss58: bool = False,
):
"""
The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
Expand All @@ -752,10 +754,15 @@ def __init__(
_mock: whether to use mock version of the subtensor interface
_log_raw_websockets: whether to log raw websocket requests during RPC requests
ws_shutdown_timer: how long after the last connection your websocket should close
decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples.

"""
super().__init__(
type_registry, type_registry_preset, use_remote_preset, ss58_format
type_registry,
type_registry_preset,
use_remote_preset,
ss58_format,
decode_ss58,
)
self.max_retries = max_retries
self.retry_timeout = retry_timeout
Expand Down Expand Up @@ -816,6 +823,7 @@ async def initialize(self):

if ss58_prefix_constant:
self.ss58_format = ss58_prefix_constant.value
runtime.ss58_format = ss58_prefix_constant.value
self.initialized = True
self._initializing = False

Expand Down Expand Up @@ -994,6 +1002,15 @@ async def decode_scale(
runtime = await self.init_runtime(block_hash=block_hash)
if runtime.metadata_v15 is not None or force_legacy is True:
obj = decode_by_type_string(type_string, runtime.registry, scale_bytes)
if self.decode_ss58:
try:
type_str_int = int(type_string.split("::")[1])
decoded_type_str = runtime.type_id_to_name[type_str_int]
obj = convert_account_ids(
obj, decoded_type_str, runtime.ss58_format
)
except (ValueError, KeyError):
pass
Comment on lines +1005 to +1013
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hacky logic :D

else:
obj = legacy_scale_decode(type_string, scale_bytes, runtime)
if return_scale_obj:
Expand Down Expand Up @@ -1105,6 +1122,7 @@ async def _get_runtime_for_version(
metadata_v15=metadata_v15,
runtime_info=runtime_info,
registry=registry,
ss58_format=self.ss58_format,
)
self.runtime_cache.add_item(
block=block_number,
Expand Down Expand Up @@ -3471,6 +3489,7 @@ async def query_map(
value_type,
key_hashers,
ignore_decoding_errors,
self.decode_ss58,
)
return AsyncQueryMapResult(
records=result,
Expand Down
21 changes: 20 additions & 1 deletion async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_bt_decode_to_dict_or_list,
decode_query_map,
legacy_scale_decode,
convert_account_ids,
)
from async_substrate_interface.utils.storage import StorageKey
from async_substrate_interface.type_registry import _TYPE_REGISTRY
Expand Down Expand Up @@ -486,6 +487,7 @@ def __init__(
retry_timeout: float = 60.0,
_mock: bool = False,
_log_raw_websockets: bool = False,
decode_ss58: bool = False,
):
"""
The sync compatible version of the subtensor interface commands we use in bittensor. Use this instance only
Expand All @@ -503,10 +505,15 @@ def __init__(
retry_timeout: how to long wait since the last ping to retry the RPC request
_mock: whether to use mock version of the subtensor interface
_log_raw_websockets: whether to log raw websocket requests during RPC requests
decode_ss58: Whether to decode AccountIds to SS58 or leave them in raw bytes tuples.

"""
super().__init__(
type_registry, type_registry_preset, use_remote_preset, ss58_format
type_registry,
type_registry_preset,
use_remote_preset,
ss58_format,
decode_ss58,
)
self.max_retries = max_retries
self.retry_timeout = retry_timeout
Expand Down Expand Up @@ -560,6 +567,7 @@ def initialize(self):
)
if ss58_prefix_constant:
self.ss58_format = ss58_prefix_constant.value
self.runtime.ss58_format = ss58_prefix_constant.value
self.initialized = True

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -693,6 +701,15 @@ def decode_scale(
obj = decode_by_type_string(
type_string, self.runtime.registry, scale_bytes
)
if self.decode_ss58:
try:
type_str_int = int(type_string.split("::")[1])
decoded_type_str = self.runtime.type_id_to_name[type_str_int]
obj = convert_account_ids(
obj, decoded_type_str, self.ss58_format
)
except (ValueError, KeyError):
pass
else:
obj = legacy_scale_decode(type_string, scale_bytes, self.runtime)
if return_scale_obj:
Expand Down Expand Up @@ -834,6 +851,7 @@ def get_runtime_for_version(
metadata_v15=metadata_v15,
runtime_info=runtime_info,
registry=registry,
ss58_format=self.ss58_format,
)
self.runtime_cache.add_item(
block=block_number,
Expand Down Expand Up @@ -3009,6 +3027,7 @@ def query_map(
value_type,
key_hashers,
ignore_decoding_errors,
self.decode_ss58,
)
return QueryMapResult(
records=result,
Expand Down
4 changes: 4 additions & 0 deletions async_substrate_interface/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def __init__(
metadata_v15=None,
runtime_info=None,
registry=None,
ss58_format=SS58_FORMAT,
):
self.ss58_format = ss58_format
self.config = {}
self.chain = chain
self.type_registry = type_registry
Expand Down Expand Up @@ -551,8 +553,10 @@ def __init__(
type_registry_preset: Optional[str] = None,
use_remote_preset: bool = False,
ss58_format: Optional[int] = None,
decode_ss58: bool = False,
):
# We load a very basic RuntimeConfigurationObject that is only used for the initial metadata decoding
self.decode_ss58 = decode_ss58
self.runtime_config = RuntimeConfigurationObject(ss58_format=ss58_format)
self.ss58_format = ss58_format
self.runtime_config.update_type_registry(load_type_registry_preset(name="core"))
Expand Down
93 changes: 87 additions & 6 deletions async_substrate_interface/utils/decoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union, TYPE_CHECKING
from typing import Union, TYPE_CHECKING, Any

from bt_decode import AxonInfo, PrometheusInfo, decode_list
from scalecodec import ScaleBytes
from scalecodec import ScaleBytes, ss58_encode

from async_substrate_interface.utils import hex_to_bytes
from async_substrate_interface.types import ScaleObj
Expand Down Expand Up @@ -81,6 +81,7 @@ def decode_query_map(
value_type,
key_hashers,
ignore_decoding_errors,
decode_ss58: bool = False,
):
def concat_hash_len(key_hasher: str) -> int:
"""
Expand Down Expand Up @@ -120,12 +121,19 @@ def concat_hash_len(key_hasher: str) -> int:
)
middl_index = len(all_decoded) // 2
decoded_keys = all_decoded[:middl_index]
decoded_values = [ScaleObj(x) for x in all_decoded[middl_index:]]
for dk, dv in zip(decoded_keys, decoded_values):
decoded_values = all_decoded[middl_index:]
for (kts, vts), (dk, dv) in zip(
zip(pre_decoded_key_types, pre_decoded_value_types),
zip(decoded_keys, decoded_values),
):
Comment on lines +125 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for kts, vts, dk, dv in zip(
    pre_decoded_key_types,
    pre_decoded_value_types,
    decoded_keys,
    decoded_values,
):

wouldn't it be the same logic with a more elegant look?
looks like this is optimization too. no?

try:
# strip key_hashers to use as item key
if len(param_types) - len(params) == 1:
item_key = dk[1]
if decode_ss58:
if kts[kts.index(", ") + 2 : kts.index(")")] == "scale_info::0":
item_key = ss58_encode(bytes(item_key[0]), runtime.ss58_format)

else:
item_key = tuple(
dk[key + 1] for key in range(len(params), len(param_types) + 1, 2)
Expand All @@ -135,9 +143,17 @@ def concat_hash_len(key_hasher: str) -> int:
if not ignore_decoding_errors:
raise
item_key = None

item_value = dv
result.append([item_key, item_value])
if decode_ss58:
try:
value_type_str_int = int(vts.split("::")[1])
decoded_type_str = runtime.type_id_to_name[value_type_str_int]
item_value = convert_account_ids(
dv, decoded_type_str, runtime.ss58_format
)
except (ValueError, KeyError):
pass
result.append([item_key, ScaleObj(item_value)])
return result


Expand All @@ -154,3 +170,68 @@ def legacy_scale_decode(
obj.decode(check_remaining=runtime.config.get("strict_scale_decode"))

return obj.value


def is_accountid32(value: Any) -> bool:
return (
isinstance(value, tuple)
and len(value) == 32
and all(isinstance(b, int) and 0 <= b <= 255 for b in value)
)


def convert_account_ids(value: Any, type_str: str, ss58_format=42) -> Any:
if "AccountId32" not in type_str:
return value

# Option<T>
if type_str.startswith("Option<") and value is not None:
inner_type = type_str[7:-1]
return convert_account_ids(value, inner_type)
# Vec<T>
if type_str.startswith("Vec<") and isinstance(value, (list, tuple)):
inner_type = type_str[4:-1]
return tuple(convert_account_ids(v, inner_type) for v in value)

# Vec<Vec<T>>
if type_str.startswith("Vec<Vec<") and isinstance(value, (list, tuple)):
inner_type = type_str[8:-2]
return tuple(
tuple(convert_account_ids(v2, inner_type) for v2 in v1) for v1 in value
)

# Tuple
if type_str.startswith("(") and isinstance(value, (list, tuple)):
inner_parts = split_tuple_type(type_str)
return tuple(convert_account_ids(v, t) for v, t in zip(value, inner_parts))

# AccountId32
if type_str == "AccountId32" and is_accountid32(value[0]):
return ss58_encode(bytes(value[0]), ss58_format=ss58_format)

# Fallback
return value


def split_tuple_type(type_str: str) -> list[str]:
"""
Splits a type string like '(AccountId32, Vec<StakeInfo>)' into ['AccountId32', 'Vec<StakeInfo>']
Handles nested generics.
"""
s = type_str[1:-1]
parts = []
depth = 0
current = ""
for char in s:
if char == "," and depth == 0:
parts.append(current.strip())
current = ""
else:
if char == "<":
depth += 1
elif char == ">":
depth -= 1
current += char
if current:
parts.append(current.strip())
return parts
2 changes: 2 additions & 0 deletions tests/helpers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@
)

ARCHIVE_ENTRYPOINT = "wss://archive.chain.opentensor.ai:443"

LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443"
45 changes: 43 additions & 2 deletions tests/unit_tests/asyncio_/test_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from unittest.mock import AsyncMock, MagicMock, ANY

import pytest
from scalecodec import ss58_encode
from websockets.exceptions import InvalidURI

from async_substrate_interface.async_substrate import AsyncSubstrateInterface
from async_substrate_interface.types import ScaleObj
from tests.helpers.settings import ARCHIVE_ENTRYPOINT
from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT


@pytest.mark.asyncio
Expand Down Expand Up @@ -100,7 +101,7 @@ async def test_runtime_call(monkeypatch):
@pytest.mark.asyncio
async def test_websocket_shutdown_timer():
# using default ws shutdown timer of 5.0 seconds
async with AsyncSubstrateInterface("wss://lite.sub.latent.to:443") as substrate:
async with AsyncSubstrateInterface(LATENT_LITE_ENTRYPOINT) as substrate:
await substrate.get_chain_head()
await asyncio.sleep(6)
assert (
Expand Down Expand Up @@ -141,3 +142,43 @@ async def test_legacy_decoding():
block_hash=block_hash,
)
assert timestamp.value == 1716358476004


@pytest.mark.asyncio
async def test_ss58_conversion():
async with AsyncSubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, decode_ss58=False
) as substrate:
block_hash = await substrate.get_chain_finalised_head()
qm = await substrate.query_map(
"SubtensorModule",
"OwnedHotkeys",
block_hash=block_hash,
)
# only do the first page, bc otherwise this will be massive
for key, value in qm.records:
assert isinstance(key, tuple)
assert isinstance(value, ScaleObj)
assert isinstance(value.value, list)
assert len(key) == 1
for key_tuple in value.value:
assert len(key_tuple[0]) == 32
random_key = key_tuple[0]

ss58_of_key = ss58_encode(bytes(random_key), substrate.ss58_format)
assert isinstance(ss58_of_key, str)

substrate.decode_ss58 = True # change to decoding True

qm = await substrate.query_map(
"SubtensorModule",
"OwnedHotkeys",
block_hash=block_hash,
)
for key, value in qm.records:
assert isinstance(key, str)
assert isinstance(value, ScaleObj)
assert isinstance(value.value, list)
if len(value.value) > 0:
for decoded_key in value.value:
assert isinstance(decoded_key, str)
Loading