@@ -83,6 +83,7 @@ def train(
83
83
best_val_loss = float ("inf" )
84
84
total_train_steps = 0
85
85
max_steps_reached = False # Flag to indicate max training steps reached
86
+ device_type = device .split (":" )[0 ]
86
87
87
88
tensorboard_updates = None
88
89
if train_config .enable_ddp :
@@ -95,7 +96,7 @@ def train(
95
96
if device .startswith ("qaic" ):
96
97
scaler = QAicGradScaler ()
97
98
else :
98
- scaler = GradScaler ()
99
+ scaler = GradScaler (device_type )
99
100
100
101
loss_0_counter = torch .tensor ([0 ]).to (device )
101
102
@@ -177,10 +178,7 @@ def train(
177
178
# adjust atol & rtol this as required
178
179
atol = 1e-1 ,
179
180
use_ref_output_on_mismatch = True ,
180
- # report all mismatches
181
- max_failures = None ,
182
- # generate unittest for each op once
183
- repeat_same_op = True ,
181
+ filter_config = qaic_debug .DispatchFilterConfig .default (device ),
184
182
dump_root_dir = train_config .dump_root_dir + str (step ),
185
183
) as verifier :
186
184
loss = model (** batch ).loss # Forward call
@@ -296,8 +294,6 @@ def train(
296
294
eval_ppl , eval_epoch_loss , temp_val_loss , temp_step_perplexity = evaluation (
297
295
model , train_config , eval_dataloader , local_rank , tokenizer , device
298
296
)
299
- dist .barrier ()
300
- dist .all_reduce (eval_epoch_loss , op = dist .ReduceOp .SUM )
301
297
if local_rank == 0 :
302
298
tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
303
299
0 commit comments