5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import argparse
8
- import glob
9
- import importlib
10
- import os
11
- from dataclasses import dataclass
12
- from inspect import getmembers , isfunction
13
- from typing import Dict , List , Optional , Union
8
+ from dataclasses import asdict
9
+ from pprint import pformat
10
+ from typing import Dict , List , Union , cast
14
11
15
12
import torchx .specs as specs
16
13
from pyre_extensions import none_throws
17
14
from torchx .cli .cmd_base import SubCommand
18
15
from torchx .runner import get_runner
19
- from torchx .specs .file_linter import get_fn_docstring , validate
20
- from torchx .util import entrypoints
21
- from torchx .util .io import COMPONENTS_DIR , get_abspath , read_conf_file
16
+ from torchx .specs .finder import get_components , _Component
22
17
from torchx .util .types import to_dict
23
18
24
19
@@ -38,96 +33,19 @@ def _parse_run_config(arg: str) -> specs.RunConfig:
38
33
return conf
39
34
40
35
41
- def _to_module (filepath : str ) -> str :
42
- path , _ = os .path .splitext (filepath )
43
- return path .replace (os .path .sep , "." )
44
-
45
-
46
- def _get_builtin_description (filepath : str , function_name : str ) -> Optional [str ]:
47
- source = read_conf_file (filepath )
48
- if len (validate (source , torchx_function = function_name )) != 0 :
49
- return None
50
-
51
- func_definition , _ = none_throws (get_fn_docstring (source , function_name ))
52
- return func_definition
53
-
54
-
55
- @dataclass
56
- class BuiltinComponent :
57
- definition : str
58
- description : str
59
-
60
-
61
- def _get_component_definition (module : str , function_name : str ) -> str :
62
- if module .startswith ("torchx.components" ):
63
- module = module .split ("torchx.components." )[1 ]
64
- return f"{ module } .{ function_name } "
65
-
66
-
67
- def _to_relative (filepath : str ) -> str :
68
- if os .path .isabs (filepath ):
69
- # make path torchx/components/$suffix out of the abs
70
- rel_path = filepath .split (str (COMPONENTS_DIR ))[1 ]
71
- return f"{ str (COMPONENTS_DIR )} { rel_path } "
72
- else :
73
- return os .path .join (str (COMPONENTS_DIR ), filepath )
74
-
75
-
76
- def _get_components_from_file (filepath : str ) -> List [BuiltinComponent ]:
77
- components_path = _to_relative (filepath )
78
- components_module_path = _to_module (components_path )
79
- module = importlib .import_module (components_module_path )
80
- functions = getmembers (module , isfunction )
81
- buitin_functions = []
82
- for function_name , _ in functions :
83
- # Ignore private functions.
84
- if function_name .startswith ("_" ):
85
- continue
86
- component_desc = _get_builtin_description (filepath , function_name )
87
- if component_desc :
88
- definition = _get_component_definition (
89
- components_module_path , function_name
90
- )
91
- builtin_component = BuiltinComponent (
92
- definition = definition ,
93
- description = component_desc ,
94
- )
95
- buitin_functions .append (builtin_component )
96
- return buitin_functions
97
-
98
-
99
- def _allowed_path (path : str ) -> bool :
100
- filename = os .path .basename (path )
101
- if filename .startswith ("_" ):
102
- return False
103
- return True
104
-
105
-
106
- def _builtins () -> List [BuiltinComponent ]:
107
- components_dir = entrypoints .load (
108
- "torchx.file" , "get_dir_path" , default = get_abspath
109
- )(COMPONENTS_DIR )
110
-
111
- builtins : List [BuiltinComponent ] = []
112
- search_pattern = os .path .join (components_dir , "**" , "*.py" )
113
- for filepath in glob .glob (search_pattern , recursive = True ):
114
- if not _allowed_path (filepath ):
115
- continue
116
- components = _get_components_from_file (filepath )
117
- builtins += components
118
- return builtins
119
-
120
-
121
36
class CmdBuiltins (SubCommand ):
122
37
def add_arguments (self , subparser : argparse .ArgumentParser ) -> None :
123
- pass # no arguments
38
+ pass
39
+
40
+ def _builtins (self ) -> Dict [str , _Component ]:
41
+ return get_components ()
124
42
125
43
def run (self , args : argparse .Namespace ) -> None :
126
- builtin_configs = _builtins ()
127
- num_builtins = len (builtin_configs )
44
+ builtin_components = self . _builtins ()
45
+ num_builtins = len (builtin_components )
128
46
print (f"Found { num_builtins } builtin configs:" )
129
- for i , component in enumerate (builtin_configs ):
130
- print (f" { i + 1 :2d} . { component .definition } - { component . description } " )
47
+ for i , component in enumerate (builtin_components . values () ):
48
+ print (f" { i + 1 :2d} . { component .name } " )
131
49
132
50
133
51
class CmdRun (SubCommand ):
@@ -172,15 +90,23 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
172
90
def run (self , args : argparse .Namespace ) -> None :
173
91
# TODO: T91790598 - remove the if condition when all apps are migrated to pure python
174
92
runner = get_runner ()
175
- app_handle = runner .run_from_path (
93
+ result = runner .run_component (
176
94
args .conf_file ,
177
95
args .conf_args ,
178
96
args .scheduler ,
179
97
args .scheduler_args ,
180
98
dryrun = args .dryrun ,
181
99
)
182
100
183
- if not args .dryrun :
101
+ if args .dryrun :
102
+ app_dryrun_info = cast (specs .AppDryRunInfo , result )
103
+ print ("=== APPLICATION ===" )
104
+ print (pformat (asdict (app_dryrun_info ._app ), indent = 2 , width = 80 ))
105
+
106
+ print ("=== SCHEDULER REQUEST ===" )
107
+ print (app_dryrun_info )
108
+ else :
109
+ app_handle = cast (specs .AppHandle , result )
184
110
if args .scheduler == "local" :
185
111
runner .wait (app_handle )
186
112
else :
0 commit comments