Skip to content

Commit 40c9c67

Browse files
authored
Merge pull request #808 from macrocosm-os/release/v3.0.4
v3.0.4 Release
2 parents 1c33f2c + 823a6c2 commit 40c9c67

File tree

13 files changed

+2076
-1834
lines changed

13 files changed

+2076
-1834
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,5 @@ cython_debug/
117117

118118
# VS Code
119119
.vscode
120+
121+
wandb/

apex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def setup_logger(log_file_path: str | Path | None = None, level: str = "INFO") -
4444
return logger
4545

4646

47-
setup_logger(log_file_path="logs.log", level="DEBUG")
47+
setup_logger(log_file_path="logs.log", level="INFO")

apex/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Config(BaseModel):
1515
chain: ConfigClass = Field(default_factory=ConfigClass)
1616
websearch: ConfigClass = Field(default_factory=ConfigClass)
1717
logger_db: ConfigClass = Field(default_factory=ConfigClass)
18+
logger_wandb: ConfigClass = Field(default_factory=ConfigClass)
1819
weight_syncer: ConfigClass = Field(default_factory=ConfigClass)
1920
miner_sampler: ConfigClass = Field(default_factory=ConfigClass)
2021
miner_scorer: ConfigClass = Field(default_factory=ConfigClass)

apex/common/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ class MinerGeneratorResults(BaseModel):
1313
query: str
1414
generator_hotkeys: list[str]
1515
generator_results: list[str]
16+
generator_times: list[float]
1617

1718

1819
class MinerDiscriminatorResults(BaseModel):
1920
query: str
2021
generator_hotkey: str
2122
generator_result: str
23+
generator_time: float
2224
generator_score: float
2325
discriminator_hotkeys: list[str]
2426
discriminator_results: list[str]

apex/validator/logger_wandb.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from collections.abc import Mapping
2+
from typing import Any
3+
4+
import wandb
5+
from loguru import logger
6+
7+
from apex import __version__
8+
from apex.common.async_chain import AsyncChain
9+
from apex.common.models import MinerDiscriminatorResults
10+
11+
12+
def approximate_tokens(text: str) -> int:
13+
"""Count the number of tokens in a text."""
14+
return len(text) // 4
15+
16+
17+
class LoggerWandb:
18+
def __init__(
19+
self,
20+
async_chain: AsyncChain,
21+
project: str = "apex-gan-arena",
22+
api_key: str | None = None,
23+
):
24+
self.run: Any | None = None
25+
if project and api_key:
26+
try:
27+
# Authenticate with W&B, then initialize the run
28+
wandb.login(key=api_key)
29+
self.run = wandb.init(
30+
entity="macrocosmos",
31+
project=project,
32+
config={
33+
"hotkey": async_chain.wallet.hotkey.ss58_address,
34+
"netuid": async_chain.netuid,
35+
"version": __version__,
36+
},
37+
)
38+
logger.info(f"Initialized W&B run: {self.run.id}")
39+
except Exception as e:
40+
logger.error(f"Failed to initialize W&B run: {e}")
41+
else:
42+
logger.warning("W&B API key not provided, skipping logging to W&B")
43+
44+
async def log(
45+
self,
46+
reference: str | None = None,
47+
discriminator_results: MinerDiscriminatorResults | None = None,
48+
tool_history: list[dict[str, str]] | None = None,
49+
) -> None:
50+
"""Log an event to wandb."""
51+
if self.run:
52+
if discriminator_results:
53+
processed_event = self.process_event(discriminator_results.model_dump())
54+
processed_event["reference"] = reference
55+
processed_event["tool_history"] = tool_history
56+
self.run.log(processed_event)
57+
58+
def process_event(self, event: Mapping[str, Any]) -> dict[str, Any]:
59+
"""Preprocess an event before logging it."""
60+
reference = event.get("reference", "")
61+
generation = event.get("generation", "")
62+
generator_tokens = approximate_tokens(generation)
63+
reference_tokens = approximate_tokens(reference)
64+
65+
processed_event: dict[str, Any] = dict(event)
66+
processed_event["generator_tokens"] = generator_tokens
67+
processed_event["reference_tokens"] = reference_tokens
68+
69+
return processed_event

apex/validator/miner_sampler.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,32 @@ async def query_miners(
156156
return ""
157157
return str(result)
158158

159+
async def query_miners_with_times(
160+
self, body: dict[str, Any], endpoint: str, hotkey: str | None = None, timeout: float = TIMEOUT
161+
) -> tuple[str, float]:
162+
"""Query the miners for the query."""
163+
start_time = time.time()
164+
result = await self.query_miners(body, endpoint, hotkey, timeout)
165+
return result, time.time() - start_time
166+
159167
async def query_generators(self, query: str) -> MinerGeneratorResults:
160168
"""Query the miners for the query."""
161169
miner_information = await self._sample_miners(sample_size=self._generator_sample_size)
162170
body = {"step": "generator", "query": query}
163171

164172
hotkeys: list[str] = []
165-
tasks: list[Coroutine[str, str, Any]] = []
173+
tasks: list[Coroutine[tuple[str, float], str, Any]] = []
166174

167175
for miner_info in miner_information:
168176
hotkeys.append(miner_info.hotkey)
169-
tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
177+
tasks.append(self.query_miners_with_times(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
170178
generator_results = await asyncio.gather(*tasks)
171-
return MinerGeneratorResults(query=query, generator_hotkeys=hotkeys, generator_results=generator_results)
179+
return MinerGeneratorResults(
180+
query=query,
181+
generator_hotkeys=hotkeys,
182+
generator_results=[result[0] for result in generator_results],
183+
generator_times=[result[1] for result in generator_results],
184+
)
172185

173186
async def query_discriminators(
174187
self,
@@ -181,19 +194,20 @@ async def query_discriminators(
181194
miner_information = await self._sample_miners(sample_size=self._discriminator_sample_size)
182195
# Flip the coin for the generator.
183196
if ground_truth and generator_results:
184-
selected_generator: tuple[str, str] = random.choice(
197+
selected_generator: tuple[str, str, float] = random.choice(
185198
list(
186199
zip(
187200
generator_results.generator_hotkeys,
188201
generator_results.generator_results,
202+
generator_results.generator_times,
189203
strict=False,
190204
)
191205
)
192206
)
193207
else:
194208
if reference is None:
195209
raise ValueError("Reference cannot be None when not using miner generator results")
196-
selected_generator = (VALIDATOR_REFERENCE_LABEL, reference)
210+
selected_generator = (VALIDATOR_REFERENCE_LABEL, reference, 0.0)
197211

198212
body = {
199213
"step": "discriminator",
@@ -202,7 +216,7 @@ async def query_discriminators(
202216
}
203217

204218
hotkeys: list[str] = []
205-
tasks: list[Coroutine[str, str, Any]] = []
219+
tasks: list[Coroutine[tuple[str, float], str, Any]] = []
206220
for miner_info in miner_information:
207221
hotkeys.append(miner_info.hotkey)
208222
tasks.append(self.query_miners(body=body, endpoint=miner_info.address, hotkey=miner_info.hotkey))
@@ -244,6 +258,7 @@ async def query_discriminators(
244258
generator_hotkey=selected_generator[0],
245259
generator_result=selected_generator[1],
246260
generator_score=generator_result_float,
261+
generator_time=selected_generator[2],
247262
discriminator_hotkeys=hotkeys,
248263
discriminator_results=parsed_discriminator_results,
249264
discriminator_scores=discriminator_results_float,

apex/validator/pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from apex.services.llm.llm_base import LLMBase
1212
from apex.services.websearch.websearch_base import WebSearchBase
1313
from apex.validator import generate_query, generate_reference
14-
from apex.validator.logger_apex import LoggerApex
1514
from apex.validator.logger_local import LoggerLocal
15+
from apex.validator.logger_wandb import LoggerWandb
1616
from apex.validator.miner_sampler import MinerSampler
1717

1818

@@ -23,7 +23,7 @@ def __init__(
2323
miner_sampler: MinerSampler,
2424
llm: LLMBase,
2525
deep_research: DeepResearchBase,
26-
logger_apex: LoggerApex | None = None,
26+
logger_wandb: LoggerWandb | None = None,
2727
num_consumers: int = 5,
2828
timeout_consumer: float = 1200,
2929
timeout_producer: float = 240,
@@ -36,7 +36,7 @@ def __init__(
3636
self.miner_registry = miner_sampler
3737
self.llm = llm
3838
self.deep_research = deep_research
39-
self.logger_apex = logger_apex
39+
self.logger_wandb = logger_wandb
4040
self.num_consumers = num_consumers
4141
self.timeout_consumer = timeout_consumer
4242
self.timeout_producer = timeout_producer
@@ -109,8 +109,8 @@ async def run_single(self, task: QueryTask) -> str:
109109
query=query, generator_results=generator_results, reference=reference, ground_truth=ground_truth
110110
)
111111

112-
if self.logger_apex:
113-
await self.logger_apex.log(
112+
if self.logger_wandb:
113+
await self.logger_wandb.log(
114114
reference=reference, discriminator_results=discriminator_results, tool_history=tool_history
115115
)
116116

config/mainnet.yaml.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ websearch:
1212
kwargs:
1313
key: "TAVILY_API_KEY"
1414

15+
logger_wandb:
16+
kwargs:
17+
project: "apex-gan-arena"
18+
api_key: "YOUR_WANDB_API_KEY"
19+
1520
llm:
1621
kwargs:
1722
key: "CHUTES_API_KEY"

config/testnet.yaml.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ websearch:
1010
kwargs:
1111
key: "TAVILY_API_KEY"
1212

13+
logger_wandb:
14+
kwargs:
15+
project: "apex-gan-arena"
16+
api_key: "YOUR_WANDB_API_KEY"
17+
1318
llm:
1419
kwargs:
1520
key: "CHUTES_API_KEY"

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "apex"
3-
version = "3.0.3"
3+
version = "3.0.4"
44
description = "Bittensor Subnet 1: Apex"
55
readme = "README.md"
66
requires-python = "~=3.11"
@@ -34,6 +34,8 @@ dependencies = [
3434
"types-cachetools>=6.0.0.20250525",
3535
"dotenv>=0.9.9",
3636
"pytest-mock>=3.14.1",
37+
"wandb>=0.21.1",
38+
"ruff>=0.12.5",
3739
]
3840

3941

0 commit comments

Comments
 (0)