Skip to content

Commit b7e1fd0

Browse files
d4l3kfacebook-github-bot
authored andcommitted
added pyre support (#8)
Summary: This runs the pyre type checking from OSS. It also formats some files + adds some types + small test fixes. Pull Request resolved: #8 Reviewed By: kiukchung Differential Revision: D28652691 Pulled By: d4l3k fbshipit-source-id: 09e4ef1f82c1240a8da19b0018081a36b0f0bafd
1 parent c497683 commit b7e1fd0

File tree

10 files changed

+82
-18
lines changed

10 files changed

+82
-18
lines changed

.github/workflows/pyre.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Pyre
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
9+
jobs:
10+
pyre:
11+
runs-on: ubuntu-18.04
12+
steps:
13+
- name: Setup Python
14+
uses: actions/setup-python@v2
15+
with:
16+
python-version: 3.8
17+
architecture: x64
18+
- name: Checkout TorchX
19+
uses: actions/checkout@v2
20+
- name: Install Dependencies
21+
run: |
22+
set -eux
23+
pip install -r dev-requirements.txt
24+
pip install pyre-check
25+
- name: Run Pyre
26+
run: scripts/pyre.sh

.pyre_configuration

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"source_directories": [
3+
"."
4+
],
5+
"strict": true,
6+
"exclude": [
7+
".*/build/.*",
8+
".*/docs/.*",
9+
".*/setup.py"
10+
]
11+
}

.watchmanconfig

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"root_files": [
3+
"torchx",
4+
".pyre_configuration",
5+
".watchmanconfig"
6+
]
7+
}

dev-requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ kfp==1.4.0
44
pyre-extensions>=0.0.21
55
black>=21.5b1
66
isort>=5.8.0
7+
pytorch-lightning>=0.5.3
8+
torch>=1.8.1
9+
torchvision>=0.9.1
10+
classy-vision>=0.5.0

scripts/pyre.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/sh
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -eux
9+
10+
SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
11+
pyre --search-path "${SITE_PACKAGES}" check

torchx/examples/lightning_classy_vision/lightning_classy_vision.torchx

+11-11
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ arguments:
2525
import torchx.specs.api as torchx
2626
import torchx.schedulers.fb.resource as resource
2727

28-
container = torchx.Container(image=args.image).require(resources=resource.get(args.resource))
28+
container = torchx.Container(image=args.image).require(
29+
resources=resource.get(args.resource)
30+
)
2931
entrypoint = "main"
3032

3133
trainer_role = (
32-
torchx.Role(
33-
name="trainer"
34-
)
34+
torchx.Role(name="trainer")
3535
.runs(
36-
"main",
37-
"--output_path",
38-
args.output_path,
39-
"--load_path",
40-
args.load_path,
41-
"--log_dir",
42-
args.log_dir,
36+
"main",
37+
"--output_path",
38+
args.output_path,
39+
"--load_path",
40+
args.load_path,
41+
"--log_dir",
42+
args.log_dir,
4343
)
4444
.on(container)
4545
.replicas(1)

torchx/pipelines/kfp/test/suites.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from itertools import chain
1212

1313

14-
def _circleci_parallelism(suite):
14+
def _circleci_parallelism(suite: unittest.TestSuite) -> unittest.TestSuite:
1515
"""Allow for parallelism in CircleCI for speedier tests.."""
1616
if int(os.environ.get("CIRCLE_NODE_TOTAL", 0)) <= 1:
1717
# either not running on circleci, or we're not using parallelism.
@@ -23,14 +23,15 @@ def _circleci_parallelism(suite):
2323

2424
# right now each test is corresponds to a /file/. Certain files are slower than
2525
# others, so we want to flatten it
26+
# pyre-fixme[16]: `TestCase` has no attribute `_tests`.
2627
tests = [testfile._tests for testfile in suite._tests]
2728
tests = list(chain.from_iterable(tests))
2829
random.Random(42).shuffle(tests)
2930
tests = [t for i, t in enumerate(tests) if i % total == index]
3031
return unittest.TestSuite(tests)
3132

3233

33-
def unittests():
34+
def unittests() -> unittest.TestSuite:
3435
"""
3536
Short tests.
3637

torchx/runtime/container/test/main_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import importlib
89
import json
910
import os.path
1011
import tempfile
1112
import unittest
12-
from typing import TypedDict, Optional
13+
from typing import Optional, TypedDict
1314

1415
import yaml
1516
from torchx.runtime.component import Component
1617
from torchx.runtime.container.main import main
1718
from torchx.runtime.plugins import TORCHX_CONFIG_ENV
18-
from torchx.runtime.storage import temppath, upload_blob, download_blob
19+
from torchx.runtime.storage import download_blob, temppath, upload_blob
1920

2021

2122
class SubConfig(TypedDict):
@@ -126,6 +127,8 @@ def test_config_plugins(self) -> None:
126127
"""
127128
from torchx.runtime.test import dummy_module
128129

130+
importlib.reload(dummy_module)
131+
129132
module = "torchx.runtime.test.dummy_module"
130133
config = {
131134
"plugins": {

torchx/schedulers/registry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
def get_schedulers(
16-
session_name: str, **scheduler_params
16+
session_name: str, **scheduler_params: object
1717
) -> Dict[SchedulerBackend, Scheduler]:
1818
return {
1919
"local": local_scheduler.create_scheduler(session_name, **scheduler_params),

torchx/test/suites.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from itertools import chain
1212

1313

14-
def _circleci_parallelism(suite):
14+
def _circleci_parallelism(suite: unittest.TestSuite) -> unittest.TestSuite:
1515
"""Allow for parallelism in CircleCI for speedier tests.."""
1616
if int(os.environ.get("CIRCLE_NODE_TOTAL", 0)) <= 1:
1717
# either not running on circleci, or we're not using parallelism.
@@ -23,14 +23,15 @@ def _circleci_parallelism(suite):
2323

2424
# right now each test is corresponds to a /file/. Certain files are slower than
2525
# others, so we want to flatten it
26+
# pyre-fixme[16]: `TestCase` has no attribute `_tests`.
2627
tests = [testfile._tests for testfile in suite._tests]
2728
tests = list(chain.from_iterable(tests))
2829
random.Random(42).shuffle(tests)
2930
tests = [t for i, t in enumerate(tests) if i % total == index]
3031
return unittest.TestSuite(tests)
3132

3233

33-
def unittests():
34+
def unittests() -> unittest.TestSuite:
3435
"""
3536
Short tests.
3637

0 commit comments

Comments
 (0)