Skip to content

Commit 3e556dc

Browse files
authored
Split training tests to separate test shard (#9281)
1 parent 97b75ea commit 3e556dc

File tree

6 files changed

+54
-28
lines changed

6 files changed

+54
-28
lines changed

.github/workflows/_tpu_ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jobs:
2424
- test/tpu/run_expensive_test_1.sh
2525
- test/tpu/run_expensive_test_2.sh
2626
- test/tpu/run_pallas_test.sh
27+
- test/tpu/run_training_tests.sh
2728
steps:
2829
- name: Checkout actions
2930
if: inputs.has_code_changes == 'true'

examples/data_parallel/train_resnet_spmd_data_parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import os
3+
import time
34
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
45
sys.path.append(example_folder)
56
from train_resnet_base import TrainResNetBase
@@ -46,4 +47,8 @@ def __init__(self):
4647

4748
if __name__ == '__main__':
4849
spmd_ddp = TrainResNetXLASpmdDDP()
50+
51+
start_time = time.time()
4952
spmd_ddp.start_training()
53+
end_time = time.time()
54+
print(f"Finished training in {end_time - start_time:.3f}s")

examples/train_decoder_only_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ def start_training(self):
156156
if decoder_cls is not None:
157157
params.append(decoder_cls)
158158
base = TrainDecoderOnlyBase(*params, num_steps=args.num_steps, config=config)
159+
160+
start_time = time.time()
159161
base.start_training()
162+
end_time = time.time()
163+
print(f"Finished training in {end_time - start_time:.3f}s")
160164

161165
if args.print_metrics:
162166
print(torch_xla._XLAC._xla_metrics_report())

examples/train_resnet_amp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from train_resnet_base import TrainResNetBase
22

33
import itertools
4+
import time
45

56
import torch_xla
67
import torch_xla.distributed.xla_multiprocessing as xmp
@@ -33,4 +34,8 @@ def train_loop_fn(self, loader, epoch):
3334

3435
if __name__ == '__main__':
3536
xla_amp = TrainResNetXLAAMP()
37+
38+
start_time = time.time()
3639
xla_amp.start_training()
40+
end_time = time.time()
41+
print(f"Finished training in {end_time - start_time:.3f}s")

test/tpu/run_tests.sh

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -95,31 +95,3 @@ run_test "$_TEST_DIR/quantized_ops/test_dot_general.py"
9595
run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py"
9696
run_test "$_TEST_DIR/test_data_type.py"
9797
run_test "$_TEST_DIR/test_compilation_cache_utils.py"
98-
99-
# run examples, each test should takes <2 minutes
100-
run_test "$_TEST_DIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"
101-
run_test "$_TEST_DIR/../examples/fsdp/train_decoder_only_fsdp_v2.py"
102-
run_test "$_TEST_DIR/../examples/train_resnet_amp.py"
103-
run_test "$_TEST_DIR/../examples/train_decoder_only_base.py"
104-
run_test "$_TEST_DIR/../examples/train_decoder_only_base.py" scan.decoder_with_scan.DecoderWithScan \
105-
--num-steps 30 # TODO(https://github.com/pytorch/xla/issues/8632): Reduce scan tracing overhead
106-
107-
# HACK: don't confuse local `torch_xla` folder with installed package
108-
# Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559
109-
# Egaer tests will take more HBM, only run them on TPU v4 CI
110-
TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())")
111-
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
112-
run_test "$_TEST_DIR/dynamo/test_traceable_collectives.py"
113-
run_test "$_TEST_DIR/../examples/data_parallel/train_resnet_xla_ddp.py"
114-
run_test "$_TEST_DIR/../examples/fsdp/train_resnet_fsdp_auto_wrap.py"
115-
run_test "$_TEST_DIR/../examples/eager/train_decoder_only_eager.py"
116-
run_test "$_TEST_DIR/../examples/eager/train_decoder_only_eager_spmd_data_parallel.py"
117-
run_test "$_TEST_DIR/../examples/eager/train_decoder_only_eager_with_compile.py"
118-
run_test "$_TEST_DIR/../examples/eager/train_decoder_only_eager_multi_process.py"
119-
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shapes.py" -v
120-
fi
121-
122-
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" != "6" ]]; then
123-
# Test `tpu-info` CLI compatibility
124-
run_test "$_TPU_DIR/tpu_info/test_cli.py"
125-
fi

test/tpu/run_training_tests.sh

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
set -xue
3+
4+
# Absolute path to the directory of this script.
5+
_TPU_DIR="$(
6+
cd "$(dirname "$0")"
7+
pwd -P
8+
)"
9+
10+
# Absolute path to the test/ directory.
11+
_TEST_DIR="$(dirname "$_TPU_DIR")"
12+
13+
# run examples, each test should takes <2 minutes
14+
python3 "$_TEST_DIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"
15+
python3 "$_TEST_DIR/../examples/fsdp/train_decoder_only_fsdp_v2.py"
16+
python3 "$_TEST_DIR/../examples/train_resnet_amp.py"
17+
python3 "$_TEST_DIR/../examples/train_decoder_only_base.py"
18+
python3 "$_TEST_DIR/../examples/train_decoder_only_base.py" scan.decoder_with_scan.DecoderWithScan \
19+
--num-steps 30 # TODO(https://github.com/pytorch/xla/issues/8632): Reduce scan tracing overhead
20+
21+
# HACK: don't confuse local `torch_xla` folder with installed package
22+
# Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559
23+
# Egaer tests will take more HBM, only run them on TPU v4 CI
24+
TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())")
25+
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
26+
python3 "$_TEST_DIR/dynamo/test_traceable_collectives.py"
27+
python3 "$_TEST_DIR/../examples/data_parallel/train_resnet_xla_ddp.py"
28+
python3 "$_TEST_DIR/../examples/fsdp/train_resnet_fsdp_auto_wrap.py"
29+
python3 "$_TEST_DIR/../examples/eager/train_decoder_only_eager.py"
30+
python3 "$_TEST_DIR/../examples/eager/train_decoder_only_eager_spmd_data_parallel.py"
31+
python3 "$_TEST_DIR/../examples/eager/train_decoder_only_eager_with_compile.py"
32+
python3 "$_TEST_DIR/../examples/eager/train_decoder_only_eager_multi_process.py"
33+
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$_TEST_DIR/ds/test_dynamic_shapes.py" -v
34+
fi
35+
36+
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" != "6" ]]; then
37+
# Test `tpu-info` CLI compatibility
38+
python3 "$_TPU_DIR/tpu_info/test_cli.py"
39+
fi

0 commit comments

Comments
 (0)