@@ -588,3 +588,66 @@ def _multi_conv(self, x, **kwargs):
588
588
"""Applies the convolutions with different sizes and concatenates outputs."""
589
589
590
590
return tf .concat ([conv (x , ** kwargs ) for conv in self .convs ], axis = - 1 )
591
+
592
+
593
+ class ConfigurableMLP (tf .keras .Model ):
594
+ """Implements a simple configurable MLP with optional residual connections and dropout."""
595
+
596
+ def __init__ (
597
+ self , input_dim , hidden_dim = 512 , num_hidden = 2 , activation = "relu" , residual = True , dropout_rate = 0.05 , ** kwargs
598
+ ):
599
+ """
600
+ Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
601
+
602
+ Parameters:
603
+ -----------
604
+ input_dim : int
605
+ The input dimensionality
606
+ hidden_dim : int, optional, default: 512
607
+ The dimensionality of the hidden layers
608
+ num_hidden : int, optional, default: 2
609
+ The number of hidden layers (minimum: 1)
610
+ activation : string, optional, default: 'relu'
611
+ The activation function of the dense layers
612
+ residual : bool, optional, default: True
613
+ Use residual connections in the MLP
614
+ dropout_rate : float, optional, default: 0.05
615
+ Dropout rate for the hidden layers in the MLP
616
+ """
617
+
618
+ super ().__init__ (** kwargs )
619
+
620
+ self .input_dim = input_dim
621
+ self .model = tf .keras .Sequential (
622
+ [tf .keras .layers .Dense (hidden_dim , activation = activation ), tf .keras .layers .Dropout (dropout_rate )]
623
+ )
624
+ for _ in range (num_hidden ):
625
+ self .model .add (
626
+ ConfigurableHiddenBlock (
627
+ hidden_dim ,
628
+ activation = activation ,
629
+ residual = residual ,
630
+ dropout_rate = dropout_rate ,
631
+ )
632
+ )
633
+ self .model .add (tf .keras .layers .Dense (input_dim ))
634
+
635
+ def call (self , inputs , ** kwargs ):
636
+ return self .model (inputs , ** kwargs )
637
+
638
+
639
+ class ConfigurableHiddenBlock (tf .keras .Model ):
640
+ def __init__ (self , num_units , activation = "relu" , residual = True , dropout_rate = 0.0 ):
641
+ super ().__init__ ()
642
+
643
+ self .act_fn = tf .keras .activations .get (activation )
644
+ self .residual = residual
645
+ self .dense_with_dropout = tf .keras .Sequential (
646
+ [tf .keras .layers .Dense (num_units , activation = None ), tf .keras .layers .Dropout (dropout_rate )]
647
+ )
648
+
649
+ def call (self , inputs , ** kwargs ):
650
+ x = self .dense_with_dropout (inputs , ** kwargs )
651
+ if self .residual :
652
+ x = x + inputs
653
+ return self .act_fn (x )
0 commit comments