3434from libs .nnet3 .train .dropout_schedule import _get_dropout_proportions
3535from model import get_chain_model
3636from options import get_args
37+ from sgd_max_change import SgdMaxChange
3738
3839def get_objf (batch , model , device , criterion , opts , den_graph , training , optimizer = None , dropout = 0. ):
3940 feature , supervision = batch
@@ -67,20 +68,20 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
6768 supervision , nnet_output ,
6869 xent_output )
6970 objf = objf_l2_term_weight [0 ]
71+ change = 0
7072 if training :
7173 optimizer .zero_grad ()
7274 objf .backward ()
73- clip_grad_value_ (model .parameters (), 5.0 )
74- optimizer .step ()
75+ # clip_grad_value_(model.parameters(), 5.0)
76+ _ , change = optimizer .step ()
7577
7678 objf_l2_term_weight = objf_l2_term_weight .detach ().cpu ()
7779
7880 total_objf = objf_l2_term_weight [0 ].item ()
7981 total_weight = objf_l2_term_weight [2 ].item ()
8082 total_frames = nnet_output .shape [0 ]
8183
82- return total_objf , total_weight , total_frames
83-
84+ return total_objf , total_weight , total_frames , change
8485
8586def get_validation_objf (dataloader , model , device , criterion , opts , den_graph ):
8687 total_objf = 0.
@@ -90,7 +91,7 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
9091 model .eval ()
9192
9293 for batch_idx , (pseudo_epoch , batch ) in enumerate (dataloader ):
93- objf , weight , frames = get_objf (
94+ objf , weight , frames , _ = get_objf (
9495 batch , model , device , criterion , opts , den_graph , False )
9596 total_objf += objf
9697 total_weight += weight
@@ -116,7 +117,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
116117 len (dataloader )) / (len (dataloader ) * num_epochs )
117118 _ , dropout = _get_dropout_proportions (
118119 dropout_schedule , data_fraction )[0 ]
119- curr_batch_objf , curr_batch_weight , curr_batch_frames = get_objf (
120+ curr_batch_objf , curr_batch_weight , curr_batch_frames , curr_batch_change = get_objf (
120121 batch , model , device , criterion , opts , den_graph , True , optimizer , dropout = dropout )
121122
122123 total_objf += curr_batch_objf
@@ -127,13 +128,13 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
127128 logging .info (
128129 'Device ({}) processing batch {}, current pseudo-epoch is {}/{}({:.6f}%), '
129130 'global average objf: {:.6f} over {} '
130- 'frames, current batch average objf: {:.6f} over {} frames, epoch {}'
131+ 'frames, current batch average objf: {:.6f} over {} frames, minibatch change: {:.6f}, epoch {}'
131132 .format (
132133 device .index , batch_idx , pseudo_epoch , len (dataloader ),
133134 float (pseudo_epoch ) / len (dataloader ) * 100 ,
134135 total_objf / total_weight , total_frames ,
135136 curr_batch_objf / curr_batch_weight ,
136- curr_batch_frames , current_epoch ))
137+ curr_batch_frames , curr_batch_change , current_epoch ))
137138
138139 if valid_dataloader and batch_idx % 1000 == 0 :
139140 total_valid_objf , total_valid_weight , total_valid_frames = get_validation_objf (
@@ -167,6 +168,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
167168 dropout ,
168169 pseudo_epoch + current_epoch * len (dataloader ))
169170
171+ tf_writer .add_scalar (
172+ 'train/current_batch_change' ,
173+ curr_batch_change ,
174+ pseudo_epoch + current_epoch * len (dataloader ))
175+
170176 state_dict = model .state_dict ()
171177 for key , value in state_dict .items ():
172178 # skip batchnorm parameters
@@ -301,7 +307,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
301307 else :
302308 valid_dataloader = None
303309
304- optimizer = optim . Adam (model .parameters (),
310+ optimizer = SgdMaxChange (model .parameters (),
305311 lr = learning_rate ,
306312 weight_decay = 5e-4 )
307313
0 commit comments