Skip to content

Commit 063626f

Browse files
committed
feat: module lookup for building trackers
1 parent fad46b6 commit 063626f

File tree

6 files changed

+99
-20
lines changed

6 files changed

+99
-20
lines changed

Diff for: torchx/schedulers/kubernetes_scheduler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.6.0/installer/volcano-development.yaml
2424
2525
See the
26-
`Volcano Quickstart <https://github.com/volcano-sh/volcano#user-content-quick-start-guide>`_
26+
`Volcano Quickstart <https://github.com/volcano-sh/volcano#quick-start-guide>`_
2727
for more information.
2828
"""
2929

Diff for: torchx/tracker/__init__.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
-------------
3838
To enable tracking it requires:
3939
40-
1. Defining tracker backends (entrypoints and configuration) on launcher side using :doc:`runner.config`
40+
1. Defining tracker backends (entrypoints/modules and configuration) on launcher side using :doc:`runner.config`
4141
2. Adding entrypoints within a user job using entry_points (`specification`_)
4242
4343
.. _specification: https://packaging.python.org/en/latest/specifications/entry-points/
@@ -49,13 +49,13 @@
4949
User can define any number of tracker backends under **torchx:tracker** section in :doc:`runner.config`, where:
5050
* Key: is an arbitrary name for the tracker, where the name will be used to configure its properties
5151
under [tracker:<TRACKER_NAME>]
52-
* Value: is *entrypoint/factory method* that must be available within user job. The value will be injected into a
52+
* Value: is *entrypoint* or *module* factory method that must be available within user job. The value will be injected into a
5353
user job and used to construct tracker implementation.
5454
5555
.. code-block:: ini
5656
5757
[torchx:tracker]
58-
tracker_name=<entry_point>
58+
tracker_name=<entry_point_or_module_factory_method>
5959
6060
6161
Each tracker can be additionally configured (currently limited to `config` parameter) under `[tracker:<TRACKER NAME>]` section:
@@ -71,11 +71,15 @@
7171
7272
[torchx:tracker]
7373
tracker1=tracker1
74-
tracker12=backend_2_entry_point
74+
tracker2=backend_2_entry_point
75+
tracker3=torchx.tracker.mlflow:create_tracker
7576
7677
[tracker:tracker1]
7778
config=s3://my_bucket/config.json
7879
80+
[tracker:tracker3]
81+
config=my_config.json
82+
7983
8084
2. User job configuration (Advanced)
8185
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Diff for: torchx/tracker/api.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Iterable, Mapping, Optional
1515

1616
from torchx.util.entrypoints import load_group
17+
from torchx.util.modules import load_module
1718

1819
logger: logging.Logger = logging.getLogger(__name__)
1920

@@ -177,30 +178,26 @@ def _extract_tracker_name_and_config_from_environ() -> Mapping[str, Optional[str
177178

178179

179180
def build_trackers(
180-
entrypoint_and_config: Mapping[str, Optional[str]]
181+
factory_and_config: Mapping[str, Optional[str]]
181182
) -> Iterable[TrackerBase]:
182183
trackers = []
183184

184-
entrypoint_factories = load_group("torchx.tracker")
185+
entrypoint_factories = load_group("torchx.tracker") or {}
185186
if not entrypoint_factories:
186-
logger.warning(
187-
"No 'torchx.tracker' entry_points are defined. Tracking will not capture any data."
188-
)
189-
return trackers
187+
logger.warning("No 'torchx.tracker' entry_points are defined.")
190188

191-
for entrypoint_key, config in entrypoint_and_config.items():
192-
if entrypoint_key not in entrypoint_factories:
189+
for factory_name, config in factory_and_config.items():
190+
factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
191+
if not factory:
193192
logger.warning(
194-
f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
193+
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
195194
)
196195
continue
197-
factory = entrypoint_factories[entrypoint_key]
198196
if config:
199-
logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`")
200-
tracker = factory(config)
197+
logger.info(f"Tracker config found for `{factory_name}` as `{config}`")
201198
else:
202-
logger.info(f"No tracker config specified for `{entrypoint_key}`")
203-
tracker = factory(None)
199+
logger.info(f"No tracker config specified for `{factory_name}`")
200+
tracker = factory(config)
204201
trackers.append(tracker)
205202
return trackers
206203

Diff for: torchx/tracker/test/api_test.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections import defaultdict
1010
from typing import cast, DefaultDict, Dict, Iterable, Mapping, Optional, Tuple
1111
from unittest import mock, TestCase
12-
from unittest.mock import patch
12+
from unittest.mock import MagicMock, patch
1313

1414
from torchx.tracker import app_run_from_env
1515
from torchx.tracker.api import (
@@ -27,6 +27,8 @@
2727
TrackerSource,
2828
)
2929

30+
from torchx.tracker.mlflow import MLflowTracker
31+
3032
RunId = str
3133

3234
DEFAULT_SOURCE: str = "__parent__"
@@ -271,6 +273,26 @@ def test_build_trackers_with_no_entrypoints_group_defined(self) -> None:
271273
trackers = build_trackers(tracker_names)
272274
self.assertEqual(0, len(list(trackers)))
273275

276+
def test_build_trackers_with_module(self) -> None:
277+
module = MagicMock()
278+
module.return_value = MagicMock(spec=MLflowTracker)
279+
with patch(
280+
"torchx.tracker.api.load_group",
281+
return_value=None,
282+
) and patch(
283+
"torchx.tracker.api.load_module",
284+
return_value=module,
285+
):
286+
tracker_names = {
287+
"torchx.tracker.mlflow:create_tracker": (config := "myconfig.txt")
288+
}
289+
trackers = build_trackers(tracker_names)
290+
trackers = list(trackers)
291+
self.assertEqual(1, len(trackers))
292+
tracker = trackers[0]
293+
self.assertIsInstance(tracker, MLflowTracker)
294+
module.assert_called_once_with(config)
295+
274296
def test_build_trackers(self) -> None:
275297
with patch(
276298
"torchx.tracker.api.load_group",

Diff for: torchx/util/modules.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
from types import ModuleType
9+
from typing import Callable, Optional, Union
10+
11+
12+
def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]:
13+
"""
14+
Loads and returns the module/module attr represented by the ``path``: ``full.module.path:optional_attr``
15+
16+
::
17+
18+
19+
1. ``load_module("this.is.a_module:fn")`` -> equivalent to ``this.is.a_module.fn``
20+
1. ``load_module("this.is.a_module")`` -> equivalent to ``this.is.a_module``
21+
"""
22+
parts = path.split(":", 2)
23+
module_path, method = parts[0], parts[1] if len(parts) > 1 else None
24+
module = None
25+
i, n = -1, len(module_path)
26+
try:
27+
while i < n:
28+
i = module_path.find(".", i + 1)
29+
i = i if i >= 0 else n
30+
module = importlib.import_module(module_path[:i])
31+
return getattr(module, method) if method else module
32+
except Exception:
33+
return None

Diff for: torchx/util/test/modules_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from torchx.util.modules import load_module
10+
11+
12+
class ModulesTest(unittest.TestCase):
13+
def test_load_module(self) -> None:
14+
result = load_module("os.path")
15+
import os
16+
17+
self.assertEqual(result, os.path)
18+
19+
def test_load_module_method(self) -> None:
20+
result = load_module("os.path:join")
21+
import os
22+
23+
self.assertEqual(result, os.path.join)

0 commit comments

Comments
 (0)