Skip to content

[QEff. Finetune] Fixed reporting of single value of loss and ppl across devices. #496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 16, 2025
17 changes: 13 additions & 4 deletions QEfficient/finetune/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
import logging

import datasets
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -66,6 +68,11 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):


def padding_dataset(train_config, dataset, batch_size):
num_replicas = get_num_ddp_devices()
remainder = len(dataset) % (num_replicas * batch_size)
if remainder == 0:
return dataset

if train_config.enable_ddp and train_config.enable_sorting_for_ddp:
if isinstance(dataset, datasets.Dataset):
# Hugging Face Dataset transformation
Expand All @@ -77,24 +84,26 @@ def padding_dataset(train_config, dataset, batch_size):

dummy_row = next(iter(dataset))
dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"]))
padding_size = 0
num_replicas = get_num_ddp_devices()
remainder = len(dataset) % (num_replicas * batch_size)
padding_size = (num_replicas * batch_size) - remainder

padding_size = (num_replicas * batch_size) - remainder
dummy_data = [dummy_row.copy() for _ in range(padding_size)]
dummy_dataset = datasets.Dataset.from_list(dummy_data)
if isinstance(dataset, datasets.Dataset):
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
else:
combined_dataset = dataset + list(dummy_dataset)

logger.log_rank_zero("Padding dataset to make it divisible by batch_size * num_devices.", logging.DEBUG)
logger.log_rank_zero(f"Length of dataset before padding: {len(dataset)}", logging.DEBUG)
logger.log_rank_zero(f"Length of dataset after padding: {len(combined_dataset)}", logging.DEBUG)
return combined_dataset


def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)

batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size

dataset = padding_dataset(train_config, dataset, batch_size)

dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
Expand Down
Loading
Loading