File tree 10 files changed +82
-18
lines changed
examples/lightning_classy_vision
10 files changed +82
-18
lines changed Original file line number Diff line number Diff line change
1
+ name : Pyre
2
+
3
+ on :
4
+ push :
5
+ branches :
6
+ - master
7
+ pull_request :
8
+
9
+ jobs :
10
+ pyre :
11
+ runs-on : ubuntu-18.04
12
+ steps :
13
+ - name : Setup Python
14
+ uses : actions/setup-python@v2
15
+ with :
16
+ python-version : 3.8
17
+ architecture : x64
18
+ - name : Checkout TorchX
19
+ uses : actions/checkout@v2
20
+ - name : Install Dependencies
21
+ run : |
22
+ set -eux
23
+ pip install -r dev-requirements.txt
24
+ pip install pyre-check
25
+ - name : Run Pyre
26
+ run : scripts/pyre.sh
Original file line number Diff line number Diff line change
1
+ {
2
+ "source_directories": [
3
+ "."
4
+ ],
5
+ "strict": true,
6
+ "exclude": [
7
+ ".*/build/.*",
8
+ ".*/docs/.*",
9
+ ".*/setup.py"
10
+ ]
11
+ }
Original file line number Diff line number Diff line change
1
+ {
2
+ "root_files" : [
3
+ " torchx" ,
4
+ " .pyre_configuration" ,
5
+ " .watchmanconfig"
6
+ ]
7
+ }
Original file line number Diff line number Diff line change @@ -4,3 +4,7 @@ kfp==1.4.0
4
4
pyre-extensions>=0.0.21
5
5
black>=21.5b1
6
6
isort>=5.8.0
7
+ pytorch-lightning>=0.5.3
8
+ torch>=1.8.1
9
+ torchvision>=0.9.1
10
+ classy-vision>=0.5.0
Original file line number Diff line number Diff line change
1
+ #! /bin/sh
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ set -eux
9
+
10
+ SITE_PACKAGES=$( python -c " import site; print(site.getsitepackages()[0])" )
11
+ pyre --search-path " ${SITE_PACKAGES} " check
Original file line number Diff line number Diff line change @@ -25,21 +25,21 @@ arguments:
25
25
import torchx.specs.api as torchx
26
26
import torchx.schedulers.fb.resource as resource
27
27
28
- container = torchx.Container(image=args.image).require(resources=resource.get(args.resource))
28
+ container = torchx.Container(image=args.image).require(
29
+ resources=resource.get(args.resource)
30
+ )
29
31
entrypoint = "main"
30
32
31
33
trainer_role = (
32
- torchx.Role(
33
- name="trainer"
34
- )
34
+ torchx.Role(name="trainer")
35
35
.runs(
36
- "main",
37
- "--output_path",
38
- args.output_path,
39
- "--load_path",
40
- args.load_path,
41
- "--log_dir",
42
- args.log_dir,
36
+ "main",
37
+ "--output_path",
38
+ args.output_path,
39
+ "--load_path",
40
+ args.load_path,
41
+ "--log_dir",
42
+ args.log_dir,
43
43
)
44
44
.on(container)
45
45
.replicas(1)
Original file line number Diff line number Diff line change 11
11
from itertools import chain
12
12
13
13
14
- def _circleci_parallelism (suite ) :
14
+ def _circleci_parallelism (suite : unittest . TestSuite ) -> unittest . TestSuite :
15
15
"""Allow for parallelism in CircleCI for speedier tests.."""
16
16
if int (os .environ .get ("CIRCLE_NODE_TOTAL" , 0 )) <= 1 :
17
17
# either not running on circleci, or we're not using parallelism.
@@ -23,14 +23,15 @@ def _circleci_parallelism(suite):
23
23
24
24
# right now each test is corresponds to a /file/. Certain files are slower than
25
25
# others, so we want to flatten it
26
+ # pyre-fixme[16]: `TestCase` has no attribute `_tests`.
26
27
tests = [testfile ._tests for testfile in suite ._tests ]
27
28
tests = list (chain .from_iterable (tests ))
28
29
random .Random (42 ).shuffle (tests )
29
30
tests = [t for i , t in enumerate (tests ) if i % total == index ]
30
31
return unittest .TestSuite (tests )
31
32
32
33
33
- def unittests ():
34
+ def unittests () -> unittest . TestSuite :
34
35
"""
35
36
Short tests.
36
37
Original file line number Diff line number Diff line change 5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
7
8
+ import importlib
8
9
import json
9
10
import os .path
10
11
import tempfile
11
12
import unittest
12
- from typing import TypedDict , Optional
13
+ from typing import Optional , TypedDict
13
14
14
15
import yaml
15
16
from torchx .runtime .component import Component
16
17
from torchx .runtime .container .main import main
17
18
from torchx .runtime .plugins import TORCHX_CONFIG_ENV
18
- from torchx .runtime .storage import temppath , upload_blob , download_blob
19
+ from torchx .runtime .storage import download_blob , temppath , upload_blob
19
20
20
21
21
22
class SubConfig (TypedDict ):
@@ -126,6 +127,8 @@ def test_config_plugins(self) -> None:
126
127
"""
127
128
from torchx .runtime .test import dummy_module
128
129
130
+ importlib .reload (dummy_module )
131
+
129
132
module = "torchx.runtime.test.dummy_module"
130
133
config = {
131
134
"plugins" : {
Original file line number Diff line number Diff line change 13
13
14
14
15
15
def get_schedulers (
16
- session_name : str , ** scheduler_params
16
+ session_name : str , ** scheduler_params : object
17
17
) -> Dict [SchedulerBackend , Scheduler ]:
18
18
return {
19
19
"local" : local_scheduler .create_scheduler (session_name , ** scheduler_params ),
Original file line number Diff line number Diff line change 11
11
from itertools import chain
12
12
13
13
14
- def _circleci_parallelism (suite ) :
14
+ def _circleci_parallelism (suite : unittest . TestSuite ) -> unittest . TestSuite :
15
15
"""Allow for parallelism in CircleCI for speedier tests.."""
16
16
if int (os .environ .get ("CIRCLE_NODE_TOTAL" , 0 )) <= 1 :
17
17
# either not running on circleci, or we're not using parallelism.
@@ -23,14 +23,15 @@ def _circleci_parallelism(suite):
23
23
24
24
# right now each test is corresponds to a /file/. Certain files are slower than
25
25
# others, so we want to flatten it
26
+ # pyre-fixme[16]: `TestCase` has no attribute `_tests`.
26
27
tests = [testfile ._tests for testfile in suite ._tests ]
27
28
tests = list (chain .from_iterable (tests ))
28
29
random .Random (42 ).shuffle (tests )
29
30
tests = [t for i , t in enumerate (tests ) if i % total == index ]
30
31
return unittest .TestSuite (tests )
31
32
32
33
33
- def unittests ():
34
+ def unittests () -> unittest . TestSuite :
34
35
"""
35
36
Short tests.
36
37
You can’t perform that action at this time.
0 commit comments