11import asyncio
2- from collections import deque
32import datetime
43import json
4+ from collections import deque
55from pathlib import Path
66from typing import Any
77
@@ -42,9 +42,7 @@ async def set_weights(
4242 # If weights will not be set on chain, we should not synchronize.
4343 augmented_weights = weights
4444 else :
45- augmented_weights = await weight_syncer .get_augmented_weights (
46- weights = weights , uid = shared_settings .UID
47- )
45+ augmented_weights = await weight_syncer .get_augmented_weights (weights = weights , uid = shared_settings .UID )
4846 except BaseException as ex :
4947 logger .exception (f"Issue with setting weights: { ex } " )
5048 augmented_weights = weights
@@ -60,8 +58,7 @@ async def set_weights(
6058
6159 # Convert to uint16 weights and uids.
6260 uint_uids , uint_weights = bt .utils .weight_utils .convert_weights_and_uids_for_emit (
63- uids = processed_weight_uids ,
64- weights = processed_weights
61+ uids = processed_weight_uids , weights = processed_weights
6562 )
6663 except Exception as ex :
6764 logger .exception (f"Skipping weight setting: { ex } " )
@@ -167,9 +164,7 @@ async def _load_rewards(self):
167164 if payload is None :
168165 raise ValueError (f"Malformed weight history file: { data } " )
169166
170- self .reward_history .append (
171- {int (uid ): {"reward" : float (reward )} for uid , reward in payload .items ()}
172- )
167+ self .reward_history .append ({int (uid ): {"reward" : float (reward )} for uid , reward in payload .items ()})
173168 except BaseException as exc :
174169 self .reward_history : deque [dict [int , dict [str , Any ]]] | None = deque (maxlen = self .reward_history_len )
175170 logger .error (f"Couldn't load rewards from file, resetting weight history: { exc } " )
@@ -217,8 +212,7 @@ async def merge_task_rewards(cls, reward_events: list[list[WeightedRewardEvent]]
217212 processed_rewards = task_rewards / max (1 , (np .sum (task_rewards [task_rewards > 0 ]) + 1e-10 ))
218213 else :
219214 processed_rewards = cls .apply_steepness (
220- raw_rewards = task_rewards ,
221- steepness = shared_settings .REWARD_STEEPNESS
215+ raw_rewards = task_rewards , steepness = shared_settings .REWARD_STEEPNESS
222216 )
223217 processed_rewards *= task_config .probability
224218
@@ -238,11 +232,11 @@ def apply_steepness(cls, raw_rewards: npt.NDArray[np.float32], steepness: float
238232 p > 0.5 makes the function more exponential (winner takes all).
239233 """
240234 # 6.64385619 = ln(100)/ln(2) -> this way if p = 0.5, the exponent is exactly 1.
241- exponent = (steepness ** 6.64385619 ) * 100
235+ exponent = (steepness ** 6.64385619 ) * 100
242236 raw_rewards = np .array (raw_rewards ) / max (1 , (np .sum (raw_rewards [raw_rewards > 0 ]) + 1e-10 ))
243237 positive_rewards = np .clip (raw_rewards , 1e-10 , np .inf )
244238 normalised_rewards = positive_rewards / np .max (positive_rewards )
245- post_func_rewards = normalised_rewards ** exponent
239+ post_func_rewards = normalised_rewards ** exponent
246240 all_rewards = post_func_rewards / (np .sum (post_func_rewards ) + 1e-10 )
247241 all_rewards [raw_rewards <= 0 ] = raw_rewards [raw_rewards <= 0 ]
248242 return all_rewards
@@ -251,13 +245,13 @@ async def run_step(self):
251245 await asyncio .sleep (0.01 )
252246 try :
253247 if self .reward_events is None :
254- logger .error (f "No rewards events were found, skipping weight setting" )
248+ logger .error ("No rewards events were found, skipping weight setting" )
255249 return
256250
257251 final_rewards = await self .merge_task_rewards (self .reward_events )
258252
259253 if final_rewards is None :
260- logger .error (f "No rewards were found, skipping weight setting" )
254+ logger .error ("No rewards were found, skipping weight setting" )
261255 return
262256
263257 await self ._save_rewards (final_rewards )
0 commit comments