9
9
import logging
10
10
import ssl
11
11
import time
12
+ import warnings
13
+ from unittest .mock import AsyncMock
12
14
from hashlib import blake2b
13
15
from typing import (
14
16
Optional ,
@@ -530,15 +532,31 @@ def __init__(
530
532
self ._exit_task = None
531
533
self ._open_subscriptions = 0
532
534
self ._options = options if options else {}
533
- self .last_received = time .time ()
535
+ try :
536
+ now = asyncio .get_running_loop ().time ()
537
+ except RuntimeError :
538
+ warnings .warn (
539
+ "You are instantiating the AsyncSubstrateInterface Websocket outside of an event loop. "
540
+ "Verify this is intended."
541
+ )
542
+ now = asyncio .new_event_loop ().time ()
543
+ self .last_received = now
544
+ self .last_sent = now
534
545
535
546
async def __aenter__ (self ):
536
547
async with self ._lock :
537
548
self ._in_use += 1
538
549
await self .connect ()
539
550
return self
540
551
552
+ @staticmethod
553
+ async def loop_time () -> float :
554
+ return asyncio .get_running_loop ().time ()
555
+
541
556
async def connect (self , force = False ):
557
+ now = await self .loop_time ()
558
+ self .last_received = now
559
+ self .last_sent = now
542
560
if self ._exit_task :
543
561
self ._exit_task .cancel ()
544
562
if not self ._initialized or force :
@@ -594,7 +612,7 @@ async def _recv(self) -> None:
594
612
try :
595
613
# TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic
596
614
response = json .loads (await self .ws .recv (decode = False ))
597
- self .last_received = time . time ()
615
+ self .last_received = await self . loop_time ()
598
616
async with self ._lock :
599
617
# note that these 'subscriptions' are all waiting sent messages which have not received
600
618
# responses, and thus are not the same as RPC 'subscriptions', which are unique
@@ -630,12 +648,12 @@ async def send(self, payload: dict) -> int:
630
648
Returns:
631
649
id: the internal ID of the request (incremented int)
632
650
"""
633
- # async with self._lock:
634
651
original_id = get_next_id ()
635
652
# self._open_subscriptions += 1
636
653
await self .max_subscriptions .acquire ()
637
654
try :
638
655
await self .ws .send (json .dumps ({** payload , ** {"id" : original_id }}))
656
+ self .last_sent = await self .loop_time ()
639
657
return original_id
640
658
except (ConnectionClosed , ssl .SSLError , EOFError ):
641
659
async with self ._lock :
@@ -697,13 +715,16 @@ def __init__(
697
715
self .chain_endpoint = url
698
716
self .url = url
699
717
self ._chain = chain_name
700
- self .ws = Websocket (
701
- url ,
702
- options = {
703
- "max_size" : self .ws_max_size ,
704
- "write_limit" : 2 ** 16 ,
705
- },
706
- )
718
+ if not _mock :
719
+ self .ws = Websocket (
720
+ url ,
721
+ options = {
722
+ "max_size" : self .ws_max_size ,
723
+ "write_limit" : 2 ** 16 ,
724
+ },
725
+ )
726
+ else :
727
+ self .ws = AsyncMock (spec = Websocket )
707
728
self ._lock = asyncio .Lock ()
708
729
self .config = {
709
730
"use_remote_preset" : use_remote_preset ,
@@ -726,9 +747,11 @@ def __init__(
726
747
self ._initializing = False
727
748
self .registry_type_map = {}
728
749
self .type_id_to_name = {}
750
+ self ._mock = _mock
729
751
730
752
async def __aenter__ (self ):
731
- await self .initialize ()
753
+ if not self ._mock :
754
+ await self .initialize ()
732
755
return self
733
756
734
757
async def initialize (self ):
@@ -2120,7 +2143,11 @@ async def _make_rpc_request(
2120
2143
2121
2144
if request_manager .is_complete :
2122
2145
break
2123
- if time .time () - self .ws .last_received >= self .retry_timeout :
2146
+ if (
2147
+ (current_time := await self .ws .loop_time ()) - self .ws .last_received
2148
+ >= self .retry_timeout
2149
+ and current_time - self .ws .last_sent >= self .retry_timeout
2150
+ ):
2124
2151
if attempt >= self .max_retries :
2125
2152
logger .warning (
2126
2153
f"Timed out waiting for RPC requests { attempt } times. Exiting."
0 commit comments