File tree 2 files changed +5
-5
lines changed
2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -86,6 +86,10 @@ def main(**kwargs):
86
86
attn_implementation = "sdpa" ,
87
87
torch_dtype = torch .float16 ,
88
88
)
89
+
90
+ if not hasattr (model , "base_model_prefix" ):
91
+ raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
92
+
89
93
for param in getattr (model , model .base_model_prefix ).parameters ():
90
94
param .requires_grad = False
91
95
Original file line number Diff line number Diff line change @@ -306,12 +306,8 @@ def train(
306
306
train_epoch_loss = total_loss / len (train_dataloader )
307
307
308
308
if train_config .task_type == "seq_classification" :
309
- accuracy = acc_helper .compute ()
309
+ metric_val = acc_helper .compute ()
310
310
acc_helper .reset ()
311
- if train_config .enable_ddp :
312
- dist .all_reduce (accuracy , op = dist .ReduceOp .SUM )
313
- accuracy /= dist .get_world_size ()
314
- metric_val = accuracy
315
311
else :
316
312
metric_val = torch .exp (train_epoch_loss )
317
313
You can’t perform that action at this time.
0 commit comments