22
22
import json
23
23
import os
24
24
25
- os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'expandable_segments:True'
26
- os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "2" # Disables tensorRT, cuda warnings.
27
- # disable only for deepspeech if it works fine for other workloads.
28
- os .environ ['XLA_FLAGS' ] = '--xla_gpu_enable_triton_gemm=false'
29
-
30
25
import struct
31
26
import time
32
27
from types import MappingProxyType
56
51
from algorithmic_efficiency .pytorch_utils import sync_ddp_time
57
52
from algorithmic_efficiency .workloads import workloads
58
53
54
+ # Environment variables
55
+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "2" # Disables tensorRT, cuda warnings.
56
+ # disable only for deepspeech if it works fine for other workloads
57
+ os .environ ['XLA_FLAGS' ] = '--xla_gpu_enable_triton_gemm=false'
58
+
59
59
# TODO(znado): make a nicer registry of workloads that lookup in.
60
60
BASE_WORKLOADS_DIR = workloads .BASE_WORKLOADS_DIR
61
61
@@ -681,6 +681,14 @@ def main(_):
681
681
else :
682
682
profiler = PassThroughProfiler ()
683
683
684
+ # Set PyTorch environment variables before initializing w DDP
685
+ base_workload = workloads .get_base_workload_name (FLAGS .workload )
686
+ if base_workload == 'librispeech_conformer' :
687
+ os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'expandable_segments:True'
688
+
689
+ if FLAGS .set_pytorch_max_split_size :
690
+ os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'max_split_size_mb:256'
691
+
684
692
if FLAGS .framework == 'pytorch' :
685
693
pytorch_init (USE_PYTORCH_DDP , RANK , profiler )
686
694
@@ -692,9 +700,6 @@ def main(_):
692
700
693
701
workload_metadata = WORKLOADS [FLAGS .workload ]
694
702
695
- # Prevent OOM on librispeech conformer.
696
- base_workload = workloads .get_base_workload_name (FLAGS .workload )
697
-
698
703
if base_workload in [
699
704
'librispeech_conformer' ,
700
705
'librispeech_deepspeech' ,
@@ -703,13 +708,6 @@ def main(_):
703
708
]:
704
709
os .environ ['XLA_PYTHON_CLIENT_MEM_FRACTION' ] = '0.80'
705
710
706
- if base_workload != 'librispeech_conformer' :
707
- # Remove the environment variable (only for workloads other than librispeech conformer).
708
- del os .environ ['PYTORCH_CUDA_ALLOC_CONF' ]
709
-
710
- if FLAGS .set_pytorch_max_split_size :
711
- os .environ ['PYTORCH_CUDA_ALLOC_CONF' ] = 'max_split_size_mb:256'
712
-
713
711
# Extend path according to framework.
714
712
workload_metadata ['workload_path' ] = os .path .join (
715
713
BASE_WORKLOADS_DIR ,
0 commit comments