|
7 | 7 |
|
8 | 8 | import os
|
9 | 9 | import shutil
|
| 10 | +from pathlib import Path |
10 | 11 |
|
11 | 12 | import numpy as np
|
12 | 13 | import pytest
|
| 14 | +import requests |
13 | 15 | import torch.optim as optim
|
14 | 16 | from torch.utils.data import DataLoader
|
15 | 17 |
|
16 | 18 | import QEfficient
|
17 | 19 | import QEfficient.cloud.finetune
|
18 | 20 | from QEfficient.cloud.finetune import main as finetune
|
19 | 21 |
|
| 22 | +alpaca_json_path = Path.cwd() / "alpaca_data.json" |
| 23 | + |
20 | 24 |
|
21 | 25 | def clean_up(path):
|
22 |
| - if os.path.exists(path): |
| 26 | + if os.path.isdir(path) and os.path.exists(path): |
23 | 27 | shutil.rmtree(path)
|
| 28 | + if os.path.isfile(path): |
| 29 | + os.remove(path) |
| 30 | + |
| 31 | + |
| 32 | +def download_alpaca(): |
| 33 | + alpaca_url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json" |
| 34 | + response = requests.get(alpaca_url) |
| 35 | + |
| 36 | + with open(alpaca_json_path, "wb") as f: |
| 37 | + f.write(response.content) |
24 | 38 |
|
25 | 39 |
|
26 | 40 | configs = [
|
27 | 41 | pytest.param(
|
28 | 42 | "meta-llama/Llama-3.2-1B", # model_name
|
| 43 | + "generation", # task_type |
29 | 44 | 10, # max_eval_step
|
30 | 45 | 20, # max_train_step
|
| 46 | + "gsm8k_dataset", # dataset_name |
| 47 | + None, # data_path |
31 | 48 | 1, # intermediate_step_save
|
32 | 49 | None, # context_length
|
33 | 50 | True, # run_validation
|
34 | 51 | True, # use_peft
|
35 | 52 | "qaic", # device
|
36 |
| - id="llama_config", # config name |
37 |
| - ) |
| 53 | + 0.0043353, # expected_train_loss |
| 54 | + 1.0043447, # expected_train_metric |
| 55 | + 0.0117334, # expected_eval_loss |
| 56 | + 1.0118025, # expected_eval_metric |
| 57 | + id="llama_config_gsm8k", # config name |
| 58 | + ), |
| 59 | + pytest.param( |
| 60 | + "meta-llama/Llama-3.2-1B", # model_name |
| 61 | + "generation", # task_type |
| 62 | + 10, # max_eval_step |
| 63 | + 20, # max_train_step |
| 64 | + "alpaca_dataset", # dataset_name |
| 65 | + alpaca_json_path, # data_path |
| 66 | + 1, # intermediate_step_save |
| 67 | + None, # context_length |
| 68 | + True, # run_validation |
| 69 | + True, # use_peft |
| 70 | + "qaic", # device |
| 71 | + 0.0006099, # expected_train_loss |
| 72 | + 1.0006101, # expected_train_metric |
| 73 | + 0.0065296, # expected_eval_loss |
| 74 | + 1.0065510, # expected_eval_metric |
| 75 | + id="llama_config_alpaca", # config name |
| 76 | + ), |
| 77 | + pytest.param( |
| 78 | + "google-bert/bert-base-uncased", # model_name |
| 79 | + "seq_classification", # task_type |
| 80 | + 10, # max_eval_step |
| 81 | + 20, # max_train_step |
| 82 | + "imdb_dataset", # dataset_name |
| 83 | + None, # data_path |
| 84 | + 1, # intermediate_step_save |
| 85 | + None, # context_length |
| 86 | + True, # run_validation |
| 87 | + False, # use_peft |
| 88 | + "qaic", # device |
| 89 | + 0.00052981, # expected_train_loss |
| 90 | + 0.55554199, # expected_train_metric |
| 91 | + 0.00738618, # expected_eval_loss |
| 92 | + 0.70825195, # expected_eval_metric |
| 93 | + id="bert_config_imdb", # config name |
| 94 | + ), |
38 | 95 | ]
|
39 | 96 |
|
40 | 97 |
|
41 |
| -@pytest.mark.skip(reason="Currently CI is broken. Once it is fixed we will enable this test.") |
42 | 98 | @pytest.mark.cli
|
43 | 99 | @pytest.mark.on_qaic
|
44 | 100 | @pytest.mark.finetune
|
45 | 101 | @pytest.mark.parametrize(
|
46 |
| - "model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device", |
| 102 | + "model_name,task_type,max_eval_step,max_train_step,dataset_name,data_path,intermediate_step_save,context_length,run_validation,use_peft,device,expected_train_loss,expected_train_metric,expected_eval_loss,expected_eval_metric", |
47 | 103 | configs,
|
48 | 104 | )
|
49 |
| -def test_finetune( |
| 105 | +def test_finetune_llama( |
50 | 106 | model_name,
|
| 107 | + task_type, |
51 | 108 | max_eval_step,
|
52 | 109 | max_train_step,
|
| 110 | + dataset_name, |
| 111 | + data_path, |
53 | 112 | intermediate_step_save,
|
54 | 113 | context_length,
|
55 | 114 | run_validation,
|
56 | 115 | use_peft,
|
57 | 116 | device,
|
| 117 | + expected_train_loss, |
| 118 | + expected_train_metric, |
| 119 | + expected_eval_loss, |
| 120 | + expected_eval_metric, |
58 | 121 | mocker,
|
59 | 122 | ):
|
60 | 123 | train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig")
|
61 | 124 | generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config")
|
62 | 125 | generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config")
|
63 |
| - get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs") |
| 126 | + get_dataloader_kwargs_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_dataloader_kwargs") |
64 | 127 | update_config_spy = mocker.spy(QEfficient.cloud.finetune, "update_config")
|
65 |
| - get_custom_data_collator_spy = mocker.spy(QEfficient.cloud.finetune, "get_custom_data_collator") |
66 |
| - get_preprocessed_dataset_spy = mocker.spy(QEfficient.cloud.finetune, "get_preprocessed_dataset") |
| 128 | + get_custom_data_collator_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_custom_data_collator") |
| 129 | + get_preprocessed_dataset_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_preprocessed_dataset") |
67 | 130 | get_longest_seq_length_spy = mocker.spy(QEfficient.cloud.finetune, "get_longest_seq_length")
|
68 | 131 | print_model_size_spy = mocker.spy(QEfficient.cloud.finetune, "print_model_size")
|
69 | 132 | train_spy = mocker.spy(QEfficient.cloud.finetune, "train")
|
70 | 133 |
|
71 | 134 | kwargs = {
|
72 | 135 | "model_name": model_name,
|
| 136 | + "task_type": task_type, |
73 | 137 | "max_eval_step": max_eval_step,
|
74 | 138 | "max_train_step": max_train_step,
|
| 139 | + "dataset": dataset_name, |
| 140 | + "data_path": data_path, |
75 | 141 | "intermediate_step_save": intermediate_step_save,
|
76 | 142 | "context_length": context_length,
|
77 | 143 | "run_validation": run_validation,
|
78 | 144 | "use_peft": use_peft,
|
79 | 145 | "device": device,
|
80 | 146 | }
|
81 | 147 |
|
| 148 | + if dataset_name == "alpaca_dataset": |
| 149 | + download_alpaca() |
| 150 | + |
82 | 151 | results = finetune(**kwargs)
|
83 |
| - assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching." |
84 |
| - assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching." |
85 |
| - assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching." |
86 |
| - assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching." |
| 152 | + assert np.allclose(results["avg_train_loss"], expected_train_loss, atol=1e-3), "Train loss is not matching." |
| 153 | + assert np.allclose(results["avg_train_metric"], expected_train_metric, atol=1e-3), "Train metric is not matching." |
| 154 | + assert np.allclose(results["avg_eval_loss"], expected_eval_loss, atol=1e-3), "Eval loss is not matching." |
| 155 | + assert np.allclose(results["avg_eval_metric"], expected_eval_metric, atol=1e-3), "Eval metric is not matching." |
87 | 156 | assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
|
88 | 157 |
|
89 | 158 | train_config_spy.assert_called_once()
|
90 | 159 | generate_dataset_config_spy.assert_called_once()
|
91 |
| - generate_peft_config_spy.assert_called_once() |
92 |
| - get_custom_data_collator_spy.assert_called_once() |
| 160 | + if task_type == "generation": |
| 161 | + generate_peft_config_spy.assert_called_once() |
93 | 162 | get_longest_seq_length_spy.assert_called_once()
|
94 | 163 | print_model_size_spy.assert_called_once()
|
95 | 164 | train_spy.assert_called_once()
|
96 | 165 |
|
97 | 166 | assert update_config_spy.call_count == 2
|
| 167 | + assert get_custom_data_collator_spy.call_count == 2 |
98 | 168 | assert get_dataloader_kwargs_spy.call_count == 2
|
99 | 169 | assert get_preprocessed_dataset_spy.call_count == 2
|
100 | 170 |
|
@@ -123,12 +193,19 @@ def test_finetune(
|
123 | 193 | f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps."
|
124 | 194 | )
|
125 | 195 |
|
126 |
| - saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors") |
| 196 | + if use_peft: |
| 197 | + saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors") |
| 198 | + else: |
| 199 | + saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/model.safetensors") |
127 | 200 | assert os.path.isfile(saved_file)
|
128 | 201 |
|
129 | 202 | clean_up(train_config.output_dir)
|
130 | 203 | clean_up("runs")
|
| 204 | + clean_up("qaic-dumps") |
131 | 205 | clean_up(train_config.dump_root_dir)
|
132 | 206 |
|
| 207 | + if dataset_name == "alpaca_dataset": |
| 208 | + clean_up(alpaca_json_path) |
| 209 | + |
133 | 210 |
|
134 | 211 | # TODO (Meet): Add seperate tests for BERT FT and LLama FT
|
0 commit comments