Skip to content

Commit 079c6dc

Browse files
committed
Refactored evaluation fun and renamed variable names to generic names.
1 parent ad3c27f commit 079c6dc

File tree

1 file changed

+71
-140
lines changed

1 file changed

+71
-140
lines changed

QEfficient/finetune/utils/train_utils.py

+71-140
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def train(
6262
6363
Returns: results dictionary containing average training and validation perplexity and loss
6464
"""
65-
train_prep = []
65+
train_metric = []
6666
train_loss = []
67-
val_prep = []
67+
val_metric = []
6868
val_loss = []
6969

7070
if train_config.save_metrics:
@@ -73,10 +73,10 @@ def train(
7373
metrics_filename = (
7474
f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
7575
)
76-
train_step_perplexity = []
76+
train_step_metric = []
7777
train_step_loss = []
7878
val_step_loss = []
79-
val_step_perplexity = []
79+
val_step_metric = []
8080

8181
epoch_times = []
8282
checkpoint_times = []
@@ -106,10 +106,10 @@ def train(
106106

107107
acc_helper = None
108108
if train_config.task_type == "seq_classification":
109-
if local_rank is None:
110-
num_classes = model.classifier.out_features
111-
else:
109+
if train_config.enable_ddp:
112110
num_classes = model.module.classifier.out_features
111+
else:
112+
num_classes = model.classifier.out_features
113113
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
114114

115115
# Start the training loop
@@ -231,7 +231,11 @@ def train(
231231

232232
if train_config.save_metrics:
233233
train_step_loss.append(loss.detach().float().item())
234-
train_step_perplexity.append(float(torch.exp(loss.detach().float())))
234+
if train_config.task_type == "seq_classification":
235+
step_metric_val = acc_helper.compute()
236+
else:
237+
step_metric_val = float(torch.exp(loss.detach().float()))
238+
train_step_metric.append(step_metric_val)
235239

236240
if train_config.grad_scaler:
237241
scaler.scale(loss).backward() # backward pass
@@ -266,12 +270,12 @@ def train(
266270
metrics_filename,
267271
train_step_loss,
268272
train_loss,
269-
train_step_perplexity,
270-
train_prep,
273+
train_step_metric,
274+
train_metric,
271275
val_step_loss,
272276
val_loss,
273-
val_step_perplexity,
274-
val_prep,
277+
val_step_metric,
278+
val_metric,
275279
)
276280
if train_config.enable_ddp:
277281
if loss_0_counter.item() == train_config.convergence_counter:
@@ -307,11 +311,11 @@ def train(
307311
if train_config.enable_ddp:
308312
dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
309313
accuracy /= dist.get_world_size()
310-
train_perplexity = accuracy
314+
train_metric = accuracy
311315
else:
312-
train_perplexity = torch.exp(train_epoch_loss)
316+
train_metric = torch.exp(train_epoch_loss)
313317

314-
train_prep.append(float(train_perplexity))
318+
train_metric.append(float(train_metric))
315319
train_loss.append(float(train_epoch_loss))
316320

317321
# Update the learning rate as needed
@@ -320,21 +324,21 @@ def train(
320324
if train_config.run_validation:
321325
if train_config.enable_ddp:
322326
dist.barrier()
323-
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper(
324-
model, train_config, eval_dataloader, local_rank, tokenizer, device
327+
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
328+
model, train_config, eval_dataloader, device
325329
)
326330
if local_rank == 0:
327331
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
328332

329333
else:
330-
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation_helper(
331-
model, train_config, eval_dataloader, local_rank, tokenizer, device
334+
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
335+
model, train_config, eval_dataloader, device
332336
)
333337
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
334338

335339
if train_config.save_metrics:
336340
val_step_loss.extend(temp_val_loss)
337-
val_step_perplexity.extend(temp_step_perplexity)
341+
val_step_metric.extend(temp_step_metric)
338342

339343
# saving the adapters after completion of each epoch
340344
if train_config.save_model:
@@ -349,14 +353,14 @@ def train(
349353
best_val_loss = eval_epoch_loss
350354
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
351355
val_loss.append(float(eval_epoch_loss))
352-
val_prep.append(float(eval_ppl))
356+
val_metric.append(float(eval_metric))
353357
if train_config.task_type == "seq_classification":
354358
print(
355-
f"Epoch {epoch + 1}: train_acc={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
359+
f"Epoch {epoch + 1}: train_acc={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
356360
)
357361
else:
358362
print(
359-
f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
363+
f"Epoch {epoch + 1}: train_metric={train_metric:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
360364
)
361365

362366
# Saving the results every epoch to plot later
@@ -365,31 +369,25 @@ def train(
365369
metrics_filename,
366370
train_step_loss,
367371
train_loss,
368-
train_step_perplexity,
369-
train_prep,
372+
train_step_metric,
373+
train_metric,
370374
val_step_loss,
371375
val_loss,
372-
val_step_perplexity,
373-
val_prep,
376+
val_step_metric,
377+
val_metric,
374378
)
375379
avg_epoch_time = sum(epoch_times) / len(epoch_times)
376380
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
377-
avg_train_prep = sum(train_prep) / len(train_prep)
381+
avg_train_metric = sum(train_metric) / len(train_metric)
378382
avg_train_loss = sum(train_loss) / len(train_loss)
379383
if train_config.run_validation:
380-
avg_eval_prep = sum(val_prep) / len(val_prep)
384+
avg_eval_metric = sum(val_metric) / len(val_metric)
381385
avg_eval_loss = sum(val_loss) / len(val_loss)
382386

383-
if train_config.task_type == "seq_classification":
384-
results["avg_train_acc"] = avg_train_prep
385-
else:
386-
results["avg_train_prep"] = avg_train_prep
387+
results["avg_train_metric"] = avg_train_metric
387388
results["avg_train_loss"] = avg_train_loss
388389
if train_config.run_validation:
389-
if train_config.task_type == "seq_classification":
390-
results["avg_eval_acc"] = avg_eval_prep
391-
else:
392-
results["avg_eval_prep"] = avg_eval_prep
390+
results["avg_eval_metric"] = avg_eval_metric
393391
results["avg_eval_loss"] = avg_eval_loss
394392
results["avg_epoch_time"] = avg_epoch_time
395393
results["avg_checkpoint_time"] = avg_checkpoint_time
@@ -399,39 +397,40 @@ def train(
399397
return results
400398

401399

402-
def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device):
400+
def evaluation_helper(model, train_config, eval_dataloader, device):
403401
"""
404402
Evaluates the model on the given dataloader
405403
406404
Args:
407405
model: The model to evaluate
408406
eval_dataloader: The dataloader containing the evaluation data
409-
local_rank: The rank of the current node in a distributed setting
410-
tokenizer: The tokenizer used to decode predictions
411407
412-
Returns: eval_ppl, eval_epoch_loss
408+
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
413409
"""
414410
model.eval()
415411

412+
if train_config.task_type == "seq_classification":
413+
if train_config.enable_ddp:
414+
num_classes = model.module.classifier.out_features
415+
else:
416+
num_classes = model.classifier.out_features
417+
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
418+
416419
# special handling for qaic device and dtype
417420
# model.to(device)
418421

419-
eval_preds = []
420422
val_step_loss = []
421-
val_step_perplexity = []
423+
val_step_metric = []
422424

423425
eval_loss = 0.0 # Initialize evaluation loss
424-
total_eval_steps = 0
425-
# max_steps_reached = False # Flag to indicate max eval steps reached
426426

427427
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
428-
total_eval_steps += 1
429428
# stop when the maximum number of eval steps is reached
430-
if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
431-
# max_steps_reached = True
429+
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
432430
break
433431
for key in batch.keys():
434432
batch[key] = batch[key].to(device)
433+
435434
# Ensure no gradients are computed for this scope to save memory
436435
with torch.no_grad():
437436
# Forward pass and compute loss
@@ -441,100 +440,32 @@ def evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer,
441440
outputs = model(**batch)
442441
loss = outputs.loss
443442

444-
if train_config.save_metrics:
445-
val_step_loss.append(loss.detach().float().item())
446-
val_step_perplexity.append(float(torch.exp(loss.detach().float())))
447-
448-
eval_loss += loss.detach().float()
449-
# Decode predictions and add to evaluation predictions list
450-
preds = torch.argmax(outputs.logits, -1)
451-
eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True))
452-
453-
# Compute average loss and perplexity
454-
eval_epoch_loss = eval_loss / len(eval_dataloader)
455-
eval_ppl = torch.exp(eval_epoch_loss)
456-
457-
# Print evaluation metrics
458-
print(f" {eval_ppl.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
459-
460-
return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
461-
462-
463-
def evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device):
464-
"""
465-
Evaluates the model on the given dataloader
466-
467-
Args:
468-
model: The model to evaluate
469-
eval_dataloader: The dataloader containing the evaluation data
470-
local_rank: The rank of the current node in a distributed setting
471-
tokenizer: The tokenizer used to decode predictions
472-
473-
Returns: eval_acc, eval_epoch_loss
474-
"""
475-
model.eval()
476-
if train_config.enable_ddp:
477-
num_classes = model.module.classifier.out_features
478-
else:
479-
num_classes = model.classifier.out_features
480-
481-
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
482-
483-
# special handling for qaic device and dtype
484-
# model.to(device)
485-
486-
# eval_preds = []
487-
val_step_loss = []
488-
val_step_acc = []
489-
490-
eval_loss = 0.0 # Initialize evaluation loss
491-
total_eval_steps = 0
492-
# max_steps_reached = False # Flag to indicate max eval steps reached
443+
if train_config.task_type == "seq_classification":
444+
logits = outputs.logits
445+
labels = batch["labels"][:, 0]
446+
preds = torch.nn.functional.softmax(logits, dim=-1)
447+
val_acc = acc_helper.forward(preds, labels)
448+
metric_val = val_acc.detach().float().item()
449+
else:
450+
metric_val = float(torch.exp(loss.detach().float()))
493451

494-
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
495-
total_eval_steps += 1
496-
# stop when the maximum number of eval steps is reached
497-
if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
498-
# max_steps_reached = True
499-
break
500-
for key in batch.keys():
501-
batch[key] = batch[key].to(device)
502-
# Ensure no gradients are computed for this scope to save memory
503-
with torch.no_grad():
504-
# Forward pass and compute loss
505-
with (
506-
torch.autocast(device_type=device, dtype=torch.float16) if train_config.use_autocast else nullcontext()
507-
):
508-
outputs = model(**batch)
509-
loss = outputs.loss
510-
logits = outputs.logits
511-
labels = batch["labels"][:, 0]
512452
if train_config.save_metrics:
513453
val_step_loss.append(loss.detach().float().item())
514-
preds = torch.nn.functional.softmax(logits, dim=-1)
515-
val_acc = acc_helper.forward(preds, labels)
516-
val_step_acc.append(val_acc.detach().float().item())
454+
val_step_metric.append(metric_val)
517455

518456
eval_loss += loss.detach().float()
519-
# Decode predictions and add to evaluation predictions list
520-
# preds = torch.argmax(outputs.logits, -1)
521-
# eval_preds.extend(tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True))
522457

523-
# Compute average loss and perplexity
458+
# Compute average loss and metric
524459
eval_epoch_loss = eval_loss / len(eval_dataloader)
525-
eval_acc = acc_helper.compute()
460+
if train_config.task_type == "seq_classification":
461+
eval_metric = acc_helper.compute()
462+
else:
463+
eval_metric = torch.exp(eval_epoch_loss)
526464

527465
# Print evaluation metrics
528-
print(f" {eval_acc.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
466+
print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
529467

530-
return eval_acc, eval_epoch_loss, val_step_loss, val_step_acc
531-
532-
533-
def evaluation_helper(model, train_config, eval_dataloader, local_rank, tokenizer, device):
534-
if train_config.task_type == "seq_classification":
535-
return evaluation_acc(model, train_config, eval_dataloader, local_rank, tokenizer, device)
536-
else:
537-
return evaluation_ppl(model, train_config, eval_dataloader, local_rank, tokenizer, device)
468+
return eval_metric, eval_epoch_loss, val_step_loss, val_step_metric
538469

539470

540471
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
@@ -571,22 +502,22 @@ def save_to_json(
571502
output_filename,
572503
train_step_loss,
573504
train_epoch_loss,
574-
train_step_ppl,
575-
train_epoch_ppl,
505+
train_step_metric,
506+
train_epoch_metric,
576507
val_step_loss,
577508
val_epoch_loss,
578-
val_step_ppl,
579-
val_epoch_ppl,
509+
val_step_metric,
510+
val_epoch_metric,
580511
):
581512
metrics_data = {
582513
"train_step_loss": train_step_loss,
583514
"train_epoch_loss": train_epoch_loss,
584-
"train_step_perplexity": train_step_ppl,
585-
"train_epoch_perplexity": train_epoch_ppl,
515+
"train_step_metric": train_step_metric,
516+
"train_epoch_metric": train_epoch_metric,
586517
"val_step_loss": val_step_loss,
587518
"val_epoch_loss": val_epoch_loss,
588-
"val_step_perplexity": val_step_ppl,
589-
"val_epoch_perplexity": val_epoch_ppl,
519+
"val_step_metric": val_step_metric,
520+
"val_epoch_metric": val_epoch_metric,
590521
}
591522
with open(output_filename, "w") as f:
592523
json.dump(metrics_data, f)

0 commit comments

Comments
 (0)