Skip to content

Commit 01b0600

Browse files
quic-meetkumaquic-dhirajku
authored andcommitted
[QEff. Finetune]: Enabled FT CI tests. (quic#420)
- Enabled CI tests for Finetuning. - Updated Jenkins file to install torch_qaic as it is required during FT tests. - Added finetune as a new pytest flag and updated other existing tests not to trigger for this flag. --------- Signed-off-by: meetkuma <[email protected]> Co-authored-by: Meet Patel <[email protected]> Signed-off-by: Dhiraj Kumar Sah <[email protected]>
1 parent e925939 commit 01b0600

File tree

3 files changed

+119
-24
lines changed

3 files changed

+119
-24
lines changed

QEfficient/finetune/dataset/samsum_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
12-
dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
12+
dataset = datasets.load_dataset("knkarthick/samsum", split=split, trust_remote_code=True)
1313

1414
prompt = "Summarize this dialog:\n{dialog}\n---\nSummary:\n"
1515

scripts/Jenkinsfile

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pipeline {
2525
pip install junitparser pytest-xdist &&
2626
pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing
2727
pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.19.1+cpu einops==0.8.1 && #packages to load VLMs
28+
pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl && # For finetuning tests
2829
rm -rf QEfficient"
2930
'''
3031
}
@@ -41,7 +42,7 @@ pipeline {
4142
mkdir -p $PWD/Non_cli_qaic &&
4243
export TOKENIZERS_PARALLELISM=false &&
4344
export QEFF_HOME=$PWD/Non_cli_qaic &&
44-
pytest tests -m '(not cli) and (not on_qaic)' --ignore tests/vllm -n auto --junitxml=tests/tests_log1.xml &&
45+
pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n auto --junitxml=tests/tests_log1.xml &&
4546
junitparser merge tests/tests_log1.xml tests/tests_log.xml &&
4647
deactivate"
4748
'''
@@ -58,7 +59,7 @@ pipeline {
5859
mkdir -p $PWD/Non_qaic &&
5960
export TOKENIZERS_PARALLELISM=false &&
6061
export QEFF_HOME=$PWD/Non_qaic &&
61-
pytest tests -m '(not cli) and (on_qaic) and (not multimodal) and (not qnn)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log2.xml &&
62+
pytest tests -m '(not cli) and (on_qaic) and (not multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log2.xml &&
6263
junitparser merge tests/tests_log2.xml tests/tests_log.xml &&
6364
deactivate"
6465
'''
@@ -77,14 +78,14 @@ pipeline {
7778
mkdir -p $PWD/Non_cli_qaic_multimodal &&
7879
export TOKENIZERS_PARALLELISM=false &&
7980
export QEFF_HOME=$PWD/Non_cli_qaic_multimodal &&
80-
pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log6.xml &&
81+
pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log6.xml &&
8182
junitparser merge tests/tests_log6.xml tests/tests_log.xml &&
8283
deactivate"
8384
'''
8485
}
8586
}
8687
}
87-
stage('CLI Tests') {
88+
stage('Inference Tests') {
8889
steps {
8990
timeout(time: 60, unit: 'MINUTES') {
9091
sh '''
@@ -96,7 +97,7 @@ pipeline {
9697
mkdir -p $PWD/cli &&
9798
export TOKENIZERS_PARALLELISM=false &&
9899
export QEFF_HOME=$PWD/cli &&
99-
pytest tests -m '(cli and not qnn)' --ignore tests/vllm --junitxml=tests/tests_log3.xml &&
100+
pytest tests -m '(cli and not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log3.xml &&
100101
junitparser merge tests/tests_log3.xml tests/tests_log.xml &&
101102
deactivate"
102103
'''
@@ -125,7 +126,7 @@ pipeline {
125126
mkdir -p $PWD/Qnn_cli &&
126127
export TOKENIZERS_PARALLELISM=false &&
127128
export QEFF_HOME=$PWD/Qnn_cli &&
128-
pytest tests -m '(cli and qnn)' --ignore tests/vllm --junitxml=tests/tests_log4.xml &&
129+
pytest tests -m '(cli and qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log4.xml &&
129130
junitparser merge tests/tests_log4.xml tests/tests_log.xml &&
130131
deactivate"
131132
'''
@@ -144,7 +145,7 @@ pipeline {
144145
mkdir -p $PWD/Qnn_non_cli &&
145146
export TOKENIZERS_PARALLELISM=false &&
146147
export QEFF_HOME=$PWD/Qnn_non_cli &&
147-
pytest tests -m '(not cli) and (qnn) and (on_qaic) and (not multimodal)' --ignore tests/vllm --junitxml=tests/tests_log5.xml &&
148+
pytest tests -m '(not cli) and (qnn) and (on_qaic) and (not multimodal) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log5.xml &&
148149
junitparser merge tests/tests_log5.xml tests/tests_log.xml &&
149150
deactivate"
150151
'''
@@ -170,6 +171,23 @@ pipeline {
170171
}
171172
}
172173
}
174+
stage('Finetune CLI Tests') {
175+
steps {
176+
timeout(time: 5, unit: 'MINUTES') {
177+
sh '''
178+
sudo docker exec ${BUILD_TAG} bash -c "
179+
cd /efficient-transformers &&
180+
. preflight_qeff/bin/activate &&
181+
mkdir -p $PWD/cli_qaic_finetuning &&
182+
export TOKENIZERS_PARALLELISM=false &&
183+
export QEFF_HOME=$PWD/cli_qaic_finetuning &&
184+
pytest tests -m '(cli) and (on_qaic) and (not qnn) and (not multimodal) and (finetune)' --ignore tests/vllm --junitxml=tests/tests_log_finetune.xml &&
185+
junitparser merge tests/tests_log_finetune.xml tests/tests_log.xml &&
186+
deactivate"
187+
'''
188+
}
189+
}
190+
}
173191
}
174192

175193
post {

tests/finetune/test_finetune.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,94 +7,164 @@
77

88
import os
99
import shutil
10+
from pathlib import Path
1011

1112
import numpy as np
1213
import pytest
14+
import requests
1315
import torch.optim as optim
1416
from torch.utils.data import DataLoader
1517

1618
import QEfficient
1719
import QEfficient.cloud.finetune
1820
from QEfficient.cloud.finetune import main as finetune
1921

22+
alpaca_json_path = Path.cwd() / "alpaca_data.json"
23+
2024

2125
def clean_up(path):
22-
if os.path.exists(path):
26+
if os.path.isdir(path) and os.path.exists(path):
2327
shutil.rmtree(path)
28+
if os.path.isfile(path):
29+
os.remove(path)
30+
31+
32+
def download_alpaca():
33+
alpaca_url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json"
34+
response = requests.get(alpaca_url)
35+
36+
with open(alpaca_json_path, "wb") as f:
37+
f.write(response.content)
2438

2539

2640
configs = [
2741
pytest.param(
2842
"meta-llama/Llama-3.2-1B", # model_name
43+
"generation", # task_type
2944
10, # max_eval_step
3045
20, # max_train_step
46+
"gsm8k_dataset", # dataset_name
47+
None, # data_path
3148
1, # intermediate_step_save
3249
None, # context_length
3350
True, # run_validation
3451
True, # use_peft
3552
"qaic", # device
36-
id="llama_config", # config name
37-
)
53+
0.0043353, # expected_train_loss
54+
1.0043447, # expected_train_metric
55+
0.0117334, # expected_eval_loss
56+
1.0118025, # expected_eval_metric
57+
id="llama_config_gsm8k", # config name
58+
),
59+
pytest.param(
60+
"meta-llama/Llama-3.2-1B", # model_name
61+
"generation", # task_type
62+
10, # max_eval_step
63+
20, # max_train_step
64+
"alpaca_dataset", # dataset_name
65+
alpaca_json_path, # data_path
66+
1, # intermediate_step_save
67+
None, # context_length
68+
True, # run_validation
69+
True, # use_peft
70+
"qaic", # device
71+
0.0006099, # expected_train_loss
72+
1.0006101, # expected_train_metric
73+
0.0065296, # expected_eval_loss
74+
1.0065510, # expected_eval_metric
75+
id="llama_config_alpaca", # config name
76+
),
77+
pytest.param(
78+
"google-bert/bert-base-uncased", # model_name
79+
"seq_classification", # task_type
80+
10, # max_eval_step
81+
20, # max_train_step
82+
"imdb_dataset", # dataset_name
83+
None, # data_path
84+
1, # intermediate_step_save
85+
None, # context_length
86+
True, # run_validation
87+
False, # use_peft
88+
"qaic", # device
89+
0.00052981, # expected_train_loss
90+
0.55554199, # expected_train_metric
91+
0.00738618, # expected_eval_loss
92+
0.70825195, # expected_eval_metric
93+
id="bert_config_imdb", # config name
94+
),
3895
]
3996

4097

41-
@pytest.mark.skip(reason="Currently CI is broken. Once it is fixed we will enable this test.")
4298
@pytest.mark.cli
4399
@pytest.mark.on_qaic
44100
@pytest.mark.finetune
45101
@pytest.mark.parametrize(
46-
"model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
102+
"model_name,task_type,max_eval_step,max_train_step,dataset_name,data_path,intermediate_step_save,context_length,run_validation,use_peft,device,expected_train_loss,expected_train_metric,expected_eval_loss,expected_eval_metric",
47103
configs,
48104
)
49-
def test_finetune(
105+
def test_finetune_llama(
50106
model_name,
107+
task_type,
51108
max_eval_step,
52109
max_train_step,
110+
dataset_name,
111+
data_path,
53112
intermediate_step_save,
54113
context_length,
55114
run_validation,
56115
use_peft,
57116
device,
117+
expected_train_loss,
118+
expected_train_metric,
119+
expected_eval_loss,
120+
expected_eval_metric,
58121
mocker,
59122
):
60123
train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig")
61124
generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config")
62125
generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config")
63-
get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs")
126+
get_dataloader_kwargs_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_dataloader_kwargs")
64127
update_config_spy = mocker.spy(QEfficient.cloud.finetune, "update_config")
65-
get_custom_data_collator_spy = mocker.spy(QEfficient.cloud.finetune, "get_custom_data_collator")
66-
get_preprocessed_dataset_spy = mocker.spy(QEfficient.cloud.finetune, "get_preprocessed_dataset")
128+
get_custom_data_collator_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_custom_data_collator")
129+
get_preprocessed_dataset_spy = mocker.spy(QEfficient.finetune.utils.dataset_utils, "get_preprocessed_dataset")
67130
get_longest_seq_length_spy = mocker.spy(QEfficient.cloud.finetune, "get_longest_seq_length")
68131
print_model_size_spy = mocker.spy(QEfficient.cloud.finetune, "print_model_size")
69132
train_spy = mocker.spy(QEfficient.cloud.finetune, "train")
70133

71134
kwargs = {
72135
"model_name": model_name,
136+
"task_type": task_type,
73137
"max_eval_step": max_eval_step,
74138
"max_train_step": max_train_step,
139+
"dataset": dataset_name,
140+
"data_path": data_path,
75141
"intermediate_step_save": intermediate_step_save,
76142
"context_length": context_length,
77143
"run_validation": run_validation,
78144
"use_peft": use_peft,
79145
"device": device,
80146
}
81147

148+
if dataset_name == "alpaca_dataset":
149+
download_alpaca()
150+
82151
results = finetune(**kwargs)
83-
assert np.allclose(results["avg_train_loss"], 0.00232327, atol=1e-5), "Train loss is not matching."
84-
assert np.allclose(results["avg_train_metric"], 1.002326, atol=1e-5), "Train metric is not matching."
85-
assert np.allclose(results["avg_eval_loss"], 0.0206124, atol=1e-5), "Eval loss is not matching."
86-
assert np.allclose(results["avg_eval_metric"], 1.020826, atol=1e-5), "Eval metric is not matching."
152+
assert np.allclose(results["avg_train_loss"], expected_train_loss, atol=1e-3), "Train loss is not matching."
153+
assert np.allclose(results["avg_train_metric"], expected_train_metric, atol=1e-3), "Train metric is not matching."
154+
assert np.allclose(results["avg_eval_loss"], expected_eval_loss, atol=1e-3), "Eval loss is not matching."
155+
assert np.allclose(results["avg_eval_metric"], expected_eval_metric, atol=1e-3), "Eval metric is not matching."
87156
assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
88157

89158
train_config_spy.assert_called_once()
90159
generate_dataset_config_spy.assert_called_once()
91-
generate_peft_config_spy.assert_called_once()
92-
get_custom_data_collator_spy.assert_called_once()
160+
if task_type == "generation":
161+
generate_peft_config_spy.assert_called_once()
93162
get_longest_seq_length_spy.assert_called_once()
94163
print_model_size_spy.assert_called_once()
95164
train_spy.assert_called_once()
96165

97166
assert update_config_spy.call_count == 2
167+
assert get_custom_data_collator_spy.call_count == 2
98168
assert get_dataloader_kwargs_spy.call_count == 2
99169
assert get_preprocessed_dataset_spy.call_count == 2
100170

@@ -123,12 +193,19 @@ def test_finetune(
123193
f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps."
124194
)
125195

126-
saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
196+
if use_peft:
197+
saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
198+
else:
199+
saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/model.safetensors")
127200
assert os.path.isfile(saved_file)
128201

129202
clean_up(train_config.output_dir)
130203
clean_up("runs")
204+
clean_up("qaic-dumps")
131205
clean_up(train_config.dump_root_dir)
132206

207+
if dataset_name == "alpaca_dataset":
208+
clean_up(alpaca_json_path)
209+
133210

134211
# TODO (Meet): Add seperate tests for BERT FT and LLama FT

0 commit comments

Comments
 (0)