|
21 | 21 | import itertools
|
22 | 22 | import json
|
23 | 23 | import os
|
| 24 | + |
| 25 | +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
| 26 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. |
| 27 | +# disable only for deepspeech if it works fine for other workloads. |
| 28 | +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' |
| 29 | + |
24 | 30 | import struct
|
25 | 31 | import time
|
26 | 32 | from types import MappingProxyType
|
|
30 | 36 | from absl import flags
|
31 | 37 | from absl import logging
|
32 | 38 | import jax
|
| 39 | +import tensorflow as tf |
33 | 40 | import torch
|
34 | 41 | import torch.distributed as dist
|
35 | 42 |
|
36 |
| -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. |
37 |
| -import tensorflow as tf |
38 |
| - |
39 | 43 | # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
|
40 | 44 | # it unavailable to JAX.
|
41 | 45 | tf.config.set_visible_devices([], 'GPU')
|
|
52 | 56 | from algorithmic_efficiency.pytorch_utils import sync_ddp_time
|
53 | 57 | from algorithmic_efficiency.workloads import workloads
|
54 | 58 |
|
55 |
| -# disable only for deepspeech if it works fine for other workloads. |
56 |
| -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' |
57 |
| - |
58 | 59 | # TODO(znado): make a nicer registry of workloads that lookup in.
|
59 | 60 | BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR
|
60 | 61 |
|
@@ -702,12 +703,13 @@ def main(_):
|
702 | 703 | ]:
|
703 | 704 | os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'
|
704 | 705 |
|
| 706 | + if base_workload != 'librispeech_conformer': |
| 707 | + # Remove the environment variable (only for workloads other than librispeech conformer). |
| 708 | + del os.environ['PYTORCH_CUDA_ALLOC_CONF'] |
| 709 | + |
705 | 710 | if FLAGS.set_pytorch_max_split_size:
|
706 | 711 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
|
707 | 712 |
|
708 |
| - if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer': |
709 |
| - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
710 |
| - |
711 | 713 | # Extend path according to framework.
|
712 | 714 | workload_metadata['workload_path'] = os.path.join(
|
713 | 715 | BASE_WORKLOADS_DIR,
|
|
0 commit comments