8
8
import json
9
9
import os
10
10
import time
11
- from contextlib import nullcontext
12
11
from datetime import datetime
12
+ from functools import partial
13
13
from typing import Dict , List , Tuple
14
14
15
15
import torch
19
19
from tqdm import tqdm
20
20
21
21
from QEfficient .finetune .configs .training import TrainConfig
22
+ from QEfficient .finetune .utils .helper import get_autocast_ctx , get_op_verifier_ctx
22
23
23
24
try :
24
25
import torch_qaic # noqa: F401
@@ -110,6 +111,9 @@ def train(
110
111
num_classes = model .classifier .out_features
111
112
acc_helper = torchmetrics .classification .MulticlassAccuracy (num_classes = num_classes ).to (device )
112
113
114
+ autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
115
+ op_verifier_ctx = partial (get_op_verifier_ctx , train_config .opByOpVerifier , device , train_config .dump_root_dir )
116
+
113
117
# Start the training loop
114
118
for epoch in range (train_config .num_epochs ):
115
119
if loss_0_counter .item () == train_config .convergence_counter :
@@ -174,60 +178,38 @@ def train(
174
178
break
175
179
batch = {k : v .to (device ) for k , v in batch .items ()} # move the batch elements to qaic device
176
180
177
- with (
178
- torch .autocast (device_type = device_type , dtype = torch .float16 )
179
- if train_config .use_autocast
180
- else nullcontext ()
181
- ):
182
- # an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
183
- if train_config .opByOpVerifier :
184
- with qaic_debug .OpByOpVerifierMode (
185
- ref_device = "cpu" ,
186
- ref_dtype = torch .float32 ,
187
- # adjust atol & rtol this as required
188
- atol = 1e-1 ,
189
- use_ref_output_on_mismatch = True ,
190
- filter_config = qaic_debug .DispatchFilterConfig .default (device ),
191
- dump_root_dir = train_config .dump_root_dir + str (step ),
192
- ) as verifier :
193
- model_outputs = model (** batch )
194
- loss = model_outputs .loss # Forward call
195
- if (batch ["labels" ] != - 100 ).sum () == 0 :
196
- loss = loss .nan_to_num (nan = 0.0 )
197
- num_dummy_samples += train_config .train_batch_size
198
- else :
199
- num_dummy_samples_per_batch = (
200
- (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
201
- )
202
- if num_dummy_samples_per_batch > 0 :
203
- num_dummy_samples += num_dummy_samples_per_batch
204
- loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
205
-
206
- if train_config .task_type == "seq_classification" :
207
- logits = model_outputs .logits
208
- labels = batch ["labels" ][:, 0 ]
209
- preds = torch .nn .functional .softmax (logits , dim = - 1 )
210
- acc_helper .forward (preds , labels )
211
- print ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
181
+ is_optimizer_step = (step + 1 ) % train_config .gradient_accumulation_steps == 0 or step == len (
182
+ train_dataloader
183
+ ) - 1
184
+ if train_config .enable_ddp :
185
+ # Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
186
+ # in DDP training we only need to sync gradients at the last micro step.
187
+ # the official way to do this is with model.no_sync() context manager, but
188
+ # using too many context managers may bloat the code and forces us to repeat code
189
+ # looking at the source of that context manager, it just toggles this variable
190
+ model .require_backward_grad_sync = is_optimizer_step
191
+
192
+ with autocast_ctx , op_verifier_ctx (step ) as verifier :
193
+ model_outputs = model (** batch )
194
+ loss = model_outputs .loss # Forward call
195
+ if (batch ["labels" ] != - 100 ).sum () == 0 :
196
+ loss = loss .nan_to_num (nan = 0.0 )
197
+ num_dummy_samples += train_config .train_batch_size
212
198
else :
213
- model_outputs = model (** batch )
214
- loss = model_outputs .loss # Forward call
215
- if (batch ["labels" ] != - 100 ).sum () == 0 :
216
- loss = loss .nan_to_num (nan = 0.0 )
217
- num_dummy_samples += train_config .train_batch_size
218
- else :
219
- num_dummy_samples_per_batch = (
220
- (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
221
- )
222
- if num_dummy_samples_per_batch > 0 :
223
- num_dummy_samples += num_dummy_samples_per_batch
224
- loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
199
+ num_dummy_samples_per_batch = (
200
+ (torch .sum (batch ["labels" ] == - 100 , dim = 1 ) == batch ["labels" ].shape [1 ]).sum ().item ()
201
+ )
202
+ if num_dummy_samples_per_batch > 0 :
203
+ num_dummy_samples += num_dummy_samples_per_batch
204
+ loss = loss * train_config .train_batch_size / num_dummy_samples_per_batch
225
205
226
- if train_config .task_type == "seq_classification" :
227
- logits = model_outputs .logits
228
- labels = batch ["labels" ][:, 0 ]
229
- preds = torch .nn .functional .softmax (logits , dim = - 1 )
230
- acc_helper .forward (preds , labels )
206
+ if train_config .task_type == "seq_classification" :
207
+ logits = model_outputs .logits
208
+ labels = batch ["labels" ][:, 0 ]
209
+ preds = torch .nn .functional .softmax (logits , dim = - 1 )
210
+ acc_helper .forward (preds , labels )
211
+ if train_config .opByOpVerifier :
212
+ print ("Mismatches detected:" , verifier .get_perop_mismatch_count ())
231
213
232
214
total_loss += loss .detach ().float ()
233
215
@@ -274,7 +256,7 @@ def train(
274
256
else :
275
257
loss .backward () # backward pass
276
258
277
- if ( step + 1 ) % train_config . gradient_accumulation_steps == 0 or step == len ( train_dataloader ) - 1 :
259
+ if is_optimizer_step :
278
260
if train_config .grad_scaler :
279
261
scaler .step (optimizer )
280
262
scaler .update ()
@@ -468,6 +450,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
468
450
device_type = torch .device (device ).type
469
451
470
452
num_dummy_samples = 0
453
+ autocast_ctx = get_autocast_ctx (train_config .use_autocast , device_type , dtype = torch .float16 )
471
454
for step , batch in enumerate (tqdm (eval_dataloader , colour = "green" , desc = "evaluating Epoch" , dynamic_ncols = True )):
472
455
# stop when the maximum number of eval steps is reached
473
456
if train_config .max_eval_step > 0 and step > train_config .max_eval_step :
@@ -478,11 +461,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
478
461
# Ensure no gradients are computed for this scope to save memory
479
462
with torch .no_grad ():
480
463
# Forward pass and compute loss
481
- with (
482
- torch .autocast (device_type = device_type , dtype = torch .float16 )
483
- if train_config .use_autocast
484
- else nullcontext ()
485
- ):
464
+ with autocast_ctx :
486
465
outputs = model (** batch )
487
466
loss = outputs .loss
488
467
0 commit comments