@@ -62,9 +62,9 @@ def train(
62
62
63
63
Returns: results dictionary containing average training and validation perplexity and loss
64
64
"""
65
- train_prep = []
65
+ train_metric = []
66
66
train_loss = []
67
- val_prep = []
67
+ val_metric = []
68
68
val_loss = []
69
69
70
70
if train_config .save_metrics :
@@ -73,10 +73,10 @@ def train(
73
73
metrics_filename = (
74
74
f"{ train_config .output_dir } /metrics_data_{ local_rank } -{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} .json"
75
75
)
76
- train_step_perplexity = []
76
+ train_step_metric = []
77
77
train_step_loss = []
78
78
val_step_loss = []
79
- val_step_perplexity = []
79
+ val_step_metric = []
80
80
81
81
epoch_times = []
82
82
checkpoint_times = []
@@ -106,10 +106,10 @@ def train(
106
106
107
107
acc_helper = None
108
108
if train_config .task_type == "seq_classification" :
109
- if local_rank is None :
110
- num_classes = model .classifier .out_features
111
- else :
109
+ if train_config .enable_ddp :
112
110
num_classes = model .module .classifier .out_features
111
+ else :
112
+ num_classes = model .classifier .out_features
113
113
acc_helper = torchmetrics .classification .MulticlassAccuracy (num_classes = num_classes ).to (device )
114
114
115
115
# Start the training loop
@@ -231,7 +231,11 @@ def train(
231
231
232
232
if train_config .save_metrics :
233
233
train_step_loss .append (loss .detach ().float ().item ())
234
- train_step_perplexity .append (float (torch .exp (loss .detach ().float ())))
234
+ if train_config .task_type == "seq_classification" :
235
+ step_metric_val = acc_helper .compute ()
236
+ else :
237
+ step_metric_val = float (torch .exp (loss .detach ().float ()))
238
+ train_step_metric .append (step_metric_val )
235
239
236
240
if train_config .grad_scaler :
237
241
scaler .scale (loss ).backward () # backward pass
@@ -266,12 +270,12 @@ def train(
266
270
metrics_filename ,
267
271
train_step_loss ,
268
272
train_loss ,
269
- train_step_perplexity ,
270
- train_prep ,
273
+ train_step_metric ,
274
+ train_metric ,
271
275
val_step_loss ,
272
276
val_loss ,
273
- val_step_perplexity ,
274
- val_prep ,
277
+ val_step_metric ,
278
+ val_metric ,
275
279
)
276
280
if train_config .enable_ddp :
277
281
if loss_0_counter .item () == train_config .convergence_counter :
@@ -307,11 +311,11 @@ def train(
307
311
if train_config .enable_ddp :
308
312
dist .all_reduce (accuracy , op = dist .ReduceOp .SUM )
309
313
accuracy /= dist .get_world_size ()
310
- train_perplexity = accuracy
314
+ train_metric = accuracy
311
315
else :
312
- train_perplexity = torch .exp (train_epoch_loss )
316
+ train_metric = torch .exp (train_epoch_loss )
313
317
314
- train_prep .append (float (train_perplexity ))
318
+ train_metric .append (float (train_metric ))
315
319
train_loss .append (float (train_epoch_loss ))
316
320
317
321
# Update the learning rate as needed
@@ -320,21 +324,21 @@ def train(
320
324
if train_config .run_validation :
321
325
if train_config .enable_ddp :
322
326
dist .barrier ()
323
- eval_ppl , eval_epoch_loss , temp_val_loss , temp_step_perplexity = evaluation_helper (
324
- model , train_config , eval_dataloader , local_rank , tokenizer , device
327
+ eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation_helper (
328
+ model , train_config , eval_dataloader , device
325
329
)
326
330
if local_rank == 0 :
327
331
tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
328
332
329
333
else :
330
- eval_ppl , eval_epoch_loss , temp_val_loss , temp_step_perplexity = evaluation_helper (
331
- model , train_config , eval_dataloader , local_rank , tokenizer , device
334
+ eval_epoch_loss , eval_metric , temp_val_loss , temp_step_metric = evaluation_helper (
335
+ model , train_config , eval_dataloader , device
332
336
)
333
337
tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
334
338
335
339
if train_config .save_metrics :
336
340
val_step_loss .extend (temp_val_loss )
337
- val_step_perplexity .extend (temp_step_perplexity )
341
+ val_step_metric .extend (temp_step_metric )
338
342
339
343
# saving the adapters after completion of each epoch
340
344
if train_config .save_model :
@@ -349,14 +353,14 @@ def train(
349
353
best_val_loss = eval_epoch_loss
350
354
print (f"best eval loss on epoch { epoch + 1 } is { best_val_loss } " )
351
355
val_loss .append (float (eval_epoch_loss ))
352
- val_prep .append (float (eval_ppl ))
356
+ val_metric .append (float (eval_metric ))
353
357
if train_config .task_type == "seq_classification" :
354
358
print (
355
- f"Epoch { epoch + 1 } : train_acc={ train_perplexity :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
359
+ f"Epoch { epoch + 1 } : train_acc={ train_metric :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
356
360
)
357
361
else :
358
362
print (
359
- f"Epoch { epoch + 1 } : train_perplexity= { train_perplexity :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
363
+ f"Epoch { epoch + 1 } : train_metric= { train_metric :.4f} , train_epoch_loss={ train_epoch_loss :.4f} , epoch time { epoch_end_time } s"
360
364
)
361
365
362
366
# Saving the results every epoch to plot later
@@ -365,31 +369,25 @@ def train(
365
369
metrics_filename ,
366
370
train_step_loss ,
367
371
train_loss ,
368
- train_step_perplexity ,
369
- train_prep ,
372
+ train_step_metric ,
373
+ train_metric ,
370
374
val_step_loss ,
371
375
val_loss ,
372
- val_step_perplexity ,
373
- val_prep ,
376
+ val_step_metric ,
377
+ val_metric ,
374
378
)
375
379
avg_epoch_time = sum (epoch_times ) / len (epoch_times )
376
380
avg_checkpoint_time = sum (checkpoint_times ) / len (checkpoint_times ) if len (checkpoint_times ) > 0 else 0
377
- avg_train_prep = sum (train_prep ) / len (train_prep )
381
+ avg_train_metric = sum (train_metric ) / len (train_metric )
378
382
avg_train_loss = sum (train_loss ) / len (train_loss )
379
383
if train_config .run_validation :
380
- avg_eval_prep = sum (val_prep ) / len (val_prep )
384
+ avg_eval_metric = sum (val_metric ) / len (val_metric )
381
385
avg_eval_loss = sum (val_loss ) / len (val_loss )
382
386
383
- if train_config .task_type == "seq_classification" :
384
- results ["avg_train_acc" ] = avg_train_prep
385
- else :
386
- results ["avg_train_prep" ] = avg_train_prep
387
+ results ["avg_train_metric" ] = avg_train_metric
387
388
results ["avg_train_loss" ] = avg_train_loss
388
389
if train_config .run_validation :
389
- if train_config .task_type == "seq_classification" :
390
- results ["avg_eval_acc" ] = avg_eval_prep
391
- else :
392
- results ["avg_eval_prep" ] = avg_eval_prep
390
+ results ["avg_eval_metric" ] = avg_eval_metric
393
391
results ["avg_eval_loss" ] = avg_eval_loss
394
392
results ["avg_epoch_time" ] = avg_epoch_time
395
393
results ["avg_checkpoint_time" ] = avg_checkpoint_time
@@ -399,39 +397,40 @@ def train(
399
397
return results
400
398
401
399
402
- def evaluation_ppl (model , train_config , eval_dataloader , local_rank , tokenizer , device ):
400
+ def evaluation_helper (model , train_config , eval_dataloader , device ):
403
401
"""
404
402
Evaluates the model on the given dataloader
405
403
406
404
Args:
407
405
model: The model to evaluate
408
406
eval_dataloader: The dataloader containing the evaluation data
409
- local_rank: The rank of the current node in a distributed setting
410
- tokenizer: The tokenizer used to decode predictions
411
407
412
- Returns: eval_ppl, eval_epoch_loss
408
+ Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
413
409
"""
414
410
model .eval ()
415
411
412
+ if train_config .task_type == "seq_classification" :
413
+ if train_config .enable_ddp :
414
+ num_classes = model .module .classifier .out_features
415
+ else :
416
+ num_classes = model .classifier .out_features
417
+ acc_helper = torchmetrics .classification .MulticlassAccuracy (num_classes = num_classes ).to (device )
418
+
416
419
# special handling for qaic device and dtype
417
420
# model.to(device)
418
421
419
- eval_preds = []
420
422
val_step_loss = []
421
- val_step_perplexity = []
423
+ val_step_metric = []
422
424
423
425
eval_loss = 0.0 # Initialize evaluation loss
424
- total_eval_steps = 0
425
- # max_steps_reached = False # Flag to indicate max eval steps reached
426
426
427
427
for step , batch in enumerate (tqdm (eval_dataloader , colour = "green" , desc = "evaluating Epoch" , dynamic_ncols = True )):
428
- total_eval_steps += 1
429
428
# stop when the maximum number of eval steps is reached
430
- if train_config .max_eval_step > 0 and total_eval_steps > train_config .max_eval_step :
431
- # max_steps_reached = True
429
+ if train_config .max_eval_step > 0 and step > train_config .max_eval_step :
432
430
break
433
431
for key in batch .keys ():
434
432
batch [key ] = batch [key ].to (device )
433
+
435
434
# Ensure no gradients are computed for this scope to save memory
436
435
with torch .no_grad ():
437
436
# Forward pass and compute loss
@@ -441,100 +440,32 @@ def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer,
441
440
outputs = model (** batch )
442
441
loss = outputs .loss
443
442
444
- if train_config .save_metrics :
445
- val_step_loss .append (loss .detach ().float ().item ())
446
- val_step_perplexity .append (float (torch .exp (loss .detach ().float ())))
447
-
448
- eval_loss += loss .detach ().float ()
449
- # Decode predictions and add to evaluation predictions list
450
- preds = torch .argmax (outputs .logits , - 1 )
451
- eval_preds .extend (tokenizer .batch_decode (preds .detach ().cpu ().numpy (), skip_special_tokens = True ))
452
-
453
- # Compute average loss and perplexity
454
- eval_epoch_loss = eval_loss / len (eval_dataloader )
455
- eval_ppl = torch .exp (eval_epoch_loss )
456
-
457
- # Print evaluation metrics
458
- print (f" { eval_ppl .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
459
-
460
- return eval_ppl , eval_epoch_loss , val_step_loss , val_step_perplexity
461
-
462
-
463
- def evaluation_acc (model , train_config , eval_dataloader , local_rank , tokenizer , device ):
464
- """
465
- Evaluates the model on the given dataloader
466
-
467
- Args:
468
- model: The model to evaluate
469
- eval_dataloader: The dataloader containing the evaluation data
470
- local_rank: The rank of the current node in a distributed setting
471
- tokenizer: The tokenizer used to decode predictions
472
-
473
- Returns: eval_acc, eval_epoch_loss
474
- """
475
- model .eval ()
476
- if train_config .enable_ddp :
477
- num_classes = model .module .classifier .out_features
478
- else :
479
- num_classes = model .classifier .out_features
480
-
481
- acc_helper = torchmetrics .classification .MulticlassAccuracy (num_classes = num_classes ).to (device )
482
-
483
- # special handling for qaic device and dtype
484
- # model.to(device)
485
-
486
- # eval_preds = []
487
- val_step_loss = []
488
- val_step_acc = []
489
-
490
- eval_loss = 0.0 # Initialize evaluation loss
491
- total_eval_steps = 0
492
- # max_steps_reached = False # Flag to indicate max eval steps reached
443
+ if train_config .task_type == "seq_classification" :
444
+ logits = outputs .logits
445
+ labels = batch ["labels" ][:, 0 ]
446
+ preds = torch .nn .functional .softmax (logits , dim = - 1 )
447
+ val_acc = acc_helper .forward (preds , labels )
448
+ metric_val = val_acc .detach ().float ().item ()
449
+ else :
450
+ metric_val = float (torch .exp (loss .detach ().float ()))
493
451
494
- for step , batch in enumerate (tqdm (eval_dataloader , colour = "green" , desc = "evaluating Epoch" , dynamic_ncols = True )):
495
- total_eval_steps += 1
496
- # stop when the maximum number of eval steps is reached
497
- if train_config .max_eval_step > 0 and total_eval_steps > train_config .max_eval_step :
498
- # max_steps_reached = True
499
- break
500
- for key in batch .keys ():
501
- batch [key ] = batch [key ].to (device )
502
- # Ensure no gradients are computed for this scope to save memory
503
- with torch .no_grad ():
504
- # Forward pass and compute loss
505
- with (
506
- torch .autocast (device_type = device , dtype = torch .float16 ) if train_config .use_autocast else nullcontext ()
507
- ):
508
- outputs = model (** batch )
509
- loss = outputs .loss
510
- logits = outputs .logits
511
- labels = batch ["labels" ][:, 0 ]
512
452
if train_config .save_metrics :
513
453
val_step_loss .append (loss .detach ().float ().item ())
514
- preds = torch .nn .functional .softmax (logits , dim = - 1 )
515
- val_acc = acc_helper .forward (preds , labels )
516
- val_step_acc .append (val_acc .detach ().float ().item ())
454
+ val_step_metric .append (metric_val )
517
455
518
456
eval_loss += loss .detach ().float ()
519
- # Decode predictions and add to evaluation predictions list
520
- # preds = torch.argmax(outputs.logits, -1)
521
- # eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True))
522
457
523
- # Compute average loss and perplexity
458
+ # Compute average loss and metric
524
459
eval_epoch_loss = eval_loss / len (eval_dataloader )
525
- eval_acc = acc_helper .compute ()
460
+ if train_config .task_type == "seq_classification" :
461
+ eval_metric = acc_helper .compute ()
462
+ else :
463
+ eval_metric = torch .exp (eval_epoch_loss )
526
464
527
465
# Print evaluation metrics
528
- print (f" { eval_acc .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
466
+ print (f" { eval_metric .detach ().cpu ()= } { eval_epoch_loss .detach ().cpu ()= } " )
529
467
530
- return eval_acc , eval_epoch_loss , val_step_loss , val_step_acc
531
-
532
-
533
- def evaluation_helper (model , train_config , eval_dataloader , local_rank , tokenizer , device ):
534
- if train_config .task_type == "seq_classification" :
535
- return evaluation_acc (model , train_config , eval_dataloader , local_rank , tokenizer , device )
536
- else :
537
- return evaluation_ppl (model , train_config , eval_dataloader , local_rank , tokenizer , device )
468
+ return eval_metric , eval_epoch_loss , val_step_loss , val_step_metric
538
469
539
470
540
471
def get_longest_seq_length (data : List [Dict ]) -> Tuple [int , int ]:
@@ -571,22 +502,22 @@ def save_to_json(
571
502
output_filename ,
572
503
train_step_loss ,
573
504
train_epoch_loss ,
574
- train_step_ppl ,
575
- train_epoch_ppl ,
505
+ train_step_metric ,
506
+ train_epoch_metric ,
576
507
val_step_loss ,
577
508
val_epoch_loss ,
578
- val_step_ppl ,
579
- val_epoch_ppl ,
509
+ val_step_metric ,
510
+ val_epoch_metric ,
580
511
):
581
512
metrics_data = {
582
513
"train_step_loss" : train_step_loss ,
583
514
"train_epoch_loss" : train_epoch_loss ,
584
- "train_step_perplexity " : train_step_ppl ,
585
- "train_epoch_perplexity " : train_epoch_ppl ,
515
+ "train_step_metric " : train_step_metric ,
516
+ "train_epoch_metric " : train_epoch_metric ,
586
517
"val_step_loss" : val_step_loss ,
587
518
"val_epoch_loss" : val_epoch_loss ,
588
- "val_step_perplexity " : val_step_ppl ,
589
- "val_epoch_perplexity " : val_epoch_ppl ,
519
+ "val_step_metric " : val_step_metric ,
520
+ "val_epoch_metric " : val_epoch_metric ,
590
521
}
591
522
with open (output_filename , "w" ) as f :
592
523
json .dump (metrics_data , f )
0 commit comments