Skip to content

Commit

Permalink
saving + test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Apr 30, 2024
1 parent e0969c2 commit 9e3cbf5
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 19 deletions.
10 changes: 1 addition & 9 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,4 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from . import (
amortizers,
default_settings,
diagnostics,
losses,
networks,
sensitivity,
trainers,
)
# TODO: reintroduce imports
16 changes: 8 additions & 8 deletions bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

from . import (
amortizers,
datasets,
diagnostics,
networks,
simulation,
)
# from . import (
# amortizers,
# datasets,
# diagnostics,
# networks,
# simulation,
#)

from .amortizers import (
AmortizedLikelihood,
Expand All @@ -15,6 +15,6 @@

from .simulation import(
distribution,
GenerativeModel,
GenerativeModel as JointDistribution,
)

22 changes: 20 additions & 2 deletions bayesflow/experimental/amortizers/amortizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@

import keras
import keras.saving


@keras.saving.register_keras_serializable(package="bayesflow.amortizers")
class Amortizer(keras.Model):
def __init__(self, inference_network, summary_network=None):
super().__init__()
def __init__(self, inference_network, summary_network=None, **kwargs):
super().__init__(**kwargs)
self.inference_network = inference_network
self.summary_network = summary_network

@classmethod
def from_config(cls, config, custom_objects=None):
inference_network = keras.saving.deserialize_keras_object(config.pop("inference_network"), custom_objects)
summary_network = keras.saving.deserialize_keras_object(config.pop("summary_network"), custom_objects)
return cls(inference_network, summary_network, **config)

def get_config(self):
base_config = super().get_config()

config = {
"inference_network": keras.saving.serialize_keras_object(self.inference_network),
"summary_network": keras.saving.serialize_keras_object(self.summary_network),
}

return base_config | config

def build(self, input_shape):
if self.summary_network is not None:
self.summary_network.build(input_shape)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_experimental/test_amortizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

import keras
import keras.saving
import pytest

import bayesflow.experimental as bf

from tests.utils import *


# TODO:
# current problems & TODOs:
# - keras.utils.PyDataset does not exist under torch (jax?) backend ==> workaround?
# - implement git workflow to test each backend
# - running pytest ignores the backend env variable -

@pytest.fixture()
def inference_network():
return keras.Sequential([
keras.layers.Input(shape=(2,)),
keras.layers.Dense(2),
])


@pytest.fixture()
def summary_network():
return None


@pytest.fixture()
def amortizer(inference_network, summary_network):
return bf.Amortizer(inference_network)


def test_fit(amortizer, dataset):
# TODO: verify the model learns something?
amortizer.fit(dataset, epochs=2)


def test_interrupt_and_resume_fit(tmp_path, amortizer, dataset):
# TODO: check
callbacks = [
InterruptFitCallback(epochs=1),
keras.callbacks.ModelCheckpoint(tmp_path / "model.keras"),
]

with pytest.raises(RuntimeError):
# interrupted fit
amortizer.fit(dataset, epochs=2, callbacks=callbacks)

assert (tmp_path / "model.keras").exists(), "checkpoint has not been created"

loaded_amortizer = keras.saving.load_model(tmp_path / "model.keras")

# TODO: verify the fit is actually resumed (and not just started new with the existing weights)
# resume fit
loaded_amortizer.fit(dataset, epochs=2)


def test_extended_fit(amortizer, dataset):
# TODO: verify that the model state is used to actually resume the fit
# initial fit
amortizer.fit(dataset, epochs=2)

# extended fit
amortizer.fit(dataset, epochs=2)


def test_save_and_load(tmp_path, amortizer):
amortizer.save(tmp_path / "amortizer.keras")
loaded_amortizer = keras.saving.load_model(tmp_path / "amortizer.keras")

assert_models_equal(amortizer, loaded_amortizer)
61 changes: 61 additions & 0 deletions tests/test_experimental/test_amortizers/test_saving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

import os
os.environ["KERAS_BACKEND"] = "torch"

import keras
import keras.saving

import pytest

import bayesflow.experimental as bf

from tests.utils import *


@pytest.fixture(scope="module")
def summary_network():
# TODO: modularize over data shape
# TODO: add no summary network case
return keras.Sequential([
keras.layers.Input(shape=(2,)),
keras.layers.Dense(2),
keras.layers.Lambda(lambda x: keras.ops.mean(x, axis=1, keepdims=True))
])


@pytest.fixture(scope="module")
def inference_network():
# TODO: modularize over data shape
class AffineSubnet(keras.Layer):
def __init__(self, in_features, out_features, **kwargs):
super().__init__(**kwargs)
self.network = keras.Sequential([
keras.layers.Input(shape=(in_features,)),
keras.layers.Dense(out_features),
])

def call(self, x):
scale, shift = keras.ops.split(self.network(x), 2, axis=1)
return dict(scale=scale, shift=shift)

return bf.networks.CouplingFlow.uniform(
subnet_constructor=AffineSubnet,
features=2,
conditions=0,
layers=1,
transform="affine",
base_distribution="normal",
)


@pytest.fixture(scope="module")
def model(inference_network, summary_network):
return bf.Amortizer(inference_network, summary_network)


def test_save_and_load(tmp_path, model):
path = tmp_path / "model.keras"
model.save(path)
loaded_model = keras.saving.load_model(path)

assert_models_equal(model, loaded_model)
Empty file.
Empty file.
3 changes: 3 additions & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

from .assertions import *
from .callbacks import *
7 changes: 7 additions & 0 deletions tests/utils/assertions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

import keras


def assert_models_equal(model1: keras.Model, model2: keras.Model):
for v1, v2 in zip(model1.variables, model2.variables):
assert keras.ops.all(keras.ops.isclose(v1, v2))
21 changes: 21 additions & 0 deletions tests/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

import keras.callbacks


class InterruptFitCallback(keras.callbacks.Callback):
def __init__(self, batches=None, epochs=None, error_type=RuntimeError):
super().__init__()
if batches is None and epochs is None:
raise ValueError("Either batches or epochs must be specified.")

self.batches = batches
self.epochs = epochs
self.error_type = error_type

def on_train_batch_end(self, batch, logs=None):
if self.batches is not None and self.batches <= batch:
raise self.error_type()

def on_epoch_end(self, epoch, logs=None):
if self.epochs is not None and self.epochs <= epoch:
raise self.error_type()

0 comments on commit 9e3cbf5

Please sign in to comment.