Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions method_comparison/MetaMathQA/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import argparse
import datetime as dt
import gc
import json
import os
import random
Expand Down Expand Up @@ -58,12 +57,10 @@
from peft.utils import CONFIG_NAME, infer_device


# # suppress all warnings
# warnings.filterwarnings("ignore") # FIXME?

dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
# if lr scheduler with warmup is used, the ratio of warmup steps to total steps
BUCKET_FACTOR = 20 # number of batches per bucket, increasing this further has diminishing returns
# number of batches per bucket, increasing this further has diminishing returns
BUCKET_FACTOR = 20
# empty device cache every N steps; 10 is a good compromise between keeping memory down while lowering runtime overhead
ACCELERATOR_EMPTY_CACHE_SCHEDULE = 10


def get_generation_config(*, seq_len, generate_kwargs) -> GenerationConfig:
Expand Down Expand Up @@ -298,9 +295,8 @@ def train(
}
print_verbose(json.dumps(log_dict))

# # TODO is this needed?
torch_accelerator_module.empty_cache()
gc.collect()
if step % ACCELERATOR_EMPTY_CACHE_SCHEDULE == 0:
torch_accelerator_module.empty_cache()

print_verbose(f"Training finished after {max_steps} steps, evaluation on test set follows.")
# test set evaluation
Expand Down
Loading