|
14 | 14 | from typing import Iterable, Mapping, Optional
|
15 | 15 |
|
16 | 16 | from torchx.util.entrypoints import load_group
|
| 17 | +from torchx.util.modules import load_module |
17 | 18 |
|
18 | 19 | logger: logging.Logger = logging.getLogger(__name__)
|
19 | 20 |
|
@@ -177,30 +178,26 @@ def _extract_tracker_name_and_config_from_environ() -> Mapping[str, Optional[str
|
177 | 178 |
|
178 | 179 |
|
179 | 180 | def build_trackers(
|
180 |
| - entrypoint_and_config: Mapping[str, Optional[str]] |
| 181 | + factory_and_config: Mapping[str, Optional[str]] |
181 | 182 | ) -> Iterable[TrackerBase]:
|
182 | 183 | trackers = []
|
183 | 184 |
|
184 |
| - entrypoint_factories = load_group("torchx.tracker") |
| 185 | + entrypoint_factories = load_group("torchx.tracker") or {} |
185 | 186 | if not entrypoint_factories:
|
186 |
| - logger.warning( |
187 |
| - "No 'torchx.tracker' entry_points are defined. Tracking will not capture any data." |
188 |
| - ) |
189 |
| - return trackers |
| 187 | + logger.warning("No 'torchx.tracker' entry_points are defined.") |
190 | 188 |
|
191 |
| - for entrypoint_key, config in entrypoint_and_config.items(): |
192 |
| - if entrypoint_key not in entrypoint_factories: |
| 189 | + for factory_name, config in factory_and_config.items(): |
| 190 | + factory = entrypoint_factories.get(factory_name) or load_module(factory_name) |
| 191 | + if not factory: |
193 | 192 | logger.warning(
|
194 |
| - f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..." |
| 193 | + f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker" |
195 | 194 | )
|
196 | 195 | continue
|
197 |
| - factory = entrypoint_factories[entrypoint_key] |
198 | 196 | if config:
|
199 |
| - logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`") |
200 |
| - tracker = factory(config) |
| 197 | + logger.info(f"Tracker config found for `{factory_name}` as `{config}`") |
201 | 198 | else:
|
202 |
| - logger.info(f"No tracker config specified for `{entrypoint_key}`") |
203 |
| - tracker = factory(None) |
| 199 | + logger.info(f"No tracker config specified for `{factory_name}`") |
| 200 | + tracker = factory(config) |
204 | 201 | trackers.append(tracker)
|
205 | 202 | return trackers
|
206 | 203 |
|
|
0 commit comments