@@ -63,7 +63,6 @@ def __init__(
63
63
64
64
self .seed_generator = seed_generator or keras .random .SeedGenerator ()
65
65
66
- self .log_normalization_constant = None
67
66
self .dim = None
68
67
self ._loc = None
69
68
self ._scale = None
@@ -78,21 +77,16 @@ def build(self, input_shape: Shape) -> None:
78
77
self .loc = ops .cast (ops .broadcast_to (self .loc , (self .dim ,)), "float32" )
79
78
self .scale = ops .cast (ops .broadcast_to (self .scale , (self .dim ,)), "float32" )
80
79
81
- self .log_normalization_constant = (
82
- - 0.5 * self .dim * math .log (self .df )
83
- - 0.5 * self .dim * math .log (math .pi )
84
- - math .lgamma (0.5 * self .df )
85
- + math .lgamma (0.5 * (self .df + self .dim ))
86
- - ops .sum (keras .ops .log (self .scale ))
87
- )
88
-
89
80
if self .trainable_parameters :
90
81
self ._loc = self .add_weight (
91
- shape = ops .shape (self .loc ), initializer = keras .initializers .get (self .loc ), dtype = "float32" , trainable = True
82
+ shape = ops .shape (self .loc ),
83
+ initializer = keras .initializers .get (keras .ops .copy (self .loc )),
84
+ dtype = "float32" ,
85
+ trainable = True ,
92
86
)
93
87
self ._scale = self .add_weight (
94
88
shape = ops .shape (self .scale ),
95
- initializer = keras .initializers .get (self .scale ),
89
+ initializer = keras .initializers .get (keras . ops . copy ( self .scale ) ),
96
90
dtype = "float32" ,
97
91
trainable = True ,
98
92
)
@@ -105,7 +99,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
105
99
result = - 0.5 * (self .df + self .dim ) * ops .log1p (mahalanobis_term / self .df )
106
100
107
101
if normalize :
108
- result += self .log_normalization_constant
102
+ log_normalization_constant = (
103
+ - 0.5 * self .dim * math .log (self .df )
104
+ - 0.5 * self .dim * math .log (math .pi )
105
+ - math .lgamma (0.5 * self .df )
106
+ + math .lgamma (0.5 * (self .df + self .dim ))
107
+ - ops .sum (keras .ops .log (self ._scale ))
108
+ )
109
+ result += log_normalization_constant
109
110
110
111
return result
111
112
0 commit comments