Skip to content

Commit 439f3fd

Browse files
Kiuk Chungfacebook-github-bot
authored andcommitted
fix FileNotFound errors when reading builtin configs from a xar, use entrypoint to register custom fb plugins for torchx
Summary: need to use parutils to load bundled builtin config files... Reviewed By: tierex Differential Revision: D28614050 fbshipit-source-id: a1869da90ffe07ad18fb3ac360f7215c8f568b6c
1 parent 33eadf8 commit 439f3fd

File tree

5 files changed

+120
-13
lines changed

5 files changed

+120
-13
lines changed

torchx/cli/cmd_run.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
import warnings
1111
from dataclasses import asdict
1212
from os import path
13+
from pathlib import Path
1314
from pprint import pformat
14-
from typing import Callable, Iterable, List, Type
15+
from typing import Callable, Iterable, List, Optional, Type
1516

1617
import torchx.specs as specs
1718
import yaml
1819
from torchx.cli.cmd_base import SubCommand
1920
from torchx.cli.conf_helpers import parse_args_children
2021
from torchx.runner import get_runner
22+
from torchx.util import entrypoints
2123

2224

2325
class UnsupportFeatureError(Exception):
@@ -94,35 +96,53 @@ def _parse_run_config(arg: str) -> specs.RunConfig:
9496

9597
# TODO kiuk@ move read_conf_file + _builtins to the Runner once the Runner is API stable
9698

97-
_CONFIG_DIR: str = path.join(path.dirname(__file__), "config")
99+
_CONFIG_DIR: Path = Path("torchx/cli/config")
98100
_CONFIG_EXT = ".torchx"
99101

100102

103+
def get_file_contents(conf_file: str) -> Optional[str]:
104+
"""
105+
Reads the ``conf_file`` relative to the root of the project.
106+
Returns ``None`` if ``$root/$conf_file`` does not exist.
107+
Example: ``get_file("torchx/cli/config/foo.txt")``
108+
"""
109+
110+
root = path.dirname(__file__).replace(__name__.replace(".", path.sep), "")
111+
abspath = path.join(root, conf_file)
112+
if path.exists(abspath):
113+
with open(abspath, "r") as f:
114+
return f.read()
115+
else:
116+
return None
117+
118+
101119
def read_conf_file(conf_file: str) -> str:
102-
builtin_conf = path.join(_CONFIG_DIR, conf_file)
120+
builtin_conf = entrypoints.load(
121+
"torchx.file",
122+
"get_file_contents",
123+
default=get_file_contents,
124+
)(str(_CONFIG_DIR / conf_file))
103125

104126
# user provided conf file precedes the builtin config
105127
# just print a warning but use the user provided one
106128
if path.exists(conf_file):
107-
if path.exists(builtin_conf):
129+
if builtin_conf:
108130
warnings.warn(
109131
f"The provided config file: {conf_file} overlaps"
110132
f" with a built-in. It is recommended that you either"
111133
f" rename the config file or use abs path."
112134
f" Will use: {path.abspath(conf_file)} for this run."
113135
)
114-
else: # conf_file does not exist fallback to builtin
115-
conf_file = builtin_conf
116-
117-
if not path.exists(conf_file):
136+
with open(conf_file, "r") as f:
137+
return f.read()
138+
elif builtin_conf: # conf_file does not exist fallback to builtin
139+
return builtin_conf
140+
else: # neither conf_file nor builtin exists, raise error
118141
raise FileNotFoundError(
119-
f"{conf_file} does not exist and/or is not a builtin."
142+
f"{conf_file} does not exist and is not a builtin."
120143
" For a list of available builtins run `torchx builtins`"
121144
)
122145

123-
with open(conf_file, "r") as f:
124-
return f.read()
125-
126146

127147
def _builtins() -> List[str]:
128148
builtins: List[str] = []
@@ -143,7 +163,7 @@ def run(self, args: argparse.Namespace) -> None:
143163
num_builtins = len(builtin_configs)
144164
print(f"Found {num_builtins} builtin configs:")
145165
for i, name in enumerate(builtin_configs):
146-
print(f" {i+1:2d}. {name}")
166+
print(f" {i + 1:2d}. {name}")
147167

148168

149169
class CmdRun(SubCommand):

torchx/util/__init__.py

Whitespace-only changes.

torchx/util/entrypoints.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from importlib import metadata
8+
from importlib.metadata import EntryPoint
9+
from typing import Dict
10+
11+
12+
# pyre-ignore-all-errors[3, 2]
13+
def load(group: str, name: str, default=None):
14+
"""
15+
Loads the entry point specified by
16+
17+
::
18+
19+
[group]
20+
name1 = this.is:a_function
21+
-- or --
22+
name2 = this.is.a.module
23+
24+
In case such an entry point is not found, an optional
25+
default is returned. If the default is not specified
26+
and the entry point is not found, then this method
27+
raises an error.
28+
"""
29+
30+
entrypoints = metadata.entry_points()
31+
32+
if group not in entrypoints and default:
33+
return default
34+
35+
eps: Dict[str, EntryPoint] = {ep.name: ep for ep in entrypoints[group]}
36+
37+
if name not in eps and default:
38+
return default
39+
else:
40+
ep = eps[name]
41+
return ep.load()

torchx/util/test/__init__.py

Whitespace-only changes.

torchx/util/test/entrypoints_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from importlib.metadata import EntryPoint
9+
from typing import Dict
10+
from unittest.mock import MagicMock, patch
11+
12+
from torchx.util.entrypoints import load
13+
14+
15+
def foobar() -> str:
16+
return "foobar"
17+
18+
19+
def barbaz() -> str:
20+
return "barbaz"
21+
22+
23+
_ENTRY_POINT_TXT: str = """
24+
[entrypoints.test]
25+
foo = torchx.util.test.entrypoints_test:foobar
26+
"""
27+
28+
_ENTRY_POINTS: Dict[str, EntryPoint] = {
29+
# pyre-ignore[16]
30+
"entrypoints.test": EntryPoint._from_text(_ENTRY_POINT_TXT)
31+
}
32+
33+
_METADATA_EPS: str = "torchx.util.entrypoints.metadata.entry_points"
34+
35+
36+
class EntryPointsTest(unittest.TestCase):
37+
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
38+
def test_load(self, mock_md_eps: MagicMock) -> None:
39+
print(type(load("entrypoints.test", "foo")))
40+
self.assertEqual("foobar", load("entrypoints.test", "foo")())
41+
42+
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
43+
def test_load_with_default(self, mock_md_eps: MagicMock) -> None:
44+
self.assertEqual("barbaz", load("entrypoints.test", "missing", barbaz)())
45+
self.assertEqual("barbaz", load("entrypoints.missing", "foo", barbaz)())
46+
self.assertEqual("barbaz", load("entrypoints.missing", "missing", barbaz)())

0 commit comments

Comments
 (0)