Skip to content

Commit 58159c5

Browse files
committed
adding mem_fraction 0.80 for jax workfloads to resolve OOM of certain worklods
1 parent d7eebf8 commit 58159c5

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

docker/Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]'
9292

9393
RUN cd /algorithmic-efficiency && git fetch origin
9494
RUN cd /algorithmic-efficiency && git pull
95-
RUN pip install wandb
9695

9796
# Todo: remove this, this is temporary for developing
9897
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh

submission_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,12 +693,21 @@ def main(_):
693693

694694
# Prevent OOM on librispeech conformer.
695695
base_workload = workloads.get_base_workload_name(FLAGS.workload)
696-
if base_workload == 'librispeech_conformer':
697-
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85'
696+
697+
if base_workload == [
698+
'librispeech_conformer',
699+
'librispeech_deepspeech',
700+
'imagenet_vit',
701+
'criteo1tb'
702+
]:
703+
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'
698704

699705
if FLAGS.set_pytorch_max_split_size:
700706
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'
701707

708+
if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer':
709+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
710+
702711
# Extend path according to framework.
703712
workload_metadata['workload_path'] = os.path.join(
704713
BASE_WORKLOADS_DIR,

0 commit comments

Comments
 (0)