@@ -32,7 +32,27 @@ def test_stop_grad(self):
32
32
model = keras .Model ([x_input , y_input , z_input ], [loss ])
33
33
model .add_loss (K .mean (loss ))
34
34
model .compile ('nadam' )
35
- model .fit ([np .array ([[1 ]]), np .array ([[2 ]]), np .array ([[0 ]])], [])
35
+ model .fit ([np .array ([[1 ]]), np .array ([[2 ]]), np .array ([[0 ]])])
36
+
37
+ def test_mog_loss (self ):
38
+ inputs = [keras .layers .Input (shape = s ) for s in [(3 ,), (3 , 2 ), (3 ,), (2 ,)]]
39
+ ll_model = keras .engine .Model (inputs , mog_loss_model (3 , 2 )(inputs ))
40
+
41
+ for n in range (10 ):
42
+ ps = - np .log (np .random .uniform (size = (3 ,)))
43
+ pi = ps / np .sum (ps )
44
+ mu = np .random .normal (size = (3 , 2 ))
45
+ sig = np .exp (np .random .normal (size = 3 ,))
46
+ t = np .random .normal (size = (2 ,))
47
+
48
+ pred = ll_model .predict ([pi .reshape (1 , 3 ), mu .reshape (1 , 3 , 2 ), sig .reshape (1 , 3 ), t .reshape (1 , 2 )])
49
+
50
+ # LL = C - log(sum(pi_i/sig^d * exp(-d2/(2*sig^2))))
51
+ d = mu - t .reshape (- 1 , 2 )
52
+ d2 = np .sum (d * d , axis = - 1 )
53
+ ll = - np .log (np .sum (pi / (sig * sig ) * np .exp (- d2 / (2 * sig * sig )), axis = 0 ))
54
+
55
+ assert np .allclose (ll , pred [0 ])
36
56
37
57
@pytest .mark .slow
38
58
def test_deepiv_shape (self ):
@@ -500,7 +520,7 @@ def norm(lr):
500
520
model = keras .engine .Model ([x_input , t_input ], [ll ])
501
521
model .add_loss (K .mean (ll ))
502
522
model .compile ('nadam' )
503
- model .fit ([x , t ], [], epochs = 5 )
523
+ model .fit ([x , t ], epochs = 5 )
504
524
505
525
# For some reason this doesn't work at all when run against the CNTK backend...
506
526
# model.compile('nadam', loss=lambda _,l:l)
@@ -559,7 +579,7 @@ def sample(n):
559
579
model = keras .engine .Model ([x_input , t_input ], [ll ])
560
580
model .add_loss (K .mean (ll ))
561
581
model .compile ('nadam' )
562
- model .fit ([x , t ], [], epochs = 100 )
582
+ model .fit ([x , t ], epochs = 100 )
563
583
564
584
model2 = keras .engine .Model ([x_input ], [pi , mu , sig ])
565
585
import matplotlib
0 commit comments