@@ -21,7 +21,7 @@ def __init__(
21
21
scoring_rules : dict [str , ScoringRule ],
22
22
body_subnet : str | type = "mlp" , # naming: shared_subnet / body / subnet ?
23
23
heads_subnet : dict [str , str | keras .Layer ] = None , # TODO: `type` instead of `keras.Layer` ? Too specific ?
24
- activations : dict [str , keras .layers . Activation | Callable | str ] = None ,
24
+ activations : dict [str , keras .Layer | Callable | str ] = None ,
25
25
** kwargs ,
26
26
):
27
27
super ().__init__ (
@@ -36,17 +36,17 @@ def __init__(
36
36
37
37
self .body_subnet = find_network (body_subnet , ** kwargs .get ("body_subnet_kwargs" , {}))
38
38
39
- if heads_subnet :
39
+ if heads_subnet is not None :
40
40
self .heads = {
41
41
key : [find_network (value , ** kwargs .get ("heads_subnet_kwargs" , {}).get (key , {}))]
42
42
for key , value in heads_subnet .items ()
43
43
}
44
44
else :
45
45
self .heads = {key : [] for key in self .scoring_rules .keys ()}
46
46
47
- if activations :
47
+ if activations is not None :
48
48
self .activations = {
49
- key : (value if isinstance (value , keras .layers . Activation ) else keras .layers .Activation (value ))
49
+ key : (value if isinstance (value , keras .Layer ) else keras .layers .Activation (value ))
50
50
for key , value in activations .items ()
51
51
} # make sure that each value is an Activation object
52
52
else :
@@ -64,16 +64,16 @@ def __init__(
64
64
65
65
assert set (self .scoring_rules .keys ()) == set (self .heads .keys ()) == set (self .activations .keys ())
66
66
67
- def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
67
+ def build (self , xz_shape : Shape , conditions_shape : Shape ) -> None :
68
68
# build the shared body network
69
69
input_shape = conditions_shape
70
70
self .body_subnet .build (input_shape )
71
71
body_output_shape = self .body_subnet .compute_output_shape (input_shape )
72
72
73
73
for key in self .heads .keys ():
74
- # head_output_shape (excluding batch_size) convention is (*prediction_shape , *parameter_block_shape)
75
- prediction_shape = self .scoring_rules [key ].prediction_shape
76
- head_output_shape = prediction_shape + xz_shape [1 :]
74
+ # head_output_shape (excluding batch_size) convention is (*target_shape , *parameter_block_shape)
75
+ target_shape = self .scoring_rules [key ].target_shape
76
+ head_output_shape = target_shape + xz_shape [1 :]
77
77
78
78
# set correct head shape
79
79
self .heads [key ][- 3 ].units = prod (head_output_shape )
@@ -91,13 +91,18 @@ def call(
91
91
conditions : Tensor = None ,
92
92
training : bool = False ,
93
93
** kwargs ,
94
- ) -> Tensor | tuple [ Tensor , Tensor ]:
94
+ ) -> dict [ str , Tensor ]:
95
95
# TODO: remove unnecessary simularity with InferenceNetwork
96
96
return self ._forward (xz , conditions = conditions , training = training , ** kwargs )
97
97
98
98
def _forward (
99
- self , x : Tensor , conditions : Tensor = None , training : bool = False , ** kwargs
100
- ) -> Tensor | tuple [Tensor , Tensor ]:
99
+ self ,
100
+ x : Tensor ,
101
+ conditions : Tensor = None ,
102
+ training : bool = False ,
103
+ ** kwargs ,
104
+ # TODO: propagate training flag
105
+ ) -> dict [str , Tensor ]:
101
106
body_output = self .body_subnet (conditions )
102
107
103
108
output = dict ()
0 commit comments