Skip to content

Commit b4ed6cc

Browse files
committed
set env variables for pytorch before initializing w ddp.
1 parent f6ca2bc commit b4ed6cc

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

submission_runner.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222
import json
2323
import os
2424

25-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
26-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
27-
# disable only for deepspeech if it works fine for other workloads.
28-
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
29-
3025
import struct
3126
import time
3227
from types import MappingProxyType
@@ -56,6 +51,11 @@
5651
from algorithmic_efficiency.pytorch_utils import sync_ddp_time
5752
from algorithmic_efficiency.workloads import workloads
5853

54+
# Environment variables
55+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
56+
# disable only for deepspeech if it works fine for other workloads
57+
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
58+
5959
# TODO(znado): make a nicer registry of workloads that lookup in.
6060
BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR
6161

@@ -681,6 +681,14 @@ def main(_):
681681
else:
682682
profiler = PassThroughProfiler()
683683

684+
# Set PyTorch environment variables before initializing w DDP
685+
base_workload = workloads.get_base_workload_name(FLAGS.workload)
686+
if base_workload == 'librispeech_conformer':
687+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
688+
689+
if FLAGS.set_pytorch_max_split_size:
690+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
691+
684692
if FLAGS.framework == 'pytorch':
685693
pytorch_init(USE_PYTORCH_DDP, RANK, profiler)
686694

@@ -692,9 +700,6 @@ def main(_):
692700

693701
workload_metadata = WORKLOADS[FLAGS.workload]
694702

695-
# Prevent OOM on librispeech conformer.
696-
base_workload = workloads.get_base_workload_name(FLAGS.workload)
697-
698703
if base_workload in [
699704
'librispeech_conformer',
700705
'librispeech_deepspeech',
@@ -703,13 +708,6 @@ def main(_):
703708
]:
704709
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'
705710

706-
if base_workload != 'librispeech_conformer':
707-
# Remove the environment variable (only for workloads other than librispeech conformer).
708-
del os.environ['PYTORCH_CUDA_ALLOC_CONF']
709-
710-
if FLAGS.set_pytorch_max_split_size:
711-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
712-
713711
# Extend path according to framework.
714712
workload_metadata['workload_path'] = os.path.join(
715713
BASE_WORKLOADS_DIR,

0 commit comments

Comments
 (0)