-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Two moons test fit #175
Two moons test fit #175
Changes from all commits
83c1f3f
7f96ce1
73442f8
fea06f9
db6480c
2c5ac37
bde8ff2
6048146
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import keras | ||
import pytest | ||
|
||
from tests.utils import assert_models_equal | ||
from tests.utils import assert_models_equal, max_mean_discrepancy | ||
from tests.utils import InterruptFitCallback, FitInterruptedError | ||
|
||
|
||
|
@@ -11,14 +11,38 @@ def test_compile(approximator, random_samples, jit_compile): | |
|
||
|
||
@pytest.mark.parametrize("jit_compile", [False, True]) | ||
def test_fit(approximator, train_dataset, validation_dataset, jit_compile): | ||
# TODO: verify the model learns something by comparing a metric before and after training | ||
approximator.compile(jit_compile=jit_compile) | ||
approximator.fit( | ||
def test_fit(approximator, train_dataset, validation_dataset, test_dataset, jit_compile): | ||
# TODO: Refactor to use approximator.sample() when implemented (instead of calling the inference network directly) | ||
|
||
approximator.compile(jit_compile=jit_compile, loss=keras.losses.KLDivergence()) | ||
inf_vars = approximator.configurator.configure_inference_variables(test_dataset.data) | ||
inf_conds = approximator.configurator.configure_inference_conditions(test_dataset.data) | ||
y = test_dataset.data["x"] | ||
|
||
pre_loss = approximator.compute_metrics(train_dataset.data)["loss"] | ||
pre_val_loss = approximator.compute_metrics(validation_dataset.data)["loss"] | ||
x_before = approximator.inference_network(inf_vars, conditions=inf_conds) | ||
mmd_before = max_mean_discrepancy(x_before, y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inconsistent naming |
||
|
||
history = approximator.fit( | ||
train_dataset, | ||
validation_data=validation_dataset, | ||
epochs=2, | ||
) | ||
epochs=3, | ||
).history | ||
x_after = approximator.inference_network(inf_vars, conditions=inf_conds) | ||
mmd_after = max_mean_discrepancy(x_after, y) | ||
|
||
# Test model weights have not vanished | ||
for layer in approximator.layers: | ||
for weight in layer.weights: | ||
assert not keras.ops.any(keras.ops.isnan(weight)).numpy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Test KLD loss and validation loss decrease after training | ||
assert history["loss"][-1] < pre_loss | ||
assert history["val_loss"][-1] < pre_val_loss | ||
|
||
# Test MMD improved after training | ||
assert mmd_after < mmd_before | ||
|
||
|
||
@pytest.mark.parametrize("jit_compile", [False, True]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,30 @@ def isclose(x1, x2, rtol=1e-5, atol=1e-8): | |
|
||
def allclose(x1, x2, rtol=1e-5, atol=1e-8): | ||
return keras.ops.all(isclose(x1, x2, rtol, atol)) | ||
|
||
|
||
def max_mean_discrepancy(x, y): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the future, let's put these in |
||
# Computes the Max Mean Discrepancy between samples of two distributions | ||
xx = keras.ops.matmul(x, keras.ops.transpose(x)) | ||
yy = keras.ops.matmul(y, keras.ops.transpose(y)) | ||
zz = keras.ops.matmul(x, keras.ops.transpose(y)) | ||
|
||
rx = keras.ops.broadcast_to(keras.ops.expand_dims(keras.ops.diag(xx), 0), xx.shape) | ||
ry = keras.ops.broadcast_to(keras.ops.expand_dims(keras.ops.diag(yy), 0), yy.shape) | ||
|
||
dxx = keras.ops.transpose(rx) + rx - 2.0 * xx | ||
dyy = keras.ops.transpose(ry) + ry - 2.0 * yy | ||
dxy = keras.ops.transpose(rx) + ry - 2.0 * zz | ||
|
||
XX = keras.ops.zeros(xx.shape) | ||
YY = keras.ops.zeros(yy.shape) | ||
XY = keras.ops.zeros(zz.shape) | ||
|
||
# RBF scaling | ||
bandwidth = [10, 15, 20, 50] | ||
for a in bandwidth: | ||
XX += keras.ops.exp(-0.5 * dxx / a) | ||
YY += keras.ops.exp(-0.5 * dyy / a) | ||
XY += keras.ops.exp(-0.5 * dxy / a) | ||
|
||
return keras.ops.mean(XX + YY - 2.0 * XY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
relies on the internal structure of the dataset fixture, which is not good. Use
test_batch = test_dataset[0]; observables = test_batch["x"]
instead.