Skip to content

Commit 90d3fc0

Browse files
ENH: Improve MetaMath training script runtime (#2894)
The training script of the MetaMathQA PEFT method comparison was calling cuda.empty_cache() and gc.collect() after each step. However, this is not really needed and it also slows down training considerably. It turns out that gc.collect() is not needed at all and thus it has been removed. This results in a big improvement in runtime. As for empty_cache(), not calling it at all leads to an increase in memory usage, but it's not necessary to call it every step. It is instead called every 10th step. Improvement (tested locally, 250 steps): - Removing gc.collect() - 108 sec => 65 sec - memory reserved max stays the same (19.3 GB) - memory reserved 99th percentile stays the same (18.0 GB) - memory reserved avg stays the same (12.0 GB) - Also calling empty_cache() only every 10 steps - 65 sec => 50 sec - memory reserved max stays the same (19.3 GB) - memory reserved avg: 18.0 GB => 19.3 GB - memory reserved avg: 12.0 GB => 14.5 GB Thus gc.collect() can be safely removed. And while calling empty_cache() only every 10th step does increase average memory usage, the peak is unaffected, which is what's most important in this benchmark, so it is a worthwhile tradeoff for the 23% speed improvement we get. Note to maintainers: If this is merged, all MetaMathQA benchmarks should be re-run.
1 parent 3fc83e3 commit 90d3fc0

File tree

1 file changed

+6
-10
lines changed
  • method_comparison/MetaMathQA

1 file changed

+6
-10
lines changed

method_comparison/MetaMathQA/run.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import argparse
2020
import datetime as dt
21-
import gc
2221
import json
2322
import os
2423
import random
@@ -58,12 +57,10 @@
5857
from peft.utils import CONFIG_NAME, infer_device
5958

6059

61-
# # suppress all warnings
62-
# warnings.filterwarnings("ignore") # FIXME?
63-
64-
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
65-
# if lr scheduler with warmup is used, the ratio of warmup steps to total steps
66-
BUCKET_FACTOR = 20 # number of batches per bucket, increasing this further has diminishing returns
60+
# number of batches per bucket, increasing this further has diminishing returns
61+
BUCKET_FACTOR = 20
62+
# empty device cache every N steps; 10 is a good compromise between keeping memory down while lowering runtime overhead
63+
ACCELERATOR_EMPTY_CACHE_SCHEDULE = 10
6764

6865

6966
def get_generation_config(*, seq_len, generate_kwargs) -> GenerationConfig:
@@ -298,9 +295,8 @@ def train(
298295
}
299296
print_verbose(json.dumps(log_dict))
300297

301-
# # TODO is this needed?
302-
torch_accelerator_module.empty_cache()
303-
gc.collect()
298+
if step % ACCELERATOR_EMPTY_CACHE_SCHEDULE == 0:
299+
torch_accelerator_module.empty_cache()
304300

305301
print_verbose(f"Training finished after {max_steps} steps, evaluation on test set follows.")
306302
# test set evaluation

0 commit comments

Comments
 (0)