@@ -594,7 +594,15 @@ class ConfigurableMLP(tf.keras.Model):
594
594
"""Implements a simple configurable MLP with optional residual connections and dropout."""
595
595
596
596
def __init__ (
597
- self , input_dim , hidden_dim = 512 , num_hidden = 2 , activation = "relu" , residual = True , dropout_rate = 0.05 , ** kwargs
597
+ self ,
598
+ input_dim ,
599
+ hidden_dim = 512 ,
600
+ output_dim = None ,
601
+ num_hidden = 2 ,
602
+ activation = "relu" ,
603
+ residual = True ,
604
+ dropout_rate = 0.05 ,
605
+ ** kwargs ,
598
606
):
599
607
"""
600
608
Creates an instance of a flexible and simple MLP with optional residual connections and dropout.
@@ -605,6 +613,8 @@ def __init__(
605
613
The input dimensionality
606
614
hidden_dim : int, optional, default: 512
607
615
The dimensionality of the hidden layers
616
+ output_dim : int, optional, default: None
617
+ The output dimensionality. If None is passed, `output_dim` is set to `input_dim`
608
618
num_hidden : int, optional, default: 2
609
619
The number of hidden layers (minimum: 1)
610
620
activation : string, optional, default: 'relu'
@@ -618,6 +628,7 @@ def __init__(
618
628
super ().__init__ (** kwargs )
619
629
620
630
self .input_dim = input_dim
631
+ self .output_dim = input_dim if output_dim is None else output_dim
621
632
self .model = tf .keras .Sequential (
622
633
[tf .keras .layers .Dense (hidden_dim , activation = activation ), tf .keras .layers .Dropout (dropout_rate )]
623
634
)
@@ -630,7 +641,7 @@ def __init__(
630
641
dropout_rate = dropout_rate ,
631
642
)
632
643
)
633
- self .model .add (tf .keras .layers .Dense (input_dim ))
644
+ self .model .add (tf .keras .layers .Dense (self . output_dim ))
634
645
635
646
def call (self , inputs , ** kwargs ):
636
647
return self .model (inputs , ** kwargs )
0 commit comments