From 5f3ff518602d3d9107f449ac7a185d2700e3b39d Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Thu, 31 Oct 2024 23:57:04 -0600 Subject: [PATCH] add few validations that num checkpoints equals num epoch Signed-off-by: Anh Uong --- tests/test_sft_trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 5aeb26725..82a7883eb 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -91,6 +91,7 @@ def test_resume_training_from_checkpoint(): sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) _validate_training(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) # Get trainer state of latest checkpoint init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) @@ -100,6 +101,7 @@ def test_resume_training_from_checkpoint(): train_args.num_train_epochs += 5 sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) _validate_training(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) # Get trainer state of latest checkpoint final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir) @@ -415,6 +417,7 @@ def test_run_causallm_pt_and_inference(): # validate peft tuning configs _validate_training(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) @@ -638,6 +641,7 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): # validate lora tuning configs _validate_training(tempdir) + _validate_num_checkpoints(tempdir, train_args.num_train_epochs) checkpoint_path = _get_checkpoint(tempdir) adapter_config = _get_adapter_config(checkpoint_path) _validate_adapter_config(adapter_config, "LORA") @@ -709,6 +713,7 @@ def test_run_causallm_ft_and_inference(dataset_path): data_args.training_data_path = dataset_path _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, data_args, tempdir) + _validate_num_checkpoints(tempdir, TRAIN_ARGS.num_train_epochs) _test_run_inference(checkpoint_path=_get_checkpoint(tempdir)) @@ -813,6 +818,12 @@ def _validate_logfile(log_file_path, check_eval=False): if check_eval: assert "validation_loss" in train_log_contents +def _validate_num_checkpoints(dir_path, expected_num): + checkpoints = [ + d for d in os.listdir(dir_path) + if d.startswith("checkpoint") + ] + assert len(checkpoints) == expected_num def _get_adapter_config(dir_path): with open(os.path.join(dir_path, "adapter_config.json"), encoding="utf-8") as f: