|
1 | 1 | # ruff: noqa: E402 |
2 | 2 | import asyncio |
| 3 | +from collections import deque |
3 | 4 | from pathlib import Path |
4 | 5 | from types import SimpleNamespace |
5 | 6 | from unittest.mock import MagicMock, patch |
@@ -90,7 +91,7 @@ def test_steepness(): |
90 | 91 | assert result[0] < 0, "Negative reward should remain negative" |
91 | 92 |
|
92 | 93 |
|
93 | | -def test_run_step_with_reward_events(): |
| 94 | +def test_run_step_with_reward_events(tmp_path: Path): |
94 | 95 | with ( |
95 | 96 | patch("shared.uids.get_uids") as mock_get_uids, |
96 | 97 | patch("prompting.weight_setting.weight_setter.TaskRegistry") as MockTaskRegistry, |
@@ -126,7 +127,7 @@ def __init__(self, task, uids, rewards, weight): |
126 | 127 |
|
127 | 128 | # Set up the mock mutable_globals. |
128 | 129 |
|
129 | | - weight_setter = WeightSetter(reward_history_path=Path("test_validator_rewards.jsonl")) |
| 130 | + weight_setter = WeightSetter(reward_history_path=tmp_path / "test_validator_rewards.jsonl") |
130 | 131 | reward_events = [ |
131 | 132 | [ |
132 | 133 | WeightedRewardEvent( |
@@ -165,6 +166,37 @@ def __init__(self, task, uids, rewards, weight): |
165 | 166 | mock_logger.warning.assert_not_called() |
166 | 167 |
|
167 | 168 |
|
| 169 | +def _make_snapshot(values: list[float]) -> dict[int, dict[str, float]]: |
| 170 | + return {uid: {"reward": v} for uid, v in enumerate(values)} |
| 171 | + |
| 172 | + |
| 173 | +@pytest.mark.asyncio |
| 174 | +async def test_avg_reward_non_empty(tmp_path: Path) -> None: |
| 175 | + """Mean over two snapshots equals manual average.""" |
| 176 | + ws = WeightSetter(reward_history_path=tmp_path / "test_validator_rewards.jsonl") |
| 177 | + ws.reward_history_len = 10 |
| 178 | + ws.reward_history = deque(maxlen=10) |
| 179 | + rewards = list(range(256)) |
| 180 | + ws.reward_history.append(_make_snapshot(rewards)) |
| 181 | + ws.reward_history.append(_make_snapshot(rewards[::-1])) |
| 182 | + |
| 183 | + result = await ws._compute_avg_reward() |
| 184 | + |
| 185 | + expected = np.full(256, 255 / 2, dtype=np.float32) |
| 186 | + assert result.dtype == np.float32 |
| 187 | + assert np.allclose(result, expected, atol=1e-6) |
| 188 | + |
| 189 | + |
| 190 | +@pytest.mark.asyncio |
| 191 | +async def test_avg_reward_empty(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: |
| 192 | + """Empty history returns a zero vector.""" |
| 193 | + ws = WeightSetter(reward_history_path=tmp_path / "test_validator_rewards.jsonl") |
| 194 | + ws.reward_history_len = 10 |
| 195 | + ws.reward_history = deque(maxlen=10) |
| 196 | + result = await ws._compute_avg_reward() |
| 197 | + assert np.array_equal(result, np.zeros(256, dtype=np.float32)) |
| 198 | + |
| 199 | + |
168 | 200 | @pytest.mark.asyncio |
169 | 201 | async def test_set_weights(monkeypatch: MonkeyPatch): |
170 | 202 | """`set_weights` calls Subtensor.set_weights with processed vectors.""" |
|
0 commit comments