Skip to content

Commit

Permalink
fix mmd for batch sizes n != m
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Aug 21, 2024
1 parent 5ab83c4 commit be27554
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
9 changes: 5 additions & 4 deletions bayesflow/metrics/functional/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

def gaussian_kernel(x1: Tensor, x2: Tensor, scales: Tensor = keras.ops.logspace(-6, 6, 11)) -> Tensor:
residuals = x1[:, None] - x2[None, :]
norms = keras.ops.norm(residuals, axis=tuple(range(2, keras.ops.ndim(residuals))))
residuals = keras.ops.reshape(residuals, keras.ops.shape(residuals)[:2] + (-1,))
norms = keras.ops.norm(residuals, axis=2)
exponent = norms[:, :, None] / (2.0 * scales[None, None, :])
return keras.ops.mean(keras.ops.exp(-exponent), axis=2)

Expand Down Expand Up @@ -44,8 +45,8 @@ def maximum_mean_discrepancy(x1: Tensor, x2: Tensor, kernel: str = "gaussian", *
x1 = keras.ops.reshape(x1, (keras.ops.shape(x1)[0], -1))
x2 = keras.ops.reshape(x2, (keras.ops.shape(x2)[0], -1))

k1 = kernel_fn(x1, x1, **kwargs)
k2 = kernel_fn(x2, x2, **kwargs)
k3 = kernel_fn(x1, x2, **kwargs)
k1 = keras.ops.mean(kernel_fn(x1, x1, **kwargs), axis=1)
k2 = keras.ops.mean(kernel_fn(x2, x2, **kwargs), axis=1)
k3 = keras.ops.mean(kernel_fn(x1, x2, **kwargs), axis=1)

return k1 + k2 - 2.0 * k3
6 changes: 1 addition & 5 deletions bayesflow/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,4 @@

class MaximumMeanDiscrepancy(keras.metrics.MeanMetricWrapper):
def __init__(self, name="maximum_mean_discrepancy", dtype=None, **kwargs):
def fn(y_true, y_pred):
mmd = maximum_mean_discrepancy(y_true, y_pred, **kwargs)
return keras.ops.mean(mmd, axis=1)

super().__init__(fn, name=name, dtype=dtype)
super().__init__(maximum_mean_discrepancy, name=name, dtype=dtype)

0 comments on commit be27554

Please sign in to comment.