-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinetune_hf_peft.py
74 lines (65 loc) · 2.26 KB
/
finetune_hf_peft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import json
from metaflow import FlowSpec, step, IncludeFile, Parameter, secrets, resources, secrets, retry, pypi_base, huggingface_card, kubernetes, S3
from metaflow.profilers import gpu_profile
from exceptions import GatedRepoError, GATED_HF_ORGS
@pypi_base(packages={
'datasets': '',
'torch': '',
'transformers': '',
'peft': '',
'trl': '',
'accelerate': '',
'bitsandbytes': '',
'sentencepiece': '',
'safetensors': '',
'requests': ''
})
class FinetuneLlama3LoRA(FlowSpec):
script_args_file = IncludeFile(
'script_args',
help="JSON file containing script arguments",
default="hf_peft_args.json"
)
smoke = Parameter(
'smoke',
type=bool,
default=False,
help="Flag for a smoke test"
)
@secrets(sources=["huggingface-token"])
@step
def start(self):
from my_peft_tools import ScriptArguments
args_dict = json.loads(self.script_args_file)
self.script_args = ScriptArguments(**args_dict)
if (
self.script_args.dataset_name.split("/")[0] in GATED_HF_ORGS
and "HF_TOKEN" not in os.environ
):
raise GatedRepoError(self.script_args.dataset_name)
self.next(self.sft)
@gpu_profile(interval=1)
@huggingface_card
@secrets(sources=["huggingface-token"])
@kubernetes(gpu=1)
@step
def sft(self):
import os
from my_peft_tools import create_model, create_trainer, save_model, get_tar_bytes
import huggingface_hub
huggingface_hub.login(os.environ['HF_TOKEN']) # contained in hugginface-token secret
model, tokenizer = create_model(self.script_args)
trainer = create_trainer(self.script_args, tokenizer, model, smoke=self.smoke, card=True)
trainer.train()
output_dirname, merge_output_dirname = save_model(self.script_args, trainer)
with S3(run=self) as s3:
s3.put('lora_adapter.tar.gz', get_tar_bytes(output_dirname))
if merge_output_dirname:
s3.put('lora_merged.tar.gz', get_tar_bytes(merge_output_dirname))
self.next(self.end)
@step
def end(self):
print("Training completed successfully!")
if __name__ == '__main__':
FinetuneLlama3LoRA()