-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
194 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
|
||
import keras | ||
import keras.saving | ||
import pytest | ||
|
||
import bayesflow.experimental as bf | ||
|
||
from tests.utils import * | ||
|
||
|
||
# TODO: | ||
# current problems & TODOs: | ||
# - keras.utils.PyDataset does not exist under torch (jax?) backend ==> workaround? | ||
# - implement git workflow to test each backend | ||
# - running pytest ignores the backend env variable - | ||
|
||
@pytest.fixture() | ||
def inference_network(): | ||
return keras.Sequential([ | ||
keras.layers.Input(shape=(2,)), | ||
keras.layers.Dense(2), | ||
]) | ||
|
||
|
||
@pytest.fixture() | ||
def summary_network(): | ||
return None | ||
|
||
|
||
@pytest.fixture() | ||
def amortizer(inference_network, summary_network): | ||
return bf.Amortizer(inference_network) | ||
|
||
|
||
def test_fit(amortizer, dataset): | ||
# TODO: verify the model learns something? | ||
amortizer.fit(dataset, epochs=2) | ||
|
||
|
||
def test_interrupt_and_resume_fit(tmp_path, amortizer, dataset): | ||
# TODO: check | ||
callbacks = [ | ||
InterruptFitCallback(epochs=1), | ||
keras.callbacks.ModelCheckpoint(tmp_path / "model.keras"), | ||
] | ||
|
||
with pytest.raises(RuntimeError): | ||
# interrupted fit | ||
amortizer.fit(dataset, epochs=2, callbacks=callbacks) | ||
|
||
assert (tmp_path / "model.keras").exists(), "checkpoint has not been created" | ||
|
||
loaded_amortizer = keras.saving.load_model(tmp_path / "model.keras") | ||
|
||
# TODO: verify the fit is actually resumed (and not just started new with the existing weights) | ||
# resume fit | ||
loaded_amortizer.fit(dataset, epochs=2) | ||
|
||
|
||
def test_extended_fit(amortizer, dataset): | ||
# TODO: verify that the model state is used to actually resume the fit | ||
# initial fit | ||
amortizer.fit(dataset, epochs=2) | ||
|
||
# extended fit | ||
amortizer.fit(dataset, epochs=2) | ||
|
||
|
||
def test_save_and_load(tmp_path, amortizer): | ||
amortizer.save(tmp_path / "amortizer.keras") | ||
loaded_amortizer = keras.saving.load_model(tmp_path / "amortizer.keras") | ||
|
||
assert_models_equal(amortizer, loaded_amortizer) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
|
||
import os | ||
os.environ["KERAS_BACKEND"] = "torch" | ||
|
||
import keras | ||
import keras.saving | ||
|
||
import pytest | ||
|
||
import bayesflow.experimental as bf | ||
|
||
from tests.utils import * | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def summary_network(): | ||
# TODO: modularize over data shape | ||
# TODO: add no summary network case | ||
return keras.Sequential([ | ||
keras.layers.Input(shape=(2,)), | ||
keras.layers.Dense(2), | ||
keras.layers.Lambda(lambda x: keras.ops.mean(x, axis=1, keepdims=True)) | ||
]) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def inference_network(): | ||
# TODO: modularize over data shape | ||
class AffineSubnet(keras.Layer): | ||
def __init__(self, in_features, out_features, **kwargs): | ||
super().__init__(**kwargs) | ||
self.network = keras.Sequential([ | ||
keras.layers.Input(shape=(in_features,)), | ||
keras.layers.Dense(out_features), | ||
]) | ||
|
||
def call(self, x): | ||
scale, shift = keras.ops.split(self.network(x), 2, axis=1) | ||
return dict(scale=scale, shift=shift) | ||
|
||
return bf.networks.CouplingFlow.uniform( | ||
subnet_constructor=AffineSubnet, | ||
features=2, | ||
conditions=0, | ||
layers=1, | ||
transform="affine", | ||
base_distribution="normal", | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def model(inference_network, summary_network): | ||
return bf.Amortizer(inference_network, summary_network) | ||
|
||
|
||
def test_save_and_load(tmp_path, model): | ||
path = tmp_path / "model.keras" | ||
model.save(path) | ||
loaded_model = keras.saving.load_model(path) | ||
|
||
assert_models_equal(model, loaded_model) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
|
||
from .assertions import * | ||
from .callbacks import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
import keras | ||
|
||
|
||
def assert_models_equal(model1: keras.Model, model2: keras.Model): | ||
for v1, v2 in zip(model1.variables, model2.variables): | ||
assert keras.ops.all(keras.ops.isclose(v1, v2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
import keras.callbacks | ||
|
||
|
||
class InterruptFitCallback(keras.callbacks.Callback): | ||
def __init__(self, batches=None, epochs=None, error_type=RuntimeError): | ||
super().__init__() | ||
if batches is None and epochs is None: | ||
raise ValueError("Either batches or epochs must be specified.") | ||
|
||
self.batches = batches | ||
self.epochs = epochs | ||
self.error_type = error_type | ||
|
||
def on_train_batch_end(self, batch, logs=None): | ||
if self.batches is not None and self.batches <= batch: | ||
raise self.error_type() | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
if self.epochs is not None and self.epochs <= epoch: | ||
raise self.error_type() |