1- from functools import partial
2-
31import keras
42
5- from bayesflow .utils .serialization import serializable
3+ from bayesflow .utils .serialization import deserialize , serializable , serialize
64from .functional import maximum_mean_discrepancy
75
86
@@ -17,10 +15,22 @@ def __init__(
1715 ):
1816 super ().__init__ (name = name , ** kwargs )
1917 self .mmd = self .add_variable (shape = (), initializer = "zeros" , name = "mmd" )
20- self .mmd_fn = partial (maximum_mean_discrepancy , kernel = kernel , unbiased = unbiased )
18+ self .kernel = kernel
19+ self .unbiased = unbiased
2120
2221 def update_state (self , x , y ):
23- self .mmd .assign (keras .ops .cast (self .mmd_fn (x , y ), self .dtype ))
22+ self .mmd .assign (
23+ keras .ops .cast (maximum_mean_discrepancy (x , y , kernel = self .kernel , unbiased = self .unbiased ), self .dtype )
24+ )
2425
2526 def result (self ):
2627 return self .mmd .value
28+
29+ def get_config (self ):
30+ base_config = super ().get_config ()
31+ config = {"kernel" : self .kernel , "unbiased" : self .unbiased }
32+ return base_config | serialize (config )
33+
34+ @classmethod
35+ def from_config (cls , config , custom_objects = None ):
36+ return cls (** deserialize (config , custom_objects = custom_objects ))
0 commit comments