10
10
import warnings
11
11
from dataclasses import asdict
12
12
from os import path
13
+ from pathlib import Path
13
14
from pprint import pformat
14
- from typing import Callable , Iterable , List , Type
15
+ from typing import Callable , Iterable , List , Optional , Type
15
16
16
17
import torchx .specs as specs
17
18
import yaml
18
19
from torchx .cli .cmd_base import SubCommand
19
20
from torchx .cli .conf_helpers import parse_args_children
20
21
from torchx .runner import get_runner
22
+ from torchx .util import entrypoints
21
23
22
24
23
25
class UnsupportFeatureError (Exception ):
@@ -94,35 +96,53 @@ def _parse_run_config(arg: str) -> specs.RunConfig:
94
96
95
97
# TODO kiuk@ move read_conf_file + _builtins to the Runner once the Runner is API stable
96
98
97
- _CONFIG_DIR : str = path . join ( path . dirname ( __file__ ), " config" )
99
+ _CONFIG_DIR : Path = Path ( "torchx/cli/ config" )
98
100
_CONFIG_EXT = ".torchx"
99
101
100
102
103
+ def get_file_contents (conf_file : str ) -> Optional [str ]:
104
+ """
105
+ Reads the ``conf_file`` relative to the root of the project.
106
+ Returns ``None`` if ``$root/$conf_file`` does not exist.
107
+ Example: ``get_file("torchx/cli/config/foo.txt")``
108
+ """
109
+
110
+ root = path .dirname (__file__ ).replace (__name__ .replace ("." , path .sep ), "" )
111
+ abspath = path .join (root , conf_file )
112
+ if path .exists (abspath ):
113
+ with open (abspath , "r" ) as f :
114
+ return f .read ()
115
+ else :
116
+ return None
117
+
118
+
101
119
def read_conf_file (conf_file : str ) -> str :
102
- builtin_conf = path .join (_CONFIG_DIR , conf_file )
120
+ builtin_conf = entrypoints .load (
121
+ "torchx.file" ,
122
+ "get_file_contents" ,
123
+ default = get_file_contents ,
124
+ )(str (_CONFIG_DIR / conf_file ))
103
125
104
126
# user provided conf file precedes the builtin config
105
127
# just print a warning but use the user provided one
106
128
if path .exists (conf_file ):
107
- if path . exists ( builtin_conf ) :
129
+ if builtin_conf :
108
130
warnings .warn (
109
131
f"The provided config file: { conf_file } overlaps"
110
132
f" with a built-in. It is recommended that you either"
111
133
f" rename the config file or use abs path."
112
134
f" Will use: { path .abspath (conf_file )} for this run."
113
135
)
114
- else : # conf_file does not exist fallback to builtin
115
- conf_file = builtin_conf
116
-
117
- if not path .exists (conf_file ):
136
+ with open (conf_file , "r" ) as f :
137
+ return f .read ()
138
+ elif builtin_conf : # conf_file does not exist fallback to builtin
139
+ return builtin_conf
140
+ else : # neither conf_file nor builtin exists, raise error
118
141
raise FileNotFoundError (
119
- f"{ conf_file } does not exist and/or is not a builtin."
142
+ f"{ conf_file } does not exist and is not a builtin."
120
143
" For a list of available builtins run `torchx builtins`"
121
144
)
122
145
123
- with open (conf_file , "r" ) as f :
124
- return f .read ()
125
-
126
146
127
147
def _builtins () -> List [str ]:
128
148
builtins : List [str ] = []
@@ -143,7 +163,7 @@ def run(self, args: argparse.Namespace) -> None:
143
163
num_builtins = len (builtin_configs )
144
164
print (f"Found { num_builtins } builtin configs:" )
145
165
for i , name in enumerate (builtin_configs ):
146
- print (f" { i + 1 :2d} . { name } " )
166
+ print (f" { i + 1 :2d} . { name } " )
147
167
148
168
149
169
class CmdRun (SubCommand ):
0 commit comments