From 49bbe79045832fee7c93876dd201343aeda1b50b Mon Sep 17 00:00:00 2001 From: larskue Date: Tue, 25 Jun 2024 18:20:58 +0200 Subject: [PATCH] add check that weights have changed during training --- tests/test_two_moons/conftest.py | 7 ------- tests/test_two_moons/test_two_moons.py | 6 ++++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index a3eb7899f..264a85dc8 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -17,13 +17,6 @@ def batch_size(): return 128 -@pytest.fixture() -def inference_network(): - from bayesflow.networks import CouplingFlow - - return CouplingFlow() - - @pytest.fixture() def random_samples(batch_size, simulator): return simulator.sample((batch_size,)) diff --git a/tests/test_two_moons/test_two_moons.py b/tests/test_two_moons/test_two_moons.py index 07f10809d..c3fb1c20f 100644 --- a/tests/test_two_moons/test_two_moons.py +++ b/tests/test_two_moons/test_two_moons.py @@ -1,3 +1,4 @@ +import copy import keras import pytest @@ -19,12 +20,17 @@ def test_fit(approximator, train_dataset, validation_dataset, batch_size): approximator.build_from_data(train_dataset[0]) + untrained_weights = copy.deepcopy(approximator.weights) untrained_metrics = approximator.evaluate(validation_dataset, return_dict=True) approximator.fit(train_dataset, epochs=20) + trained_weights = approximator.weights trained_metrics = approximator.evaluate(validation_dataset, return_dict=True) + # check weights have changed during training + assert any([keras.ops.any(~keras.ops.isclose(u, t)) for u, t in zip(untrained_weights, trained_weights)]) + assert isinstance(untrained_metrics, dict) assert isinstance(trained_metrics, dict)