19
19
from tqdm import tqdm
20
20
21
21
from QEfficient .finetune .configs .training import TrainConfig
22
- from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx , is_rank_zero , get_num_ddp_devices
22
+ from QEfficient .finetune .utils .helper import get_autocast_ctx , get_num_ddp_devices , get_op_verifier_ctx , is_rank_zero
23
23
from QEfficient .finetune .utils .logging_utils import logger
24
24
25
25
try :
@@ -63,8 +63,8 @@ def train(
63
63
64
64
train_metric = []
65
65
train_loss = []
66
- val_metric = []
67
- val_loss = []
66
+ eval_metric = []
67
+ eval_loss = []
68
68
69
69
if train_config .save_metrics :
70
70
if not os .path .exists (train_config .output_dir ):
@@ -74,13 +74,13 @@ def train(
74
74
)
75
75
train_step_metric = []
76
76
train_step_loss = []
77
- val_step_loss = []
78
- val_step_metric = []
77
+ eval_step_loss = []
78
+ eval_step_metric = []
79
79
80
80
epoch_times = []
81
81
checkpoint_times = []
82
82
results = {}
83
- best_val_loss = float ("inf" )
83
+ best_eval_loss = float ("inf" )
84
84
total_train_steps = 0
85
85
max_steps_reached = False # Flag to indicate max training steps reached
86
86
@@ -130,7 +130,6 @@ def train(
130
130
continue
131
131
132
132
logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
133
- logger .log_rank_zero (f"train_config.max_train_step: { train_config .max_train_step } " )
134
133
# stop when the maximum number of training steps is reached
135
134
if max_steps_reached :
136
135
break
@@ -207,23 +206,21 @@ def train(
207
206
total_loss += loss .detach ().float ()
208
207
209
208
if is_rank_zero ():
209
+ tensorboard_updates .add_scalars ("loss" , {"train" : loss }, total_train_steps )
210
210
if loss <= train_config .convergence_loss :
211
211
loss_0_counter += 1
212
212
else :
213
213
loss_0_counter = torch .tensor ([0 ]).to (device )
214
214
if train_config .enable_ddp :
215
215
dist .broadcast (loss_0_counter , src = 0 )
216
216
217
- if is_rank_zero ():
218
- tensorboard_updates .add_scalars ("loss" , {"train" : loss }, total_train_steps )
219
-
220
217
if train_config .save_metrics :
221
218
train_step_loss .append (loss .detach ().float ().item ())
222
219
if train_config .task_type == "seq_classification" :
223
- step_metric_val = float (acc_helper .compute ())
220
+ step_metric_value = float (acc_helper .compute ())
224
221
else :
225
- step_metric_val = float (torch .exp (loss .detach ().float ()))
226
- train_step_metric .append (step_metric_val )
222
+ step_metric_value = float (torch .exp (loss .detach ().float ()))
223
+ train_step_metric .append (step_metric_value )
227
224
228
225
# Accumalate gradients
229
226
complete_accum_steps = (
@@ -271,10 +268,10 @@ def train(
271
268
train_loss ,
272
269
train_step_metric ,
273
270
train_metric ,
274
- val_step_loss ,
275
- val_loss ,
276
- val_step_metric ,
277
- val_metric ,
271
+ eval_step_loss ,
272
+ eval_loss ,
273
+ eval_step_metric ,
274
+ eval_metric ,
278
275
)
279
276
if loss_0_counter .item () == train_config .convergence_counter :
280
277
logger .log_rank_zero (
@@ -286,32 +283,19 @@ def train(
286
283
epoch_end_time = time .perf_counter () - epoch_start_time
287
284
epoch_times .append (epoch_end_time )
288
285
289
- if loss_0_counter .item () == train_config .convergence_counter :
290
- if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
291
- train_epoch_loss = (
292
- 0.0
293
- if total_loss == 0.0
294
- else total_loss / (step - intermediate_step - num_dummy_samples / train_config .train_batch_size )
295
- )
296
- else :
297
- train_epoch_loss = (
298
- 0.0
299
- if total_loss == 0.0
300
- else total_loss / (step + 1 - num_dummy_samples / train_config .train_batch_size )
301
- )
286
+ if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
287
+ train_epoch_loss = (
288
+ 0.0
289
+ if total_loss == 0.0
290
+ else total_loss / (step - intermediate_step - (num_dummy_samples / train_config .train_batch_size ))
291
+ )
302
292
else :
303
- if train_config .use_peft and train_config .from_peft_checkpoint and epoch == intermediate_epoch :
304
- train_epoch_loss = (
305
- 0.0
306
- if total_loss == 0.0
307
- else total_loss / (step - intermediate_step - (num_dummy_samples / train_config .train_batch_size ))
308
- )
309
- else :
310
- train_epoch_loss = (
311
- 0.0
312
- if total_loss == 0.0
313
- else total_loss / (step + 1 - (num_dummy_samples / train_config .train_batch_size ))
314
- )
293
+ train_epoch_loss = (
294
+ 0.0
295
+ if total_loss == 0.0
296
+ else total_loss / (step + 1 - (num_dummy_samples / train_config .train_batch_size ))
297
+ )
298
+
315
299
if train_config .task_type == "seq_classification" :
316
300
train_epoch_metric = acc_helper .compute ()
317
301
acc_helper .reset ()
@@ -331,30 +315,30 @@ def train(
331
315
lr_scheduler .step ()
332
316
333
317
if train_config .run_validation :
334
- eval_loss , eval_metric , step_loss , step_metric = evaluation_helper (
318
+ eval_epoch_loss , eval_epoch_metric , step_loss , step_metric = evaluation_helper (
335
319
model , train_config , eval_dataloader , device
336
320
)
337
321
338
- if eval_loss < best_val_loss :
339
- best_val_loss = eval_loss
340
- logger .log_rank_zero (f"Best eval loss on epoch { epoch + 1 } is { best_val_loss :.4f} " )
322
+ if eval_epoch_loss < best_eval_loss :
323
+ best_eval_loss = eval_epoch_loss
324
+ logger .log_rank_zero (f"Best eval loss on epoch { epoch + 1 } is { best_eval_loss :.4f} " )
341
325
342
326
if is_rank_zero ():
343
- tensorboard_updates .add_scalars ("loss" , {"eval" : eval_loss }, total_train_steps )
327
+ tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
344
328
if train_config .save_metrics :
345
- val_step_loss .extend (step_loss )
346
- val_step_metric .extend (step_metric )
347
- val_loss .append (float (eval_loss ))
348
- val_metric .append (float (eval_metric ))
329
+ eval_step_loss .extend (step_loss )
330
+ eval_step_metric .extend (step_metric )
331
+ eval_loss .append (float (eval_epoch_loss ))
332
+ eval_metric .append (float (eval_epoch_metric ))
349
333
350
334
if train_config .enable_ddp :
351
- dist .all_reduce (eval_loss , op = dist .ReduceOp .SUM )
352
- eval_loss /= get_num_ddp_devices ()
353
- dist .all_reduce (eval_metric , op = dist .ReduceOp .SUM )
354
- eval_metric /= get_num_ddp_devices ()
335
+ dist .all_reduce (eval_epoch_loss , op = dist .ReduceOp .SUM )
336
+ eval_epoch_loss /= get_num_ddp_devices ()
337
+ dist .all_reduce (eval_epoch_metric , op = dist .ReduceOp .SUM )
338
+ eval_epoch_metric /= get_num_ddp_devices ()
355
339
356
340
logger .log_rank_zero (
357
- f"Epoch { epoch + 1 } : Eval Loss: { eval_loss .detach ().cpu ():.4f} , Eval metric: { eval_metric .detach ().cpu ():.4f} "
341
+ f"Epoch { epoch + 1 } : Eval Loss: { eval_epoch_loss .detach ().cpu ():.4f} , Eval metric: { eval_epoch_metric .detach ().cpu ():.4f} "
358
342
)
359
343
360
344
# saving the adapters after completion of each epoch
@@ -377,19 +361,19 @@ def train(
377
361
train_loss ,
378
362
train_step_metric ,
379
363
train_metric ,
380
- val_step_loss ,
381
- val_loss ,
382
- val_step_metric ,
383
- val_metric ,
364
+ eval_step_loss ,
365
+ eval_loss ,
366
+ eval_step_metric ,
367
+ eval_metric ,
384
368
)
385
369
avg_epoch_time = sum (epoch_times ) / len (epoch_times )
386
370
avg_checkpoint_time = sum (checkpoint_times ) / len (checkpoint_times ) if len (checkpoint_times ) > 0 else 0
387
371
388
- results ["last_epoch_train_loss" ] = train_epoch_loss
389
- results ["last_epoch_train_metric" ] = train_epoch_metric
372
+ results ["last_epoch_train_loss" ] = train_epoch_loss . cpu ()
373
+ results ["last_epoch_train_metric" ] = train_epoch_metric . cpu ()
390
374
if train_config .run_validation :
391
- results ["last_epoch_eval_loss" ] = eval_loss
392
- results ["last_epoch_eval_metric" ] = eval_metric
375
+ results ["last_epoch_eval_loss" ] = eval_epoch_loss . cpu ()
376
+ results ["last_epoch_eval_metric" ] = eval_epoch_metric . cpu ()
393
377
results ["avg_epoch_time" ] = avg_epoch_time
394
378
results ["avg_checkpoint_time" ] = avg_checkpoint_time
395
379
if train_config .save_metrics :
@@ -405,7 +389,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
405
389
model: The model to evaluate
406
390
eval_dataloader: The dataloader containing the evaluation data
407
391
408
- Returns: eval_epoch_loss, eval_metric , eval_step_loss, eval_step_metric
392
+ Returns: eval_epoch_loss, eval_epoch_metric , eval_step_loss, eval_step_metric
409
393
"""
410
394
if train_config .enable_ddp :
411
395
dist .barrier ()
@@ -422,8 +406,8 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
422
406
# special handling for qaic device and dtype
423
407
# model.to(device)
424
408
425
- val_step_loss = []
426
- val_step_metric = []
409
+ eval_step_loss = []
410
+ eval_step_metric = []
427
411
428
412
eval_loss = torch .tensor (0.0 , dtype = torch .float32 , device = device ) # Initialize evaluation loss
429
413
device_type = torch .device (device ).type
@@ -459,24 +443,27 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
459
443
logits = outputs .logits
460
444
labels = batch ["labels" ][:, 0 ]
461
445
preds = torch .nn .functional .softmax (logits , dim = - 1 )
462
- val_acc = acc_helper .forward (preds , labels )
463
- metric_val = val_acc .detach ().float ().item ()
446
+ eval_acc = acc_helper .forward (preds , labels )
447
+ metric_value = eval_acc .detach ().float ().item ()
464
448
else :
465
- metric_val = float (torch .exp (loss .detach ().float ()))
449
+ metric_value = float (torch .exp (loss .detach ().float ()))
466
450
467
451
if train_config .save_metrics :
468
- val_step_loss .append (loss .detach ().float ().item ())
469
- val_step_metric .append (metric_val )
452
+ eval_step_loss .append (loss .detach ().float ().item ())
453
+ eval_step_metric .append (metric_value )
470
454
471
455
eval_loss += loss .detach ().float ()
456
+
472
457
# Compute average loss and metric
473
- eval_loss = 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config .val_batch_size )
458
+ eval_epoch_loss = (
459
+ 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config .val_batch_size )
460
+ )
474
461
if train_config .task_type == "seq_classification" :
475
- eval_metric = acc_helper .compute ()
462
+ eval_epoch_metric = acc_helper .compute ()
476
463
else :
477
- eval_metric = torch .exp (eval_loss )
464
+ eval_epoch_metric = torch .exp (eval_epoch_loss )
478
465
479
- return eval_loss , eval_metric , val_step_loss , val_step_metric
466
+ return eval_epoch_loss , eval_epoch_metric , eval_step_loss , eval_step_metric
480
467
481
468
482
469
def get_longest_seq_length (data : List [Dict ]) -> Tuple [int , int ]:
@@ -517,20 +504,20 @@ def save_to_json(
517
504
train_epoch_loss ,
518
505
train_step_metric ,
519
506
train_epoch_metric ,
520
- val_step_loss ,
521
- val_epoch_loss ,
522
- val_step_metric ,
523
- val_epoch_metric ,
507
+ eval_step_loss ,
508
+ eval_epoch_loss ,
509
+ eval_step_metric ,
510
+ eval_epoch_metric ,
524
511
):
525
512
metrics_data = {
526
513
"train_step_loss" : train_step_loss ,
527
514
"train_epoch_loss" : train_epoch_loss ,
528
515
"train_step_metric" : train_step_metric ,
529
516
"train_epoch_metric" : train_epoch_metric ,
530
- "val_step_loss " : val_step_loss ,
531
- "val_epoch_loss " : val_epoch_loss ,
532
- "val_step_metric " : val_step_metric ,
533
- "val_epoch_metric " : val_epoch_metric ,
517
+ "eval_step_loss " : eval_step_loss ,
518
+ "eval_epoch_loss " : eval_epoch_loss ,
519
+ "eval_step_metric " : eval_step_metric ,
520
+ "eval_epoch_metric " : eval_epoch_metric ,
534
521
}
535
522
with open (output_filename , "w" ) as f :
536
523
json .dump (metrics_data , f )
0 commit comments