10
10
from dataclasses import dataclass , field
11
11
from datetime import datetime
12
12
from enum import Enum
13
- from typing import Iterable , List , Mapping , Optional
13
+ from typing import Generic , Iterable , List , Optional , TypeVar
14
14
15
15
from torchx .specs import (
16
16
AppDef ,
17
17
AppDryRunInfo ,
18
18
AppState ,
19
- CfgVal ,
20
19
NONE ,
21
20
NULL_RESOURCE ,
22
21
Role ,
@@ -62,7 +61,10 @@ class DescribeAppResponse:
62
61
roles : List [Role ] = field (default_factory = list )
63
62
64
63
65
- class Scheduler (abc .ABC ):
64
+ T = TypeVar ("T" )
65
+
66
+
67
+ class Scheduler (abc .ABC , Generic [T ]):
66
68
"""
67
69
An interface abstracting functionalities of a scheduler.
68
70
Implementors need only implement those methods annotated with
@@ -93,7 +95,7 @@ def close(self) -> None:
93
95
def submit (
94
96
self ,
95
97
app : AppDef ,
96
- cfg : Mapping [ str , CfgVal ] ,
98
+ cfg : T ,
97
99
workspace : Optional [str ] = None ,
98
100
) -> str :
99
101
"""
@@ -129,7 +131,7 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> str:
129
131
130
132
raise NotImplementedError ()
131
133
132
- def submit_dryrun (self , app : AppDef , cfg : Mapping [ str , CfgVal ] ) -> AppDryRunInfo :
134
+ def submit_dryrun (self , app : AppDef , cfg : T ) -> AppDryRunInfo :
133
135
"""
134
136
Rather than submitting the request to run the app, returns the
135
137
request object that would have been submitted to the underlying
@@ -138,7 +140,9 @@ def submit_dryrun(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> AppDryRunInfo
138
140
to the scheduler implementation's documentation regarding
139
141
the actual return type.
140
142
"""
143
+ # pyre-fixme: Generic cfg type passed to resolve
141
144
resolved_cfg = self .run_opts ().resolve (cfg )
145
+ # pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
142
146
dryrun_info = self ._submit_dryrun (app , resolved_cfg )
143
147
for role in app .roles :
144
148
dryrun_info = role .pre_proc (self .backend , dryrun_info )
@@ -147,7 +151,7 @@ def submit_dryrun(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> AppDryRunInfo
147
151
return dryrun_info
148
152
149
153
@abc .abstractmethod
150
- def _submit_dryrun (self , app : AppDef , cfg : Mapping [ str , CfgVal ] ) -> AppDryRunInfo :
154
+ def _submit_dryrun (self , app : AppDef , cfg : T ) -> AppDryRunInfo :
151
155
raise NotImplementedError ()
152
156
153
157
def run_opts (self ) -> runopts :
0 commit comments