Skip to content

Commit d02e572

Browse files
d4l3kfacebook-github-bot
authored andcommitted
fix tests + torchx CLI (#6)
Summary: This fixes the tests so they pass on OSS + make it so torchx entry point is correctly setup. It also fixes a lot of imports to point to the torchx copy so it works in OSS since torchelastic is out of date there. Pull Request resolved: #6 Test Plan: Imported from GitHub, without a `Test Plan:` line. buck test //torchx/... run OSS CI Reviewed By: tierex Differential Revision: D28614651 Pulled By: d4l3k fbshipit-source-id: f8cd9577ac8ec2dbadf73d3539a4cbeb31b76bfa
1 parent 439f3fd commit d02e572

File tree

11 files changed

+46
-24
lines changed

11 files changed

+46
-24
lines changed

dev-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
boto3>=1.11.11
22
moto>=2.0.6
33
kfp==1.4.0
4+
pyre-extensions>=0.0.21
5+
black>=21.5b1
6+
isort>=5.8.0

setup.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def get_version():
3232
with open("requirements.txt") as f:
3333
reqs = f.read()
3434

35+
with open("dev-requirements.txt") as f:
36+
test_reqs = f.read()
37+
3538
version = get_version()
3639
print("-- Building version: " + version)
3740

@@ -49,9 +52,15 @@ def get_version():
4952
keywords=["pytorch", "machine learning"],
5053
python_requires=">=3.8",
5154
install_requires=reqs.strip().split("\n"),
55+
tests_requires=test_reqs.strip().split("\n"),
5256
include_package_data=True,
5357
packages=find_packages(exclude=("*.test", "aws*", "*.fb")),
5458
test_suite="torchx.test.suites.unittests",
59+
entry_points={
60+
"console_scripts": [
61+
"torchx=torchx.cli.main:main",
62+
],
63+
},
5564
# PyPI package information.
5665
classifiers=[
5766
"Development Status :: 4 - Beta",

torchx/cli/cmd_describe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import dataclasses
1010
import pprint
1111

12-
import torchelastic.tsm.driver as tsm
1312
from torchx.cli.cmd_base import SubCommand
13+
from torchx.runner import get_runner
14+
from torchx.specs import api
1415

1516

1617
class CmdDescribe(SubCommand):
@@ -23,9 +24,9 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
2324

2425
def run(self, args: argparse.Namespace) -> None:
2526
app_handle = args.app_handle
26-
scheduler, session_name, app_id = tsm.parse_app_handle(app_handle)
27-
session = tsm.session(name=session_name)
28-
app = session.describe(app_handle)
27+
scheduler, session_name, app_id = api.parse_app_handle(app_handle)
28+
runner = get_runner(name=session_name)
29+
app = runner.describe(app_handle)
2930

3031
if app:
3132
pprint.pprint(dataclasses.asdict(app), indent=2, width=80)

torchx/cli/cmd_runopts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import argparse
99

10-
import torchelastic.tsm.driver as tsm
1110
from torchx.cli.cmd_base import SubCommand
11+
from torchx.runner.api import get_runner
1212

1313

1414
class CmdRunopts(SubCommand):
@@ -22,7 +22,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
2222

2323
def run(self, args: argparse.Namespace) -> None:
2424
scheduler = args.scheduler
25-
run_opts = tsm.session(name="default").run_opts()
25+
run_opts = get_runner().run_opts()
2626

2727
if not scheduler:
2828
print(run_opts)

torchx/cli/cmd_status.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from string import Template
1313
from typing import List, Optional, Pattern
1414

15-
import torchelastic.tsm.driver as tsm
1615
from torchx.cli.cmd_base import SubCommand
16+
from torchx.runner import get_runner
17+
from torchx.specs import api
1718
from torchx.specs.api import NONE
1819

1920
_APP_STATUS_FORMAT_TEMPLATE = """Application:
@@ -71,7 +72,7 @@ def format_error_message(msg: str, header: str, width: int = 80) -> str:
7172
return "\n".join(lines)
7273

7374

74-
def format_replica_status(replica_status: tsm.ReplicaStatus) -> str:
75+
def format_replica_status(replica_status: api.ReplicaStatus) -> str:
7576
if replica_status.structured_error_msg != NONE:
7677
error_data = json.loads(replica_status.structured_error_msg)
7778
error_message = format_error_message(
@@ -88,8 +89,8 @@ def format_replica_status(replica_status: tsm.ReplicaStatus) -> str:
8889
else:
8990
data = f"{replica_status.state}"
9091
if replica_status.state in [
91-
tsm.ReplicaState.CANCELLED,
92-
tsm.ReplicaState.FAILED,
92+
api.ReplicaState.CANCELLED,
93+
api.ReplicaState.FAILED,
9394
]:
9495
data += " (no reply file)"
9596

@@ -102,7 +103,7 @@ def format_replica_status(replica_status: tsm.ReplicaStatus) -> str:
102103

103104

104105
def format_role_status(
105-
role_status: tsm.RoleStatus,
106+
role_status: api.RoleStatus,
106107
) -> str:
107108
replica_data = ""
108109

@@ -112,15 +113,15 @@ def format_role_status(
112113

113114

114115
def get_roles(
115-
roles: List[tsm.RoleStatus], filter_roles: Optional[List[str]] = None
116-
) -> List[tsm.RoleStatus]:
116+
roles: List[api.RoleStatus], filter_roles: Optional[List[str]] = None
117+
) -> List[api.RoleStatus]:
117118
if not filter_roles:
118119
return roles
119120
return [role_status for role_status in roles if role_status.role in filter_roles]
120121

121122

122123
def format_app_status(
123-
app_status: tsm.AppStatus,
124+
app_status: api.AppStatus,
124125
filter_roles: Optional[List[str]] = None,
125126
) -> str:
126127
roles_data = ""
@@ -153,9 +154,9 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
153154

154155
def run(self, args: argparse.Namespace) -> None:
155156
app_handle = args.app_handle
156-
scheduler, session_name, app_id = tsm.parse_app_handle(app_handle)
157-
session = tsm.session(name=session_name)
158-
app_status = session.status(app_handle)
157+
scheduler, session_name, app_id = api.parse_app_handle(app_handle)
158+
runner = get_runner(name=session_name)
159+
app_status = runner.status(app_handle)
159160
filter_roles = parse_list_arg(args.roles)
160161
if app_status:
161162
print(format_app_status(app_status, filter_roles))

torchx/cli/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create_parser() -> ArgumentParser:
4848
return parser
4949

5050

51-
def main(argv: List[str]) -> None:
51+
def main(argv: List[str] = sys.argv[1:]) -> None:
5252
parser = create_parser()
5353
args = parser.parse_args(argv)
5454
if "func" not in args:
@@ -58,4 +58,4 @@ def main(argv: List[str]) -> None:
5858

5959

6060
if __name__ == "__main__":
61-
main(sys.argv[1:])
61+
main()

torchx/cli/test/__init__.py

Whitespace-only changes.

torchx/cli/test/cmd_describe_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchx.specs.api import Application, Container, ElasticRole, Resource
1414

1515

16-
class CmdStatusTest(unittest.TestCase):
16+
class CmdDescribeTest(unittest.TestCase):
1717
def get_test_app(self) -> Application:
1818
resource = Resource(cpu=2, gpu=0, memMB=256)
1919
trainer = (

torchx/cli/test/cmd_status_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import argparse
9+
import os
10+
import time
911
import unittest
1012
from unittest.mock import patch
1113

@@ -71,6 +73,9 @@ def _get_test_app_status(self) -> AppStatus:
7173
return AppStatus(state=AppState.RUNNING, roles=[role_status])
7274

7375
def test_format_app_status(self) -> None:
76+
os.environ["TZ"] = "Europe/London"
77+
time.tzset()
78+
7479
app_status = self._get_test_app_status()
7580
actual_message = format_app_status(app_status)
7681
print(actual_message)
@@ -79,7 +84,7 @@ def test_format_app_status(self) -> None:
7984
Num Restarts: 0
8085
Roles:
8186
*worker[0]:FAILED (exitcode: -1)
82-
timestamp: 1970-01-15 15:13:02
87+
timestamp: 1970-01-16 00:13:02
8388
hostname: localhost
8489
error_msg: error
8590
worker[1]:RUNNING"""

torchx/cli/test/main_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import unittest
99
from pathlib import Path
1010

11-
import torchelastic.tsm.driver as tsm
1211
from torchx.cli.cmd_run import _parse_run_config
1312
from torchx.cli.main import main
13+
from torchx.specs import api
1414

1515

1616
_root: Path = Path(__file__).parent
@@ -48,12 +48,12 @@ def test_run_builtin_config(self) -> None:
4848
)
4949

5050
def test_run_scheduler_args_empty(self) -> None:
51-
self.assertEqual(_parse_run_config(""), tsm.RunConfig())
51+
self.assertEqual(_parse_run_config(""), api.RunConfig())
5252

5353
def test_run_scheduler_args_simple(self) -> None:
5454
self.assertEqual(
5555
_parse_run_config("a=1,b=2;3;4"),
56-
tsm.RunConfig(
56+
api.RunConfig(
5757
cfgs={
5858
"a": "1",
5959
"b": ["2", "3", "4"],

0 commit comments

Comments
 (0)