Skip to content

Commit 5ac2508

Browse files
quic-mamtaMamta Singh
authored and
Mamta Singh
committed
Address comments
Signed-off-by: Mamta Singh <[email protected]>
1 parent 4753d9e commit 5ac2508

File tree

5 files changed

+20
-36
lines changed

5 files changed

+20
-36
lines changed

QEfficient/cloud/finetune.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import logging
89
import random
910
import warnings
1011
from typing import Any, Dict, Optional, Union
@@ -18,7 +19,7 @@
1819
import torch.utils.data
1920
from peft import PeftModel, get_peft_model
2021
from torch.optim.lr_scheduler import StepLR
21-
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
22+
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
2223

2324
from QEfficient.finetune.configs.training import TrainConfig
2425
from QEfficient.finetune.utils.config_utils import (
@@ -41,8 +42,8 @@
4142
except ImportError as e:
4243
logger.warning(f"{e}. Moving ahead without these qaic modules.")
4344

45+
logger.setLevel(logging.INFO)
4446

45-
from transformers import AutoModelForSequenceClassification
4647

4748
# Suppress all warnings
4849
warnings.filterwarnings("ignore")
@@ -245,7 +246,7 @@ def setup_dataloaders(
245246
# )
246247
##
247248
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
248-
logger.info("length of dataset_train", len(dataset_train))
249+
logger.info(f"length of dataset_train = {len(dataset_train)}")
249250

250251
# FIXME (Meet): Add custom data collator registration from the outside by the user.
251252
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
@@ -260,7 +261,7 @@ def setup_dataloaders(
260261
pin_memory=True,
261262
**train_dl_kwargs,
262263
)
263-
logger.info(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
264+
logger.info(f"Num of Training Set Batches loaded = {len(train_dataloader)}")
264265

265266
eval_dataloader = None
266267
if train_config.run_validation:
@@ -284,7 +285,7 @@ def setup_dataloaders(
284285
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
285286
)
286287
else:
287-
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
288+
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
288289

289290
longest_seq_length, _ = get_longest_seq_length(
290291
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2323
)
2424
except Exception as e:
2525
logger.error(
26-
"Loading of grammar dataset failed! Please see [here](https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
26+
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
2727
)
2828
raise e
2929

QEfficient/finetune/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ def main(**kwargs):
109109
pin_memory=True,
110110
**val_dl_kwargs,
111111
)
112-
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
112+
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
113113
if len(eval_dataloader) == 0:
114114
raise ValueError(
115115
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
116116
)
117117
else:
118-
logger.info(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
118+
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
119119

120120
model.to(device)
121121
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)

QEfficient/finetune/utils/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def update_config(config, **kwargs):
5454
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
5555
else:
5656
config_type = type(config).__name__
57-
logger.warning(f"Unknown parameter '{k}' for config type '{config_type}'")
57+
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
5858

5959

6060
def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any:

QEfficient/finetune/utils/train_utils.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,7 @@ def train(
8585
device_type = device.split(":")[0]
8686

8787
tensorboard_updates = None
88-
if train_config.enable_ddp:
89-
if local_rank == 0:
90-
tensorboard_updates = SummaryWriter()
91-
else:
88+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
9289
tensorboard_updates = SummaryWriter()
9390

9491
if train_config.grad_scaler:
@@ -113,14 +110,9 @@ def train(
113110
# Start the training loop
114111
for epoch in range(train_config.num_epochs):
115112
if loss_0_counter.item() == train_config.convergence_counter:
116-
if train_config.enable_ddp:
117-
logger.info(
118-
f"Not proceeding with epoch {epoch + 1} on device {local_rank} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
119-
)
120-
break
121-
else:
113+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
122114
logger.info(
123-
f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
115+
f"Skipping epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
124116
)
125117
break
126118

@@ -161,7 +153,7 @@ def train(
161153
if epoch == intermediate_epoch and step == 0:
162154
total_train_steps += intermediate_step
163155
logger.info(
164-
f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them."
156+
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
165157
)
166158
if epoch == intermediate_epoch and step < intermediate_step:
167159
total_train_steps += 1
@@ -221,10 +213,7 @@ def train(
221213
else:
222214
loss_0_counter = torch.tensor([0]).to(device)
223215

224-
if train_config.enable_ddp:
225-
if local_rank == 0:
226-
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
227-
else:
216+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
228217
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
229218

230219
if train_config.save_metrics:
@@ -275,16 +264,10 @@ def train(
275264
val_step_metric,
276265
val_metric,
277266
)
278-
if train_config.enable_ddp:
279-
if loss_0_counter.item() == train_config.convergence_counter:
280-
logger.info(
281-
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning on device {local_rank}."
282-
)
283-
break
284-
else:
267+
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
285268
if loss_0_counter.item() == train_config.convergence_counter:
286269
logger.info(
287-
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps. Hence, stopping the fine tuning."
270+
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps.Hence,stopping the fine tuning."
288271
)
289272
break
290273

@@ -457,7 +440,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
457440
eval_metric = torch.exp(eval_epoch_loss)
458441

459442
# Print evaluation metrics
460-
logger.info(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
443+
logger.info(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
461444

462445
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
463446

@@ -487,9 +470,9 @@ def print_model_size(model, config) -> None:
487470
model_name (str): Name of the model.
488471
"""
489472

490-
logger.info(f"--> Model {config.model_name}")
473+
logger.info(f"Model : {config.model_name}")
491474
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
492-
logger.info(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
475+
logger.info(f"{config.model_name} has {total_params / 1e6} Million params\n")
493476

494477

495478
def save_to_json(

0 commit comments

Comments
 (0)