@@ -594,7 +594,15 @@ class ConfigurableMLP(tf.keras.Model):
594594 """Implements a simple configurable MLP with optional residual connections and dropout."""
595595
596596 def __init__ (
597- self , input_dim , hidden_dim = 512 , num_hidden = 2 , activation = "relu" , residual = True , dropout_rate = 0.05 , ** kwargs
597+ self ,
598+ input_dim ,
599+ hidden_dim = 512 ,
600+ output_dim = None ,
601+ num_hidden = 2 ,
602+ activation = "relu" ,
603+ residual = True ,
604+ dropout_rate = 0.05 ,
605+ ** kwargs ,
598606 ):
599607 """
600608 Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
@@ -605,6 +613,8 @@ def __init__(
605613 The input dimensionality
606614 hidden_dim : int, optional, default: 512
607615 The dimensionality of the hidden layers
616+ output_dim : int, optional, default: None
617+ The output dimensionality. If None is passed, `output_dim` is set to `input_dim`
608618 num_hidden : int, optional, default: 2
609619 The number of hidden layers (minimum: 1)
610620 activation : string, optional, default: 'relu'
@@ -618,6 +628,7 @@ def __init__(
618628 super ().__init__ (** kwargs )
619629
620630 self .input_dim = input_dim
631+ self .output_dim = input_dim if output_dim is None else output_dim
621632 self .model = tf .keras .Sequential (
622633 [tf .keras .layers .Dense (hidden_dim , activation = activation ), tf .keras .layers .Dropout (dropout_rate )]
623634 )
@@ -630,7 +641,7 @@ def __init__(
630641 dropout_rate = dropout_rate ,
631642 )
632643 )
633- self .model .add (tf .keras .layers .Dense (input_dim ))
644+ self .model .add (tf .keras .layers .Dense (self . output_dim ))
634645
635646 def call (self , inputs , ** kwargs ):
636647 return self .model (inputs , ** kwargs )
0 commit comments