Skip to content

Commit b6d0ce9

Browse files
committed
Addressed few offline comments.
Signed-off-by: Meet Patel <[email protected]>
1 parent 2713fdc commit b6d0ce9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

QEfficient/cloud/finetune.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def main(**kwargs):
8686
attn_implementation="sdpa",
8787
torch_dtype=torch.float16,
8888
)
89+
90+
if not hasattr(model, "base_model_prefix"):
91+
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
92+
8993
for param in getattr(model, model.base_model_prefix).parameters():
9094
param.requires_grad = False
9195

QEfficient/finetune/utils/train_utils.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,8 @@ def train(
306306
train_epoch_loss = total_loss / len(train_dataloader)
307307

308308
if train_config.task_type == "seq_classification":
309-
accuracy = acc_helper.compute()
309+
metric_val = acc_helper.compute()
310310
acc_helper.reset()
311-
if train_config.enable_ddp:
312-
dist.all_reduce(accuracy, op=dist.ReduceOp.SUM)
313-
accuracy /= dist.get_world_size()
314-
metric_val = accuracy
315311
else:
316312
metric_val = torch.exp(train_epoch_loss)
317313

0 commit comments

Comments
 (0)