Skip to content

Commit 5733f74

Browse files
authored
refactor: remove subscription and use internal interval (#31)
Subscription is really a bad way to interact with the agent. This PR changes the messaging structure to send all prices as a batch update in an configured interval to minimize agent <> publisher interactions. The pythd has undergone a rewrite because there's no async jsonrpc client library in python which supports batch requests.
1 parent 0f95b2d commit 5733f74

File tree

8 files changed

+224
-211
lines changed

8 files changed

+224
-211
lines changed

config/config.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# the configuration for the chosen engine as described below.
55
provider_engine = 'pyth_replicator'
66

7-
product_update_interval_secs = 10
7+
price_update_interval_secs = 1.0
8+
product_update_interval_secs = 60
89
health_check_port = 8000
910

1011
# The health check will return a failure status if no price data has been published within the specified time frame.
@@ -22,7 +23,7 @@ endpoint = 'ws://127.0.0.1:8910'
2223
# coin_gecko_id = 'bitcoin'
2324

2425
[publisher.pyth_replicator]
25-
http_endpoint = 'https://pythnet.rpcpool.com'
26-
ws_endpoint = 'wss://pythnet.rpcpool.com'
26+
http_endpoint = 'https://api2.pythnet.pyth.network'
27+
ws_endpoint = 'wss://api2.pythnet.pyth.network'
2728
first_mapping = 'AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J'
2829
program_key = 'FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH'

example_publisher/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_DEFAULT_CONFIG_PATH = os.path.join("config", "config.toml")
1515

1616

17-
log_level = logging._nameToLevel[os.environ.get("LOG_LEVEL", "DEBUG").upper()]
17+
log_level = logging._nameToLevel[os.environ.get("LOG_LEVEL", "INFO").upper()]
1818
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))
1919

2020
log = structlog.get_logger()

example_publisher/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Config:
4949
pythd: Pythd
5050
health_check_port: int
5151
health_check_threshold_secs: int
52+
price_update_interval_secs: float = ts.option(default=1.0)
5253
product_update_interval_secs: int = ts.option(default=60)
5354
coin_gecko: Optional[CoinGeckoConfig] = ts.option(default=None)
5455
pyth_replicator: Optional[PythReplicatorConfig] = ts.option(default=None)

example_publisher/providers/pyth_replicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ async def _update_loop(self) -> None:
102102
update.timestamp,
103103
)
104104

105-
log.info(
105+
log.debug(
106106
"Received a price update", symbol=symbol, price=self._prices[symbol]
107107
)
108108

@@ -118,7 +118,7 @@ async def _update_accounts_loop(self) -> None:
118118

119119
await asyncio.sleep(self._config.account_update_interval_secs)
120120

121-
def upd_products(self, *args) -> None:
121+
def upd_products(self, product_symbols: List[Symbol]) -> None:
122122
# This provider stores all the possible feeds and
123123
# does not care about the desired products as knowing
124124
# them does not improve the performance of the replicator

example_publisher/publisher.py

Lines changed: 47 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from example_publisher.providers.coin_gecko import CoinGecko
88
from example_publisher.config import Config
99
from example_publisher.providers.pyth_replicator import PythReplicator
10-
from example_publisher.pythd import Pythd, SubscriptionId
10+
from example_publisher.pythd import PriceUpdate, Pythd, SubscriptionId
1111

1212

1313
log = get_logger()
@@ -50,7 +50,6 @@ def __init__(self, config: Config) -> None:
5050

5151
self.pythd: Pythd = Pythd(
5252
address=config.pythd.endpoint,
53-
on_notify_price_sched=self.on_notify_price_sched,
5453
)
5554
self.subscriptions: Dict[SubscriptionId, Product] = {}
5655
self.products: List[Product] = []
@@ -66,18 +65,17 @@ def is_healthy(self) -> bool:
6665
async def start(self):
6766
await self.pythd.connect()
6867

69-
self._product_update_task = asyncio.create_task(
70-
self._start_product_update_loop()
71-
)
72-
73-
async def _start_product_update_loop(self):
7468
await self._upd_products()
69+
70+
self._product_update_task = asyncio.create_task(self._product_update_loop())
71+
self._price_update_task = asyncio.create_task(self._price_update_loop())
72+
7573
self.provider.start()
7674

75+
async def _product_update_loop(self):
7776
while True:
78-
await self._upd_products()
79-
await self._subscribe_notify_price_sched()
8077
await asyncio.sleep(self.config.product_update_interval_secs)
78+
await self._upd_products()
8179

8280
async def _upd_products(self):
8381
log.debug("fetching product accounts from Pythd")
@@ -114,58 +112,51 @@ async def _upd_products(self):
114112

115113
self.provider.upd_products([product.symbol for product in self.products])
116114

117-
async def _subscribe_notify_price_sched(self):
118-
# Subscribe to Pythd's notify_price_sched for each product that
119-
# is not subscribed yet. Unfortunately there is no way to unsubscribe
120-
# to the prices that are no longer available.
121-
log.debug("subscribing to notify_price_sched")
122-
123-
subscriptions = {}
124-
for product in self.products:
125-
if not product.subscription_id:
126-
subscription_id = await self.pythd.subscribe_price_sched(
127-
product.price_account
115+
async def _price_update_loop(self):
116+
while True:
117+
price_updates = []
118+
for product in self.products:
119+
price = self.provider.latest_price(product.symbol)
120+
if not price:
121+
log.info("latest price not available", symbol=product.symbol)
122+
continue
123+
124+
scaled_price = self.apply_exponent(price.price, product.exponent)
125+
scaled_conf = self.apply_exponent(price.conf, product.exponent)
126+
127+
price_updates.append(
128+
PriceUpdate(
129+
account=product.price_account,
130+
price=scaled_price,
131+
conf=scaled_conf,
132+
status=TRADING,
133+
)
134+
)
135+
log.debug(
136+
"sending price update",
137+
symbol=product.symbol,
138+
price_account=product.price_account,
139+
price=price.price,
140+
conf=price.conf,
141+
scaled_price=scaled_price,
142+
scaled_conf=scaled_conf,
128143
)
129-
product.subscription_id = subscription_id
130-
131-
subscriptions[product.subscription_id] = product
132-
133-
self.subscriptions = subscriptions
134-
135-
async def on_notify_price_sched(self, subscription: int) -> None:
136144

137-
log.debug("received notify_price_sched", subscription=subscription)
138-
if subscription not in self.subscriptions:
139-
return
145+
self.last_successful_update = (
146+
price.timestamp
147+
if self.last_successful_update is None
148+
else max(self.last_successful_update, price.timestamp)
149+
)
140150

141-
# Look up the current price and confidence interval of the product
142-
product = self.subscriptions[subscription]
143-
price = self.provider.latest_price(product.symbol)
144-
if not price:
145-
log.info("latest price not available", symbol=product.symbol)
146-
return
151+
log.info(
152+
"sending batch update_price",
153+
num_price_updates=len(price_updates),
154+
total_products=len(self.products),
155+
)
147156

148-
# Scale the price and confidence interval using the Pyth exponent
149-
scaled_price = self.apply_exponent(price.price, product.exponent)
150-
scaled_conf = self.apply_exponent(price.conf, product.exponent)
157+
await self.pythd.update_price_batch(price_updates)
151158

152-
# Send the price update
153-
log.info(
154-
"sending update_price",
155-
product_account=product.product_account,
156-
price_account=product.price_account,
157-
price=scaled_price,
158-
conf=scaled_conf,
159-
symbol=product.symbol,
160-
)
161-
await self.pythd.update_price(
162-
product.price_account, scaled_price, scaled_conf, TRADING
163-
)
164-
self.last_successful_update = (
165-
price.timestamp
166-
if self.last_successful_update is None
167-
else max(self.last_successful_update, price.timestamp)
168-
)
159+
await asyncio.sleep(self.config.price_update_interval_secs)
169160

170161
def apply_exponent(self, x: float, exp: int) -> int:
171162
return int(x * (10 ** (-exp)))

example_publisher/pythd.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import asyncio
21
from dataclasses import dataclass, field
3-
import sys
4-
import traceback
5-
from dataclasses_json import config, DataClassJsonMixin
6-
from typing import Callable, Coroutine, List
2+
import json
3+
from dataclasses_json import config, DataClassJsonMixin, dataclass_json
4+
from dataclasses_json.undefined import Undefined
5+
from typing import List, Any, Optional
76
from structlog import get_logger
8-
from jsonrpc_websocket import Server
7+
from websockets.client import connect, WebSocketClientProtocol
8+
from asyncio import Lock
99

1010
log = get_logger()
1111

@@ -15,12 +15,22 @@
1515
TRADING = "trading"
1616

1717

18+
@dataclass_json(undefined=Undefined.EXCLUDE)
1819
@dataclass
1920
class Price(DataClassJsonMixin):
2021
account: str
2122
exponent: int = field(metadata=config(field_name="price_exponent"))
2223

2324

25+
@dataclass
26+
class PriceUpdate(DataClassJsonMixin):
27+
account: str
28+
price: int
29+
conf: int
30+
status: str
31+
32+
33+
@dataclass_json(undefined=Undefined.EXCLUDE)
2434
@dataclass
2535
class Metadata(DataClassJsonMixin):
2636
symbol: str
@@ -34,56 +44,77 @@ class Product(DataClassJsonMixin):
3444
prices: List[Price] = field(metadata=config(field_name="price"))
3545

3646

47+
@dataclass
48+
class JSONRPCRequest(DataClassJsonMixin):
49+
id: int
50+
method: str
51+
params: List[Any] | Any
52+
jsonrpc: str = "2.0"
53+
54+
55+
@dataclass
56+
class JSONRPCResponse(DataClassJsonMixin):
57+
id: int
58+
result: Optional[Any] = None
59+
error: Optional[Any] = None
60+
jsonrpc: str = "2.0"
61+
62+
3763
class Pythd:
3864
def __init__(
3965
self,
4066
address: str,
41-
on_notify_price_sched: Callable[[SubscriptionId], Coroutine[None, None, None]],
4267
) -> None:
4368
self.address = address
44-
self.server: Server
45-
self.on_notify_price_sched = on_notify_price_sched
46-
self._tasks = set()
69+
self.client: WebSocketClientProtocol
70+
self.id_counter = 0
71+
self.lock = Lock()
4772

4873
async def connect(self):
49-
self.server = Server(self.address)
50-
self.server.notify_price_sched = self._notify_price_sched
51-
task = await self.server.ws_connect()
52-
task.add_done_callback(Pythd._on_connection_done)
53-
self._tasks.add(task)
54-
55-
@staticmethod
56-
def _on_connection_done(task):
57-
log.error("pythd connection closed")
58-
if not task.cancelled() and task.exception() is not None:
59-
e = task.exception()
60-
traceback.print_exception(None, e, e.__traceback__)
61-
sys.exit(1)
62-
63-
async def subscribe_price_sched(self, account: str) -> int:
64-
subscription = (await self.server.subscribe_price_sched(account=account))[
65-
"subscription"
66-
]
67-
log.debug(
68-
"subscribed to price_sched", account=account, subscription=subscription
74+
self.client = await connect(self.address)
75+
76+
def _create_request(self, method: str, params: List[Any] | Any) -> JSONRPCRequest:
77+
self.id_counter += 1
78+
return JSONRPCRequest(
79+
id=self.id_counter,
80+
method=method,
81+
params=params,
6982
)
70-
return subscription
7183

72-
def _notify_price_sched(self, subscription: int) -> None:
73-
log.debug("notify_price_sched RPC call received", subscription=subscription)
74-
task = asyncio.get_event_loop().create_task(
75-
self.on_notify_price_sched(subscription)
76-
)
77-
self._tasks.add(task)
78-
task.add_done_callback(self._tasks.discard)
84+
async def send_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
85+
# Using a lock will result in a synchronous execution of the send_request method
86+
# and response retrieval which makes the code easier but is not good for performance.
87+
# It is not recommended to use this behaviour where there are concurrent requests
88+
# being made to the server.
89+
async with self.lock:
90+
await self.client.send(request.to_json())
91+
response = await self.client.recv()
92+
return JSONRPCResponse.from_json(response)
93+
94+
async def send_batch_request(
95+
self, requests: List[JSONRPCRequest]
96+
) -> List[JSONRPCResponse]:
97+
async with self.lock:
98+
await self.client.send(
99+
json.dumps([request.to_dict() for request in requests])
100+
)
101+
response = await self.client.recv()
102+
return JSONRPCResponse.schema().loads(response, many=True)
79103

80104
async def all_products(self) -> List[Product]:
81-
result = await self.server.get_product_list()
82-
return [Product.from_dict(d) for d in result]
83-
84-
async def update_price(
85-
self, account: str, price: int, conf: int, status: str
86-
) -> None:
87-
await self.server.update_price(
88-
account=account, price=price, conf=conf, status=status
89-
)
105+
request = self._create_request("get_product_list", [])
106+
result = await self.send_request(request)
107+
if result.result:
108+
return Product.schema().load(result.result, many=True)
109+
else:
110+
raise ValueError(f"Error fetching products: {result.to_json()}")
111+
112+
async def update_price_batch(self, price_updates: List[PriceUpdate]) -> None:
113+
requests = [
114+
self._create_request("update_price", price_update.to_dict())
115+
for price_update in price_updates
116+
]
117+
results = await self.send_batch_request(requests)
118+
if any(result.error for result in results):
119+
results_json_str = JSONRPCResponse.schema().dumps(results, many=True)
120+
raise ValueError(f"Error updating prices: {results_json_str}")

0 commit comments

Comments
 (0)