Skip to content

Commit f6ca2bc

Browse files
committed
env variable for conformer set at the top
1 parent 81bc93d commit f6ca2bc

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

submission_runner.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
import itertools
2222
import json
2323
import os
24+
25+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
26+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
27+
# disable only for deepspeech if it works fine for other workloads.
28+
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
29+
2430
import struct
2531
import time
2632
from types import MappingProxyType
@@ -30,12 +36,10 @@
3036
from absl import flags
3137
from absl import logging
3238
import jax
39+
import tensorflow as tf
3340
import torch
3441
import torch.distributed as dist
3542

36-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
37-
import tensorflow as tf
38-
3943
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
4044
# it unavailable to JAX.
4145
tf.config.set_visible_devices([], 'GPU')
@@ -52,9 +56,6 @@
5256
from algorithmic_efficiency.pytorch_utils import sync_ddp_time
5357
from algorithmic_efficiency.workloads import workloads
5458

55-
# disable only for deepspeech if it works fine for other workloads.
56-
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
57-
5859
# TODO(znado): make a nicer registry of workloads that lookup in.
5960
BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR
6061

@@ -702,12 +703,13 @@ def main(_):
702703
]:
703704
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'
704705

706+
if base_workload != 'librispeech_conformer':
707+
# Remove the environment variable (only for workloads other than librispeech conformer).
708+
del os.environ['PYTORCH_CUDA_ALLOC_CONF']
709+
705710
if FLAGS.set_pytorch_max_split_size:
706711
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
707712

708-
if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer':
709-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
710-
711713
# Extend path according to framework.
712714
workload_metadata['workload_path'] = os.path.join(
713715
BASE_WORKLOADS_DIR,

0 commit comments

Comments
 (0)