1
1
import os
2
2
import sys
3
3
4
- import wandb
4
+ import torch
5
5
import yaml
6
+ from accelerate import PartialState
6
7
from datasets .arrow_dataset import Dataset
7
8
from datasets .load import load_dataset
8
9
from peft import LoraConfig
14
15
)
15
16
from trl import SFTTrainer
16
17
18
+ import wandb
19
+
17
20
os .environ ["TRANSFORMERS_NO_ADVISORY_WARNINGS" ] = "true"
18
21
19
22
@@ -176,24 +179,35 @@ def load_model_and_tokenizer(model_id):
176
179
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
177
180
# NOTE: tokenizer.add_special_tokensやるならこれは不要
178
181
tokenizer .pad_token = tokenizer .eos_token
182
+ tokenizer .padding_side = "right"
179
183
180
184
# Define the quantization configuration for memory-efficient training.
181
185
bnb_config = BitsAndBytesConfig (
182
186
# Load the model weights in 4-bit quantized format.
183
187
load_in_4bit = True ,
188
+ # Specify whether to use double quantization for 4-bit quantization.
189
+ bnb_4bit_use_double_quant = True ,
184
190
# Specify the quantization type to use for 4-bit quantization.
185
191
bnb_4bit_quant_type = "nf4" ,
186
192
# Specify the data type to use for computations during training.
187
- bnb_4bit_compute_dtype = "float16" ,
188
- # Specify whether to use double quantization for 4-bit quantization.
189
- bnb_4bit_use_double_quant = True ,
193
+ bnb_4bit_compute_dtype = torch .float16 ,
190
194
)
191
195
# Load the model from the specified model ID and apply the quantization configuration.
196
+
192
197
model = AutoModelForCausalLM .from_pretrained (
198
+ # Base model id
193
199
model_id ,
200
+ # BitsAndBytes configuration
194
201
quantization_config = bnb_config ,
202
+ # Set torch dtype
203
+ torch_dtype = torch .float16 ,
204
+ # Trust remote code
195
205
trust_remote_code = True ,
196
- device_map = "auto" ,
206
+ # Set device map to auto
207
+ # device_map="auto",
208
+ device_map = {"" : PartialState ().process_index },
209
+ # Set the attention impl
210
+ attn_implementation = "flash_attention_2" ,
197
211
)
198
212
# Disable cache to improve training speed.
199
213
model .config .use_cache = False
@@ -222,12 +236,6 @@ def load_model_and_tokenizer(model_id):
222
236
os .environ ["WANDB_PROJECT" ] = "infinite-tinyllama"
223
237
os .environ ["WANDB_LOG_MODEL" ] = "false"
224
238
os .environ ["WANDB_WATCH" ] = "all"
225
- wandb .init (
226
- project = "infinite-tinyllama" ,
227
- name = train_config ["model_name" ],
228
- group = train_config ["model_name" ],
229
- config = train_config ,
230
- )
231
239
232
240
#
233
241
# Define LoRA and PEFT config
@@ -249,12 +257,11 @@ def load_model_and_tokenizer(model_id):
249
257
optim = "paged_adamw_32bit" ,
250
258
learning_rate = 2e-4 ,
251
259
lr_scheduler_type = "cosine" ,
252
- save_strategy = "steps" ,
253
- save_steps = 100 ,
260
+ save_strategy = "epoch" ,
254
261
logging_steps = 10 ,
255
262
num_train_epochs = int (train_config ["train_num_train_epochs" ]),
256
- max_steps = int (train_config ["train_max_steps" ]),
257
263
fp16 = True ,
264
+ run_name = train_config ["model_name" ],
258
265
)
259
266
260
267
trainer = SFTTrainer (
0 commit comments