@@ -412,41 +412,65 @@ def _initialize_networks(self):
412
412
else :
413
413
self .task_ = check_network (self .task ,
414
414
copy = self .copy ,
415
+ force_copy = True ,
415
416
name = "task" )
417
+
418
+
419
+ def _initialize_weights (self , shape_X ):
420
+ if hasattr (self , "task_" ):
421
+ self .task_ .build ((None ,) + shape_X )
422
+ self .build ((None ,) + shape_X )
416
423
self ._add_regularization ()
417
424
418
425
419
- def _get_regularizer (self , old_weight , weight , lambda_ = 1. ):
426
+ def _get_regularizer (self , old_weight , weight , lambda_ ):
420
427
if self .regularizer == "l2" :
421
- def regularizer ():
422
- return lambda_ * tf .reduce_mean (tf .square (old_weight - weight ))
428
+ return lambda_ * tf .reduce_mean (tf .square (old_weight - weight ))
423
429
if self .regularizer == "l1" :
424
- def regularizer ():
425
- return lambda_ * tf .reduce_mean (tf .abs (old_weight - weight ))
430
+ return lambda_ * tf .reduce_mean (tf .abs (old_weight - weight ))
426
431
return regularizer
427
432
428
433
434
+ def train_step (self , data ):
435
+ # Unpack the data.
436
+ Xs , Xt , ys , yt = self ._unpack_data (data )
437
+
438
+ # Run forward pass.
439
+ with tf .GradientTape () as tape :
440
+ y_pred = self .task_ (Xt , training = True )
441
+ if hasattr (self , "_compile_loss" ) and self ._compile_loss is not None :
442
+ loss = self ._compile_loss (yt , y_pred )
443
+ else :
444
+ loss = self .compiled_loss (yt , y_pred )
445
+
446
+ loss = tf .reduce_mean (loss )
447
+ loss += sum (self .losses )
448
+ reg_loss = 0.
449
+ for i in range (len (self .task_ .trainable_variables )):
450
+ reg_loss += self ._get_regularizer (self .old_weights_ [i ],
451
+ self .task_ .trainable_variables [i ],
452
+ self .lambdas_ [i ])
453
+ loss += reg_loss
454
+
455
+ # Run backwards pass.
456
+ gradients = tape .gradient (loss , self .task_ .trainable_variables )
457
+ self .optimizer .apply_gradients (zip (gradients , self .task_ .trainable_variables ))
458
+ return self ._update_logs (yt , y_pred )
459
+
460
+
429
461
def _add_regularization (self ):
430
- i = 0
462
+ self . old_weights_ = []
431
463
if not hasattr (self .lambdas , "__iter__" ):
432
- lambdas = [self .lambdas ]
464
+ self . lambdas_ = [self .lambdas ] * len ( self . task_ . weights )
433
465
else :
434
- lambdas = self .lambdas
466
+ self .lambdas_ = (self .lambdas +
467
+ [self .lambdas [- 1 ]] * (len (self .task_ .weights ) - len (self .lambdas )))
468
+ self .lambdas_ = self .lambdas_ [::- 1 ]
435
469
436
- for layer in reversed (self .task_ .layers ):
437
- if (hasattr (layer , "weights" ) and
438
- layer .weights is not None and
439
- len (layer .weights ) != 0 ):
440
- if i >= len (lambdas ):
441
- lambda_ = lambdas [- 1 ]
442
- else :
443
- lambda_ = lambdas [i ]
444
- for weight in reversed (layer .weights ):
445
- old_weight = tf .identity (weight )
446
- old_weight .trainable = False
447
- self .add_loss (self ._get_regularizer (
448
- old_weight , weight , lambda_ ))
449
- i += 1
470
+ for weight in self .task_ .trainable_variables :
471
+ old_weight = tf .identity (weight )
472
+ old_weight .trainable = False
473
+ self .old_weights_ .append (old_weight )
450
474
451
475
452
476
def call (self , inputs ):
0 commit comments