1
+ import torch
2
+ from datasets import load_dataset
3
+ from trl import SFTTrainer
4
+ from transformers import AutoModelForCausalLM , AutoTokenizer , TrainingArguments
5
+
6
+ """
7
+ A simple example on using SFTTrainer and Accelerate to finetune Phi-3 models. For
8
+ a more advanced example, please follow HF alignment-handbook/scripts/run_sft.py
9
+
10
+ 1. Install accelerate:
11
+ conda install -c conda-forge accelerate
12
+ 2. Setup accelerate config:
13
+ accelerate config
14
+ to simply use all the GPUs available:
15
+ python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='bf16')"
16
+ check accelerate config:
17
+ accelerate env
18
+ 3. Run the code:
19
+ accelerate launch sample_finetune.py
20
+ """
21
+
22
+ ###################
23
+ # Hyper-parameters
24
+ ###################
25
+ args = {
26
+ "bf16" : True ,
27
+ "do_eval" : False ,
28
+ "learning_rate" : 5.0e-06 ,
29
+ "log_level" : "info" ,
30
+ "logging_steps" : 20 ,
31
+ "logging_strategy" : "steps" ,
32
+ "lr_scheduler_type" : "cosine" ,
33
+ "num_train_epochs" : 1 ,
34
+ "max_steps" : - 1 ,
35
+ "output_dir" : "./output" ,
36
+ "overwrite_output_dir" : True ,
37
+ "per_device_eval_batch_size" : 4 ,
38
+ "per_device_train_batch_size" : 8 ,
39
+ "remove_unused_columns" : True ,
40
+ "save_steps" : 100 ,
41
+ "save_total_limit" : 1 ,
42
+ "seed" : 0 ,
43
+ "gradient_checkpointing" : True ,
44
+ "gradient_checkpointing_kwargs" :{"use_reentrant" : False },
45
+ "gradient_accumulation_steps" : 1 ,
46
+ "warmup_ratio" : 0.2 ,
47
+ }
48
+
49
+ training_args = TrainingArguments (** args )
50
+
51
+
52
+ ################
53
+ # Modle Loading
54
+ ################
55
+ checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
56
+ # checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
57
+ model_kwargs = dict (
58
+ use_cache = False ,
59
+ trust_remote_code = True ,
60
+ attn_implementation = "flash_attention_2" , # loading the model with flash-attenstion support
61
+ torch_dtype = torch .bfloat16 ,
62
+ device_map = "cuda" ,
63
+ )
64
+ model = AutoModelForCausalLM .from_pretrained (checkpoint_path , ** model_kwargs )
65
+ tokenizer = AutoTokenizer .from_pretrained (checkpoint_path )
66
+ tokenizer .pad_token = tokenizer .unk_token # use unk rather than eos token to prevent endless generation
67
+ tokenizer .pad_token_id = tokenizer .convert_tokens_to_ids (tokenizer .pad_token )
68
+ tokenizer .padding_side = 'right'
69
+
70
+ ##################
71
+ # Data Processing
72
+ ##################
73
+ def apply_chat_template (
74
+ example ,
75
+ ):
76
+ return example
77
+
78
+ raw_dataset = load_dataset ("yuiseki/scp-jp-plain" ,split = "train[:10]" )
79
+
80
+ processed_dataset = raw_dataset .map (
81
+ apply_chat_template ,
82
+ num_proc = 16 ,
83
+ desc = "Applying chat template" ,
84
+ )
85
+
86
+ data = processed_dataset .train_test_split (seed = 42 , test_size = 0.2 )
87
+ train_dataset = data ["train" ]
88
+ eval_dataset = data ["test" ]
89
+
90
+
91
+ ###########
92
+ # Training
93
+ ###########
94
+ trainer = SFTTrainer (
95
+ model = model ,
96
+ args = training_args ,
97
+ train_dataset = train_dataset ,
98
+ eval_dataset = eval_dataset ,
99
+ max_seq_length = 2048 ,
100
+ dataset_text_field = "text" ,
101
+ tokenizer = tokenizer ,
102
+ packing = True
103
+ )
104
+ train_result = trainer .train ()
105
+ metrics = train_result .metrics
106
+ trainer .log_metrics ("train" , metrics )
107
+ trainer .save_metrics ("train" , metrics )
108
+ trainer .save_state ()
109
+
110
+ #############
111
+ # Evaluation
112
+ #############
113
+ tokenizer .padding_side = 'left'
114
+ metrics = trainer .evaluate ()
115
+ metrics ["eval_samples" ] = len (eval_dataset )
116
+ trainer .log_metrics ("eval" , metrics )
117
+ trainer .save_metrics ("eval" , metrics )
118
+
119
+ ############
120
+ # Save model
121
+ ############
122
+ trainer .save_model (training_args .output_dir )
0 commit comments