Skip to content

Commit a8eecf9

Browse files
committed
Updated variable names for eval and some more code cleanup.
Signed-off-by: meetkuma <[email protected]>
1 parent 8cedceb commit a8eecf9

File tree

3 files changed

+104
-106
lines changed

3 files changed

+104
-106
lines changed

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
import logging
8+
79
import datasets
810
import torch
911
import torch.distributed as dist
@@ -66,6 +68,11 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
6668

6769

6870
def padding_dataset(train_config, dataset, batch_size):
71+
num_replicas = get_num_ddp_devices()
72+
remainder = len(dataset) % (num_replicas * batch_size)
73+
if remainder == 0:
74+
return dataset
75+
6976
if train_config.enable_ddp and train_config.enable_sorting_for_ddp:
7077
if isinstance(dataset, datasets.Dataset):
7178
# Hugging Face Dataset transformation
@@ -77,24 +84,26 @@ def padding_dataset(train_config, dataset, batch_size):
7784

7885
dummy_row = next(iter(dataset))
7986
dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"]))
80-
padding_size = 0
81-
num_replicas = get_num_ddp_devices()
82-
remainder = len(dataset) % (num_replicas * batch_size)
83-
padding_size = (num_replicas * batch_size) - remainder
8487

88+
padding_size = (num_replicas * batch_size) - remainder
8589
dummy_data = [dummy_row.copy() for _ in range(padding_size)]
8690
dummy_dataset = datasets.Dataset.from_list(dummy_data)
8791
if isinstance(dataset, datasets.Dataset):
8892
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
8993
else:
9094
combined_dataset = dataset + list(dummy_dataset)
95+
96+
logger.log_rank_zero("Padding dataset to make it divisible by batch_size * num_devices.", logging.DEBUG)
97+
logger.log_rank_zero(f"Length of dataset before padding: {len(dataset)}", logging.DEBUG)
98+
logger.log_rank_zero(f"Length of dataset after padding: {len(combined_dataset)}", logging.DEBUG)
9199
return combined_dataset
92100

93101

94102
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
95103
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)
96104

97105
batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size
106+
98107
dataset = padding_dataset(train_config, dataset, batch_size)
99108

100109
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)

QEfficient/finetune/utils/train_utils.py

Lines changed: 71 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22-
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx, is_rank_zero, get_num_ddp_devices
22+
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_num_ddp_devices, get_op_verifier_ctx, is_rank_zero
2323
from QEfficient.finetune.utils.logging_utils import logger
2424

2525
try:
@@ -63,8 +63,8 @@ def train(
6363

6464
train_metric = []
6565
train_loss = []
66-
val_metric = []
67-
val_loss = []
66+
eval_metric = []
67+
eval_loss = []
6868

6969
if train_config.save_metrics:
7070
if not os.path.exists(train_config.output_dir):
@@ -74,13 +74,13 @@ def train(
7474
)
7575
train_step_metric = []
7676
train_step_loss = []
77-
val_step_loss = []
78-
val_step_metric = []
77+
eval_step_loss = []
78+
eval_step_metric = []
7979

8080
epoch_times = []
8181
checkpoint_times = []
8282
results = {}
83-
best_val_loss = float("inf")
83+
best_eval_loss = float("inf")
8484
total_train_steps = 0
8585
max_steps_reached = False # Flag to indicate max training steps reached
8686

@@ -130,7 +130,6 @@ def train(
130130
continue
131131

132132
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
133-
logger.log_rank_zero(f"train_config.max_train_step: {train_config.max_train_step}")
134133
# stop when the maximum number of training steps is reached
135134
if max_steps_reached:
136135
break
@@ -207,23 +206,21 @@ def train(
207206
total_loss += loss.detach().float()
208207

209208
if is_rank_zero():
209+
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
210210
if loss <= train_config.convergence_loss:
211211
loss_0_counter += 1
212212
else:
213213
loss_0_counter = torch.tensor([0]).to(device)
214214
if train_config.enable_ddp:
215215
dist.broadcast(loss_0_counter, src=0)
216216

217-
if is_rank_zero():
218-
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
219-
220217
if train_config.save_metrics:
221218
train_step_loss.append(loss.detach().float().item())
222219
if train_config.task_type == "seq_classification":
223-
step_metric_val = float(acc_helper.compute())
220+
step_metric_value = float(acc_helper.compute())
224221
else:
225-
step_metric_val = float(torch.exp(loss.detach().float()))
226-
train_step_metric.append(step_metric_val)
222+
step_metric_value = float(torch.exp(loss.detach().float()))
223+
train_step_metric.append(step_metric_value)
227224

228225
# Accumalate gradients
229226
complete_accum_steps = (
@@ -271,10 +268,10 @@ def train(
271268
train_loss,
272269
train_step_metric,
273270
train_metric,
274-
val_step_loss,
275-
val_loss,
276-
val_step_metric,
277-
val_metric,
271+
eval_step_loss,
272+
eval_loss,
273+
eval_step_metric,
274+
eval_metric,
278275
)
279276
if loss_0_counter.item() == train_config.convergence_counter:
280277
logger.log_rank_zero(
@@ -286,32 +283,19 @@ def train(
286283
epoch_end_time = time.perf_counter() - epoch_start_time
287284
epoch_times.append(epoch_end_time)
288285

289-
if loss_0_counter.item() == train_config.convergence_counter:
290-
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
291-
train_epoch_loss = (
292-
0.0
293-
if total_loss == 0.0
294-
else total_loss / (step - intermediate_step - num_dummy_samples / train_config.train_batch_size)
295-
)
296-
else:
297-
train_epoch_loss = (
298-
0.0
299-
if total_loss == 0.0
300-
else total_loss / (step + 1 - num_dummy_samples / train_config.train_batch_size)
301-
)
286+
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
287+
train_epoch_loss = (
288+
0.0
289+
if total_loss == 0.0
290+
else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size))
291+
)
302292
else:
303-
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
304-
train_epoch_loss = (
305-
0.0
306-
if total_loss == 0.0
307-
else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size))
308-
)
309-
else:
310-
train_epoch_loss = (
311-
0.0
312-
if total_loss == 0.0
313-
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
314-
)
293+
train_epoch_loss = (
294+
0.0
295+
if total_loss == 0.0
296+
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
297+
)
298+
315299
if train_config.task_type == "seq_classification":
316300
train_epoch_metric = acc_helper.compute()
317301
acc_helper.reset()
@@ -331,30 +315,30 @@ def train(
331315
lr_scheduler.step()
332316

333317
if train_config.run_validation:
334-
eval_loss, eval_metric, step_loss, step_metric = evaluation_helper(
318+
eval_epoch_loss, eval_epoch_metric, step_loss, step_metric = evaluation_helper(
335319
model, train_config, eval_dataloader, device
336320
)
337321

338-
if eval_loss < best_val_loss:
339-
best_val_loss = eval_loss
340-
logger.log_rank_zero(f"Best eval loss on epoch {epoch + 1} is {best_val_loss:.4f}")
322+
if eval_epoch_loss < best_eval_loss:
323+
best_eval_loss = eval_epoch_loss
324+
logger.log_rank_zero(f"Best eval loss on epoch {epoch + 1} is {best_eval_loss:.4f}")
341325

342326
if is_rank_zero():
343-
tensorboard_updates.add_scalars("loss", {"eval": eval_loss}, total_train_steps)
327+
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
344328
if train_config.save_metrics:
345-
val_step_loss.extend(step_loss)
346-
val_step_metric.extend(step_metric)
347-
val_loss.append(float(eval_loss))
348-
val_metric.append(float(eval_metric))
329+
eval_step_loss.extend(step_loss)
330+
eval_step_metric.extend(step_metric)
331+
eval_loss.append(float(eval_epoch_loss))
332+
eval_metric.append(float(eval_epoch_metric))
349333

350334
if train_config.enable_ddp:
351-
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
352-
eval_loss /= get_num_ddp_devices()
353-
dist.all_reduce(eval_metric, op=dist.ReduceOp.SUM)
354-
eval_metric /= get_num_ddp_devices()
335+
dist.all_reduce(eval_epoch_loss, op=dist.ReduceOp.SUM)
336+
eval_epoch_loss /= get_num_ddp_devices()
337+
dist.all_reduce(eval_epoch_metric, op=dist.ReduceOp.SUM)
338+
eval_epoch_metric /= get_num_ddp_devices()
355339

356340
logger.log_rank_zero(
357-
f"Epoch {epoch + 1}: Eval Loss: {eval_loss.detach().cpu():.4f}, Eval metric: {eval_metric.detach().cpu():.4f}"
341+
f"Epoch {epoch + 1}: Eval Loss: {eval_epoch_loss.detach().cpu():.4f}, Eval metric: {eval_epoch_metric.detach().cpu():.4f}"
358342
)
359343

360344
# saving the adapters after completion of each epoch
@@ -377,19 +361,19 @@ def train(
377361
train_loss,
378362
train_step_metric,
379363
train_metric,
380-
val_step_loss,
381-
val_loss,
382-
val_step_metric,
383-
val_metric,
364+
eval_step_loss,
365+
eval_loss,
366+
eval_step_metric,
367+
eval_metric,
384368
)
385369
avg_epoch_time = sum(epoch_times) / len(epoch_times)
386370
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
387371

388-
results["last_epoch_train_loss"] = train_epoch_loss
389-
results["last_epoch_train_metric"] = train_epoch_metric
372+
results["last_epoch_train_loss"] = train_epoch_loss.cpu()
373+
results["last_epoch_train_metric"] = train_epoch_metric.cpu()
390374
if train_config.run_validation:
391-
results["last_epoch_eval_loss"] = eval_loss
392-
results["last_epoch_eval_metric"] = eval_metric
375+
results["last_epoch_eval_loss"] = eval_epoch_loss.cpu()
376+
results["last_epoch_eval_metric"] = eval_epoch_metric.cpu()
393377
results["avg_epoch_time"] = avg_epoch_time
394378
results["avg_checkpoint_time"] = avg_checkpoint_time
395379
if train_config.save_metrics:
@@ -405,7 +389,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
405389
model: The model to evaluate
406390
eval_dataloader: The dataloader containing the evaluation data
407391
408-
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
392+
Returns: eval_epoch_loss, eval_epoch_metric, eval_step_loss, eval_step_metric
409393
"""
410394
if train_config.enable_ddp:
411395
dist.barrier()
@@ -422,8 +406,8 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
422406
# special handling for qaic device and dtype
423407
# model.to(device)
424408

425-
val_step_loss = []
426-
val_step_metric = []
409+
eval_step_loss = []
410+
eval_step_metric = []
427411

428412
eval_loss = torch.tensor(0.0, dtype=torch.float32, device=device) # Initialize evaluation loss
429413
device_type = torch.device(device).type
@@ -459,24 +443,27 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
459443
logits = outputs.logits
460444
labels = batch["labels"][:, 0]
461445
preds = torch.nn.functional.softmax(logits, dim=-1)
462-
val_acc = acc_helper.forward(preds, labels)
463-
metric_val = val_acc.detach().float().item()
446+
eval_acc = acc_helper.forward(preds, labels)
447+
metric_value = eval_acc.detach().float().item()
464448
else:
465-
metric_val = float(torch.exp(loss.detach().float()))
449+
metric_value = float(torch.exp(loss.detach().float()))
466450

467451
if train_config.save_metrics:
468-
val_step_loss.append(loss.detach().float().item())
469-
val_step_metric.append(metric_val)
452+
eval_step_loss.append(loss.detach().float().item())
453+
eval_step_metric.append(metric_value)
470454

471455
eval_loss += loss.detach().float()
456+
472457
# Compute average loss and metric
473-
eval_loss = 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size)
458+
eval_epoch_loss = (
459+
0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size)
460+
)
474461
if train_config.task_type == "seq_classification":
475-
eval_metric = acc_helper.compute()
462+
eval_epoch_metric = acc_helper.compute()
476463
else:
477-
eval_metric = torch.exp(eval_loss)
464+
eval_epoch_metric = torch.exp(eval_epoch_loss)
478465

479-
return eval_loss, eval_metric, val_step_loss, val_step_metric
466+
return eval_epoch_loss, eval_epoch_metric, eval_step_loss, eval_step_metric
480467

481468

482469
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
@@ -517,20 +504,20 @@ def save_to_json(
517504
train_epoch_loss,
518505
train_step_metric,
519506
train_epoch_metric,
520-
val_step_loss,
521-
val_epoch_loss,
522-
val_step_metric,
523-
val_epoch_metric,
507+
eval_step_loss,
508+
eval_epoch_loss,
509+
eval_step_metric,
510+
eval_epoch_metric,
524511
):
525512
metrics_data = {
526513
"train_step_loss": train_step_loss,
527514
"train_epoch_loss": train_epoch_loss,
528515
"train_step_metric": train_step_metric,
529516
"train_epoch_metric": train_epoch_metric,
530-
"val_step_loss": val_step_loss,
531-
"val_epoch_loss": val_epoch_loss,
532-
"val_step_metric": val_step_metric,
533-
"val_epoch_metric": val_epoch_metric,
517+
"eval_step_loss": eval_step_loss,
518+
"eval_epoch_loss": eval_epoch_loss,
519+
"eval_step_metric": eval_step_metric,
520+
"eval_epoch_metric": eval_epoch_metric,
534521
}
535522
with open(output_filename, "w") as f:
536523
json.dump(metrics_data, f)

0 commit comments

Comments
 (0)