Skip to content

Commit

Permalink
Add example for training and sampling with summary network and offlin…
Browse files Browse the repository at this point in the history
…e mode
  • Loading branch information
LarsKue committed Aug 22, 2024
1 parent 4647fec commit a8bec97
Showing 1 changed file with 348 additions and 0 deletions.
348 changes: 348 additions & 0 deletions examples/2024-08-22 Summary Network.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true
},
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"torch\""
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import bayesflow as bf\n",
"import keras\n",
"import matplotlib.patches\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
],
"id": "48c74bd15ec629c3",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"class GaussianSimulator(bf.simulators.Simulator):\n",
" def __init__(self, dim=2):\n",
" self.dim = dim\n",
" \n",
" def sample(self, batch_shape, num_obs=512):\n",
" mean = np.random.normal(0.0, 2.0, size=batch_shape + (self.dim,))\n",
" std = np.random.uniform(0.5, 2.0, size=batch_shape + (self.dim,)) \n",
" obs = np.random.normal(mean[:, None], std[:, None], size=batch_shape + (num_obs, self.dim))\n",
" \n",
" return dict(mean=mean, std=std, obs=obs)"
],
"id": "dc5ec708d3d99ab6",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "simulator = GaussianSimulator()",
"id": "e5f5f552f76e6eb9",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"data_adapter = bf.ContinuousApproximator.build_data_adapter(\n",
" inference_variables=[\"mean\", \"std\"],\n",
" summary_variables=[\"obs\"]\n",
")"
],
"id": "a02c67c2e6dee96",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"inference_network = bf.networks.FlowMatching(\n",
" subnet=\"mlp\",\n",
" subnet_kwargs=dict(\n",
" depth=4,\n",
" width=256,\n",
" dropout=None,\n",
" activation=\"relu\"\n",
" ),\n",
")"
],
"id": "5c1a803710507943",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"summary_network = bf.networks.DeepSet(\n",
" summary_dim=32,\n",
" depth=4,\n",
" dropout=None,\n",
" activation=\"relu\",\n",
")"
],
"id": "a7fa52fe4e73478b",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"approximator = bf.ContinuousApproximator(\n",
" data_adapter=data_adapter,\n",
" inference_network=inference_network,\n",
" summary_network=summary_network,\n",
")"
],
"id": "6f36864935015663",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"epochs = 100\n",
"training_batches = 1000\n",
"validation_batches = 10\n",
"batch_size = 256"
],
"id": "753e78d120f6c296",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"learning_rate = keras.optimizers.schedules.CosineDecay(\n",
" initial_learning_rate=1e-4,\n",
" decay_steps=epochs * training_batches,\n",
" alpha=1e-3,\n",
")"
],
"id": "573920445dfa5e64",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"optimizer = keras.optimizers.AdamW(\n",
" learning_rate=learning_rate,\n",
" \n",
" # you can find good values for this by logging the gradient norms\n",
" # we do this with the LogGradientNorm callback (currently only supported in torch)\n",
" global_clipnorm=100,\n",
" \n",
" # use some weight decay since we train offline\n",
" weight_decay=0.01,\n",
")"
],
"id": "aaa2846d72a3bcfe",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"callbacks = [\n",
" keras.callbacks.TensorBoard(log_dir=\"logs\", update_freq=int(0.25 * training_batches)),\n",
"]"
],
"id": "14534267c53cbb84",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"if keras.backend.backend() == \"torch\":\n",
" callbacks.append(bf.callbacks.LogGradientNorm())"
],
"id": "f7d575b83491f51c",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "approximator.compile(optimizer=optimizer)",
"id": "341a61e8ad3307db",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"training_data = simulator.sample((training_batches,))\n",
"validation_data = simulator.sample((validation_batches,))"
],
"id": "ca767fa2f06be37d",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"training_dataset = bf.datasets.OfflineDataset(\n",
" batch_size=batch_size,\n",
" data=training_data,\n",
" data_adapter=data_adapter,\n",
")"
],
"id": "37983a076f6db16a",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"validation_dataset = bf.datasets.OfflineDataset(\n",
" batch_size=batch_size,\n",
" data=validation_data,\n",
" data_adapter=data_adapter,\n",
")"
],
"id": "a37a1be02273ee4e",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"x = keras.ops.zeros((16, 100, 2))\n",
"summary_network(x).shape"
],
"id": "6cd805134d740c29",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir logs"
],
"id": "7a15808de5f97768",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"approximator.fit(\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" dataset=training_dataset,\n",
" validation_data=validation_dataset,\n",
" callbacks=callbacks,\n",
")"
],
"id": "fee9ee0a9b958bbf",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"# generate 16 samples, each with 1024 observations\n",
"test_data = simulator.sample((16,), num_obs=1024)\n",
"\n",
"conditions = {\n",
" \"obs\": test_data[\"obs\"]\n",
"}\n",
"\n",
"samples = approximator.sample((16,), data=conditions)"
],
"id": "bd8f018107cfcaab",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"fig, axes = plt.subplots(4, 4, figsize=(16, 16))\n",
"\n",
"for i, ax in enumerate(axes.flat):\n",
" x, y = conditions[\"obs\"][i, :, 0:2].T\n",
" ax.scatter(x, y, s=5, color=\"black\", alpha=0.5, label=\"Samples\")\n",
" \n",
" # predictions in red\n",
" mean, std = samples[\"mean\"][i, 0:2], samples[\"std\"][i, 0:2]\n",
" ax.add_artist(matplotlib.patches.Ellipse(mean, 2 * std[0], 2 * std[1], color=\"red\", fill=False, ls=\":\", lw=3, label=\"Prediction\"))\n",
" \n",
" # true values in green\n",
" mean, std = test_data[\"mean\"][i, 0:2], test_data[\"std\"][i, 0:2]\n",
" ax.add_artist(matplotlib.patches.Ellipse(mean, 2 * std[0], 2 * std[1], color=\"green\", fill=False, ls=\"--\", lw=3, label=\"Ground Truth\"))\n",
" \n",
" ax.set_xlim(-5, 5)\n",
" ax.set_ylim(-5, 5)\n",
" ax.set_aspect(\"equal\", adjustable=\"box\")\n",
"\n",
"plt.legend(loc=\"lower right\")\n",
"plt.show()"
],
"id": "92755ef87fae628c",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "8ec13de03538936e",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit a8bec97

Please sign in to comment.