Skip to content

Commit d0392f1

Browse files
hadarohanafacebook-github-bot
authored andcommitted
Use generic TypedDict for runcfg (#482)
Summary: Pull Request resolved: #482 Pull Request resolved: #481 Use generic TypedDict for runcfg rather than Mapping. This eliminates the need for casts and type asserts to use the values. Reviewed By: d4l3k Differential Revision: D36072685 fbshipit-source-id: 3f8d57c385210839fd21c4e2690339055a78a29a
1 parent 8d83333 commit d0392f1

18 files changed

+193
-126
lines changed

torchx/runner/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Runner:
4949
def __init__(
5050
self,
5151
name: str,
52+
# pyre-fixme: Scheduler opts
5253
schedulers: Dict[str, Scheduler],
5354
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
5455
) -> None:
@@ -539,6 +540,7 @@ def log_lines(
539540
)
540541
return log_iter
541542

543+
# pyre-fixme: Scheduler opts
542544
def _scheduler(self, scheduler: str) -> Scheduler:
543545
sched = self._schedulers.get(scheduler)
544546
if not sched:
@@ -548,7 +550,10 @@ def _scheduler(self, scheduler: str) -> Scheduler:
548550
return sched
549551

550552
def _scheduler_app_id(
551-
self, app_handle: AppHandle, check_session: bool = True
553+
self,
554+
app_handle: AppHandle,
555+
check_session: bool = True
556+
# pyre-fixme: Scheduler opts
552557
) -> Tuple[Scheduler, str, str]:
553558
"""
554559
Returns the scheduler and app_id from the app_handle.

torchx/schedulers/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
class SchedulerFactory(Protocol):
21+
# pyre-fixme: Scheduler opts
2122
def __call__(self, session_name: str, **kwargs: object) -> Scheduler:
2223
...
2324

@@ -71,7 +72,9 @@ def get_default_scheduler_name() -> str:
7172

7273

7374
def get_schedulers(
74-
session_name: str, **scheduler_params: object
75+
session_name: str,
76+
**scheduler_params: object
77+
# pyre-fixme: Scheduler opts
7578
) -> Dict[str, Scheduler]:
7679
"""
7780
get_schedulers returns all available schedulers.

torchx/schedulers/api.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
from dataclasses import dataclass, field
1111
from datetime import datetime
1212
from enum import Enum
13-
from typing import Iterable, List, Mapping, Optional
13+
from typing import Generic, Iterable, List, Optional, TypeVar
1414

1515
from torchx.specs import (
1616
AppDef,
1717
AppDryRunInfo,
1818
AppState,
19-
CfgVal,
2019
NONE,
2120
NULL_RESOURCE,
2221
Role,
@@ -62,7 +61,10 @@ class DescribeAppResponse:
6261
roles: List[Role] = field(default_factory=list)
6362

6463

65-
class Scheduler(abc.ABC):
64+
T = TypeVar("T")
65+
66+
67+
class Scheduler(abc.ABC, Generic[T]):
6668
"""
6769
An interface abstracting functionalities of a scheduler.
6870
Implementors need only implement those methods annotated with
@@ -93,7 +95,7 @@ def close(self) -> None:
9395
def submit(
9496
self,
9597
app: AppDef,
96-
cfg: Mapping[str, CfgVal],
98+
cfg: T,
9799
workspace: Optional[str] = None,
98100
) -> str:
99101
"""
@@ -129,7 +131,7 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> str:
129131

130132
raise NotImplementedError()
131133

132-
def submit_dryrun(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> AppDryRunInfo:
134+
def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
133135
"""
134136
Rather than submitting the request to run the app, returns the
135137
request object that would have been submitted to the underlying
@@ -138,7 +140,9 @@ def submit_dryrun(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> AppDryRunInfo
138140
to the scheduler implementation's documentation regarding
139141
the actual return type.
140142
"""
143+
# pyre-fixme: Generic cfg type passed to resolve
141144
resolved_cfg = self.run_opts().resolve(cfg)
145+
# pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
142146
dryrun_info = self._submit_dryrun(app, resolved_cfg)
143147
for role in app.roles:
144148
dryrun_info = role.pre_proc(self.backend, dryrun_info)
@@ -147,7 +151,7 @@ def submit_dryrun(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> AppDryRunInfo
147151
return dryrun_info
148152

149153
@abc.abstractmethod
150-
def _submit_dryrun(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> AppDryRunInfo:
154+
def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
151155
raise NotImplementedError()
152156

153157
def run_opts(self) -> runopts:

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
Callable,
4444
Dict,
4545
Iterable,
46-
Mapping,
4746
Optional,
4847
Tuple,
4948
TYPE_CHECKING,
@@ -65,14 +64,14 @@
6564
AppDef,
6665
AppState,
6766
BindMount,
68-
CfgVal,
6967
DeviceMount,
7068
macros,
7169
Role,
7270
runopts,
7371
VolumeMount,
7472
)
7573
from torchx.workspace.docker_workspace import DockerWorkspace
74+
from typing_extensions import TypedDict
7675

7776
if TYPE_CHECKING:
7877
from docker import DockerClient
@@ -224,7 +223,12 @@ def _local_session() -> "boto3.session.Session":
224223
return boto3.session.Session()
225224

226225

227-
class AWSBatchScheduler(Scheduler, DockerWorkspace):
226+
class AWSBatchOpts(TypedDict, total=False):
227+
queue: str
228+
image_repo: Optional[str]
229+
230+
231+
class AWSBatchScheduler(Scheduler[AWSBatchOpts], DockerWorkspace):
228232
"""
229233
AWSBatchScheduler is a TorchX scheduling interface to AWS Batch.
230234
@@ -326,16 +330,14 @@ def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
326330

327331
return f"{req.queue}:{req.name}"
328332

329-
def _submit_dryrun(
330-
self, app: AppDef, cfg: Mapping[str, CfgVal]
331-
) -> AppDryRunInfo[BatchJob]:
333+
def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJob]:
332334
queue = cfg.get("queue")
333335
if not isinstance(queue, str):
334336
raise TypeError(f"config value 'queue' must be a string, got {queue}")
335337
name = make_unique(app.name)
336338

337339
# map any local images to the remote image
338-
images_to_push = self._update_app_images(app, cfg)
340+
images_to_push = self._update_app_images(app, cfg.get("image_repo"))
339341

340342
nodes = []
341343

@@ -390,6 +392,7 @@ def _submit_dryrun(
390392
)
391393
info = AppDryRunInfo(req, repr)
392394
info._app = app
395+
# pyre-fixme: AppDryRunInfo
393396
info._cfg = cfg
394397
return info
395398

torchx/schedulers/docker_scheduler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tempfile
1111
from dataclasses import dataclass
1212
from datetime import datetime
13-
from typing import Any, Dict, Iterable, List, Mapping, Optional, TYPE_CHECKING, Union
13+
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union
1414

1515
import torchx
1616
import yaml
@@ -28,7 +28,6 @@
2828
AppDef,
2929
AppState,
3030
BindMount,
31-
CfgVal,
3231
DeviceMount,
3332
is_terminal,
3433
macros,
@@ -39,6 +38,7 @@
3938
VolumeMount,
4039
)
4140
from torchx.workspace.docker_workspace import DockerWorkspace
41+
from typing_extensions import TypedDict
4242

4343

4444
if TYPE_CHECKING:
@@ -93,7 +93,11 @@ def has_docker() -> bool:
9393
return False
9494

9595

96-
class DockerScheduler(Scheduler, DockerWorkspace):
96+
class DockerOpts(TypedDict, total=False):
97+
copy_env: Optional[List[str]]
98+
99+
100+
class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace):
97101
"""
98102
DockerScheduler is a TorchX scheduling interface to Docker.
99103
@@ -187,9 +191,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
187191

188192
return req.app_id
189193

190-
def _submit_dryrun(
191-
self, app: AppDef, cfg: Mapping[str, CfgVal]
192-
) -> AppDryRunInfo[DockerJob]:
194+
def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJob]:
193195
from docker.types import DeviceRequest, Mount
194196

195197
default_env = {}
@@ -301,6 +303,7 @@ def _submit_dryrun(
301303

302304
info = AppDryRunInfo(req, repr)
303305
info._app = app
306+
# pyre-fixme: AppDryRunInfo
304307
info._cfg = cfg
305308
return info
306309

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
AppDef,
5757
AppState,
5858
BindMount,
59-
CfgVal,
6059
DeviceMount,
6160
macros,
6261
ReplicaState,
@@ -68,6 +67,7 @@
6867
VolumeMount,
6968
)
7069
from torchx.workspace.docker_workspace import DockerWorkspace
70+
from typing_extensions import TypedDict
7171

7272

7373
if TYPE_CHECKING:
@@ -434,7 +434,15 @@ def __repr__(self) -> str:
434434
return str(self)
435435

436436

437-
class KubernetesScheduler(Scheduler, DockerWorkspace):
437+
class KubernetesOpts(TypedDict, total=False):
438+
namespace: Optional[str]
439+
queue: str
440+
image_repo: Optional[str]
441+
service_account: Optional[str]
442+
priority_class: Optional[str]
443+
444+
445+
class KubernetesScheduler(Scheduler[KubernetesOpts], DockerWorkspace):
438446
"""
439447
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.
440448
@@ -590,14 +598,14 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
590598
return f'{namespace}:{resp["metadata"]["name"]}'
591599

592600
def _submit_dryrun(
593-
self, app: AppDef, cfg: Mapping[str, CfgVal]
601+
self, app: AppDef, cfg: KubernetesOpts
594602
) -> AppDryRunInfo[KubernetesJob]:
595603
queue = cfg.get("queue")
596604
if not isinstance(queue, str):
597605
raise TypeError(f"config value 'queue' must be a string, got {queue}")
598606

599607
# map any local images to the remote image
600-
images_to_push = self._update_app_images(app, cfg)
608+
images_to_push = self._update_app_images(app, cfg.get("image_repo"))
601609

602610
service_account = cfg.get("service_account")
603611
assert service_account is None or isinstance(
@@ -616,6 +624,7 @@ def _submit_dryrun(
616624
)
617625
info = AppDryRunInfo(req, repr)
618626
info._app = app
627+
# pyre-fixme: AppDryRunInfo
619628
info._cfg = cfg
620629
return info
621630

0 commit comments

Comments
 (0)