Skip to content

[QEff Finetune]: Enable --help for finetune CLI #392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 10 additions & 56 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import warnings
from typing import Any, Dict, Optional, Union

import fire
import numpy as np
import torch
import torch.distributed as dist
Expand All @@ -24,13 +23,10 @@
from QEfficient.finetune.utils.config_utils import (
generate_dataset_config,
generate_peft_config,
get_dataloader_kwargs,
update_config,
)
from QEfficient.finetune.utils.dataset_utils import (
get_custom_data_collator,
get_preprocessed_dataset,
)
from QEfficient.finetune.utils.dataset_utils import get_dataloader
from QEfficient.finetune.utils.parser import get_finetune_parser
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
from QEfficient.utils._utils import login_and_download_hf_lm

Expand Down Expand Up @@ -180,11 +176,11 @@ def apply_peft(
kwargs: Additional arguments to override PEFT config params.

Returns:
Union[AutoModel, PeftModel]: If the use_peft in train_config is True
Union[AutoModel, PeftModel]: If the peft_method in train_config is set to lora
then PeftModel object is returned else original model object
(AutoModel) is returned.
"""
if not train_config.use_peft:
if not train_config.peft_method != "lora":
return model

# Load the pre-trained peft model checkpoint and setup its configuration
Expand Down Expand Up @@ -226,58 +222,13 @@ def setup_dataloaders(
- Applies a custom data collator if provided by get_custom_data_collator.
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
"""
# Get the dataset utils
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
dataset_processer, dataset_config, split="train", context_length=train_config.context_length
)

dataset_val = get_preprocessed_dataset(
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
)

# TODO: vbaddi, check if its necessary to do this?
# dataset_train = ConcatDataset(
# dataset_train, chunk_size=train_config.context_length
# )
##
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
print("length of dataset_train", len(dataset_train))

# FIXME (Meet): Add custom data collator registration from the outside by the user.
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
if custom_data_collator:
print("custom_data_collator is used")
train_dl_kwargs["collate_fn"] = custom_data_collator

# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**train_dl_kwargs,
)
train_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="train")
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")

eval_dataloader = None
if train_config.run_validation:
# if train_config.batching_strategy == "packing":
# dataset_val = ConcatDataset(
# dataset_val, chunk_size=train_config.context_length
# )

val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
if custom_data_collator:
val_dl_kwargs["collate_fn"] = custom_data_collator

eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**val_dl_kwargs,
)
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")
if len(eval_dataloader) == 0:
raise ValueError(
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
Expand Down Expand Up @@ -354,4 +305,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:


if __name__ == "__main__":
fire.Fire(main)
parser = get_finetune_parser()
args = parser.parse_args()
args_dict = vars(args)
main(**args_dict)
7 changes: 0 additions & 7 deletions QEfficient/finetune/configs/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,3 @@ class LoraConfig:
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False # should be False for finetuning


# CAUTION prefix tuning is currently not supported
@dataclass
class PrefixConfig:
num_virtual_tokens: int = 30
task_type: str = "CAUSAL_LM"
25 changes: 7 additions & 18 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from dataclasses import dataclass


Expand All @@ -16,7 +17,8 @@ class TrainConfig:
model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B").
tokenizer_name (str): Name of the tokenizer (defaults to model_name if None).
run_validation (bool): Whether to run validation during training (default: True).
batch_size_training (int): Batch size for training (default: 1).
train_batch_size (int): Batch size for training (default: 1).
val_batch_size (int): Batch size for validation (default: 1).
context_length (Optional[int]): Maximum sequence length for inputs (default: None).
gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4).
gradient checkpointing (bool): Enable gradient checkpointing to save the memory by compromising the speed. (default: False).
Expand All @@ -29,17 +31,11 @@ class TrainConfig:
weight_decay (float): Weight decay for optimizer (default: 0.0).
gamma (float): Learning rate decay factor (default: 0.85).
seed (int): Random seed for reproducibility (default: 42).
use_fp16 (bool): Use mixed precision training (default: True).
use_autocast (bool): Use autocast for mixed precision (default: True).
val_batch_size (int): Batch size for validation (default: 1).
dataset (str): Dataset name for training (default: "samsum_dataset").
task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
use_peft (bool): Whether to use PEFT (default: True).
from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
output_dir (str): Directory to save outputs (default: "meta-llama-samsum").
num_freeze_layers (int): Number of layers to freeze (default: 1).
one_qaic (bool): Use single QAIC device (default: False).
save_model (bool): Save the trained model (default: True).
save_metrics (bool): Save training metrics (default: True).
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
Expand All @@ -50,15 +46,15 @@ class TrainConfig:
use_profiler (bool): Enable profiling (default: False).
enable_ddp (bool): Enable distributed data parallel (default: False).
dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo").
grad_scaler (bool): Use gradient scaler (default: True).
dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_").
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
"""

model_name: str = "meta-llama/Llama-3.2-1B"
tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name
run_validation: bool = True
batch_size_training: int = 1
train_batch_size: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are changing this param, may be an internal announcement is required. SIT team's testing commands might depend on this one.

val_batch_size: int = 1
context_length: int = None
gradient_accumulation_steps: int = 4
gradient_checkpointing: bool = False
Expand All @@ -71,17 +67,11 @@ class TrainConfig:
weight_decay: float = 0.0
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
seed: int = 42
use_fp16: bool = True
use_autocast: bool = True
val_batch_size: int = 1
dataset = "samsum_dataset"
task_type = "generation" # "generation" / "seq_classification"
peft_method: str = "lora"
use_peft: bool = True # use parameter efficient fine tuning
from_peft_checkpoint: str = "" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
from_peft_checkpoint: str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
output_dir: str = "meta-llama-samsum"
num_freeze_layers: int = 1
one_qaic: bool = False
save_model: bool = True
save_metrics: bool = True # saves training metrics to a json file for later plotting
intermediate_step_save: int = 1000
Expand All @@ -102,6 +92,5 @@ class TrainConfig:
enable_ddp: bool = False
dist_backend: str = "cpu:gloo,qaic:qccl,cuda:gloo"

grad_scaler: bool = True
dump_root_dir: str = "meta-llama-samsum-mismatches/step_"
dump_root_dir: str = "mismatches/step_"
opByOpVerifier: bool = False
39 changes: 8 additions & 31 deletions QEfficient/finetune/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

import os
import random
import warnings

Expand All @@ -13,15 +14,8 @@
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.config_utils import (
generate_dataset_config,
get_dataloader_kwargs,
update_config,
)
from utils.dataset_utils import (
get_custom_data_collator,
get_preprocessed_dataset,
)
from utils.config_utils import generate_dataset_config, update_config
from utils.dataset_utils import get_dataloader
from utils.train_utils import evaluation, print_model_size

from QEfficient.finetune.configs.training import TrainConfig
Expand All @@ -42,18 +36,20 @@ def main(**kwargs):
# update the configuration for the training process
train_config = TrainConfig()
update_config(train_config, **kwargs)
dataset_config = generate_dataset_config(train_config.dataset)
update_config(dataset_config, **kwargs)

# Set the seeds for reproducibility
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)
np.random.seed(train_config.seed)

# Load the pre-trained model and setup its configuration
# config = AutoConfig.from_pretrained(train_config.model_name)
save_dir = "meta-llama-samsum/trained_weights/step_14000"
save_dir = os.path.join(train_config.output_dir, "complete_epoch_1")

# Load PEFT model on CPU
model_peft = AutoPeftModelForCausalLM.from_pretrained(save_dir)

# Merge LoRA and base model and save
merged_model = model_peft.merge_and_unload()
merged_model.save_pretrained(train_config.output_dir, safe_serialization=True)
Expand Down Expand Up @@ -82,32 +78,13 @@ def main(**kwargs):

print_model_size(model, train_config)

# Get the dataset utils
dataset_config = generate_dataset_config(train_config, kwargs)
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation
dataset_val = get_preprocessed_dataset(
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
)

eval_dataloader = None
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
if train_config.run_validation:
# TODO: vbaddi enable packing later in entire infra.
# if train_config.batching_strategy == "packing":
# dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)

val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
if custom_data_collator:
val_dl_kwargs["collate_fn"] = custom_data_collator
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")

eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**val_dl_kwargs,
)
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
if len(eval_dataloader) == 0:
raise ValueError(
Expand Down
48 changes: 3 additions & 45 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,19 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import inspect
import json
import os
from dataclasses import asdict
from typing import Any, Dict

import torch.distributed as dist
import torch.utils.data as data_utils
import yaml
from peft import (
AdaptionPromptConfig,
PrefixTuningConfig,
)
from peft import LoraConfig as PeftLoraConfig
from transformers.data import DataCollatorForSeq2Seq

import QEfficient.finetune.configs.dataset_config as datasets
from QEfficient.finetune.configs.peft_config import LoraConfig, PrefixConfig
from QEfficient.finetune.configs.peft_config import LoraConfig
from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC


Expand Down Expand Up @@ -75,12 +68,7 @@ def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None
validate_config(peft_config_data, config_type="lora")
peft_config = PeftLoraConfig(**peft_config_data)
else:
config_map = {
"lora": (LoraConfig, PeftLoraConfig),
"prefix": (PrefixConfig, PrefixTuningConfig),
"adaption_prompt": (None, AdaptionPromptConfig),
}

config_map = {"lora": (LoraConfig, PeftLoraConfig)}
if train_config.peft_method not in config_map:
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")

Expand Down Expand Up @@ -115,36 +103,6 @@ def generate_dataset_config(dataset_name: str) -> Any:
return dataset_config


def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
if train_config.enable_ddp:
if train_config.enable_sorting_for_ddp:
if train_config.context_length:
raise ValueError(
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
)
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=False,
)
else:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
else:
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
return kwargs


def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None:
"""Validate the provided YAML/JSON configuration for required fields and types.

Expand Down
Loading
Loading