Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ContinuousApproximator.sample() fails without previous adapter calls (e.g., when loading data) #255

Open
elseml opened this issue Nov 21, 2024 · 13 comments
Milestone

Comments

@elseml
Copy link
Member

elseml commented Nov 21, 2024

I noticed that after switching from generating bf.datasets on-the-fly to loading pre-simulated data, ContinuousApproximator.sample() fails since the adapter is not called before sampling anymore. Concretely, in line 141 of continuous_approximator.py, the adapter is called with strict=False to process the observed data (and not require parameter keys while doing so):

conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) 

This raises the following error in the adapters forward() method when working with loaded data:

"ValueError: Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.". 

The error is easily fixed by manually calling the adapter on the data before sampling, but of course unexpected for the user and should therefore be handled internally.
@LarsKue @stefanradev93: what do you think would be a principled handling of this behavior?

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Nov 21, 2024

Based on how I understand what you are doing, I agree with you that this should be differently handled. Just to make sure I understand you correctly, could you add a small example here that (only) includes the relevant code parts?

@paul-buerkner paul-buerkner added user interface Changes to the user interface and improvements in usability v2 labels Nov 21, 2024
@paul-buerkner paul-buerkner added this to the BayesFlow 2.0 milestone Nov 21, 2024
@paul-buerkner paul-buerkner removed the v2 label Nov 21, 2024
@elseml
Copy link
Member Author

elseml commented Nov 21, 2024

I looked further into the issue, as far as I can see it is caused by the OfflineDataset and approximator no longer referring to the same adapter object in memory:

  • When creating the OfflineDataset right before training, the adapter is already called during approximator.fit() via OfflineDataset.__getitem__ .
  • When loading pre-simulated data, the adapter passed to OfflineDataset does not longer refer to the same adapter in memory that the approximator uses. Thus, approximator.adapter is not called during training, only OfflineDataset.adapter -> sampling fails.

Here is some reduced pseudocode to keep things concise:

Simulating at the beginning does not fail:

adapter = Adapter()
data = OfflineDataset(simulate(), adapter)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
approximator.sample(data)

When the data is loaded from an external source (where the adapter was also supplied to OfflineDataset), sampling fails:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
approximator.sample(data)

Calling the adapter manually before sampling fixes the error:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
_ = adapter(data)
approximator.sample(data)

Creating data manually before sampling does not fix it (i.e., simply creating an OfflineDataset) since the adapter is not called during OfflineDataset construction:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
data_2 = OfflineDataset(simulate(), adapter)
approximator.sample(data_2)

@paul-buerkner
Copy link
Contributor

Thank you! This is very helpful! @LarsKue and @stefanradev93 what are your takes on how to fix this?

@paul-buerkner paul-buerkner added bug and removed user interface Changes to the user interface and improvements in usability labels Nov 21, 2024
@elseml
Copy link
Member Author

elseml commented Nov 21, 2024

Indeed, when passing OfflineDataset.adapter to the approximator, the error is gone (so it is not really a bug but more of an unexpected behavior). But this is a rather unintuitive solution for users that should not be required.

data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, data.adapter)
approximator.fit(data)
approximator.sample(data)

@paul-buerkner
Copy link
Contributor

It will appear to users as a bug because it should just work. In any case, we should fix it before 2.0 release.

@LarsKue
Copy link
Contributor

LarsKue commented Nov 21, 2024

Could be faulty serialization in the Adapter. I will investigate next week.

@paul-buerkner
Copy link
Contributor

@LarsKue Is this issue fixed already?

@LarsKue
Copy link
Contributor

LarsKue commented Feb 12, 2025

Thanks for the bump. I fail to see the issue, or it is not reproducible for me. Consider the following working snippet:

import os
os.environ["KERAS_BACKEND"] = "torch"

import bayesflow as bf
import keras
import numpy as np

data = {
    "x": np.random.standard_normal(size=(32, 2)),
    "theta": np.random.standard_normal(size=(32, 2)),
}

adapter = bf.Adapter()
adapter.to_array()
adapter.rename("x", "inference_variables")
adapter.rename("theta", "inference_conditions")

dataset = bf.OfflineDataset(data, batch_size=2, adapter=adapter)

inference_network = bf.networks.FlowMatching()

approximator = bf.ContinuousApproximator(adapter=adapter, inference_network=inference_network)

approximator.compile(optimizer="adam")
approximator.build_from_data(
    keras.tree.map_structure(keras.ops.convert_to_tensor, dataset[0])
)

# optional: approximator.fit(...)

conditions = {"inference_conditions": dataset[0]["inference_conditions"]}

samples = approximator.sample(num_samples=32, conditions=conditions)

approximator.save("m.keras")

# later:
approximator = keras.saving.load_model("m.keras")

# generate new data
conditions = {
    "theta": np.random.standard_normal(size=(32, 2)),
}

# uses the existing adapter under the hood
samples = approximator.sample(num_samples=32, conditions=conditions)

@LarsKue LarsKue removed the bug label Feb 12, 2025
@paul-buerkner
Copy link
Contributor

I will close this issue for now @elseml feel free to reopen it if it occurs again.

@elseml
Copy link
Member Author

elseml commented Feb 13, 2025

Thanks for looking into this. The issue does not relate to model loading but data loading situations and will be relevant for offline training workflows.

Here, a standard approach would be to simulate some data with script A and train a network with it in a separate script B. As I wrote above, the issue is not a technical bug but rather an unexpected behavior: Intuitively, users might define the same adapter at the start of each script. Then, the adapter in script B is first called when sampling, raising the ValueError. The error does not occur in script B if the user reuses the adapter from script A via data.adapter, but this is (at least for me) not super intuitive.

@paul-buerkner could you reopen the issue?
@LarsKue do you see a way we can relax the requirements here to get to a "should just work" solution?

Your snippet can be modified as follows to reproduce the error:

Script A:

import os
os.environ["KERAS_BACKEND"] = "torch"
import bayesflow as bf
import numpy as np
import pickle

data = {
    "x": np.random.standard_normal(size=(32, 2)),
    "theta": np.random.standard_normal(size=(32, 2)),
}

adapter = bf.Adapter()
adapter.to_array()
adapter.rename("x", "inference_variables")
adapter.rename("theta", "inference_conditions")

dataset = bf.OfflineDataset(data, batch_size=2, adapter=adapter)

with open("test_dataset.pkl", "wb") as f:
    pickle.dump(dataset, f)

Script B:

import os
os.environ["KERAS_BACKEND"] = "torch"
import bayesflow as bf
import keras
import numpy as np
import pickle

with open("test_dataset.pkl", "rb") as f:
    dataset = pickle.load(f)
inference_network = bf.networks.FlowMatching()

adapter = bf.Adapter()
adapter.to_array()
adapter.rename("x", "inference_variables")
adapter.rename("theta", "inference_conditions")

approximator = bf.ContinuousApproximator(adapter=adapter, inference_network=inference_network)

approximator.compile(optimizer="adam")
approximator.build_from_data(
    keras.tree.map_structure(keras.ops.convert_to_tensor, dataset[0])
)

# optional: approximator.fit(...)

conditions = {"inference_conditions": dataset[0]["inference_conditions"]}

samples = approximator.sample(num_samples=32, conditions=conditions)

approximator.save("m.keras")

# later:
approximator = keras.saving.load_model("m.keras")

# generate new data
conditions = {
    "theta": np.random.standard_normal(size=(32, 2)),
}

# uses the existing adapter under the hood
samples = approximator.sample(num_samples=32, conditions=conditions)

This code raises a RuntimeError when calling the adapter with inverse=True :

RuntimeError: Cannot call `inverse` before calling `forward` at least once.

The specific ValueError reported above occured for an adapter using the common concatenate() transform, in which case the error is raised 4 lines earlier when calling the adapter with strict=False:

ValueError: Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.

@LarsKue
Copy link
Contributor

LarsKue commented Feb 13, 2025

@elseml Why not save your raw data to file rather than wrapping in OfflineDataset? The dataset classes are intended to be used only (or at least, primarily) for training, so it seems suboptimal to just pickle them instead of saving your simulated samples with numpy or similar libraries.

@paul-buerkner paul-buerkner reopened this Feb 13, 2025
@elseml
Copy link
Member Author

elseml commented Feb 13, 2025

I agree that creating the OfflineDataset object after loading raw data is optimal here. It was not obvious to me that other approaches fail, so I think it would be good to explicitly inform users about it. What about extending the OfflineDataset docstring (currently only "A dataset that is pre-simulated and stored in memory.") with this: "When storing and loading data from disk, it is recommended to save any pre-simulated data in raw form and create the OfflineDataset object only after loading in the raw data."?

@LarsKue
Copy link
Contributor

LarsKue commented Feb 13, 2025

@elseml Yes, I think this would be a good addition. We could even hint at the intended use of DiskDataset here.

elseml added a commit to elseml/BayesFlow that referenced this issue Feb 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants