-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for training and sampling with summary network and offlin…
…e mode
- Loading branch information
Showing
1 changed file
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"id": "initial_id", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"KERAS_BACKEND\"] = \"torch\"" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"import bayesflow as bf\n", | ||
"import keras\n", | ||
"import matplotlib.patches\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np" | ||
], | ||
"id": "48c74bd15ec629c3", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"class GaussianSimulator(bf.simulators.Simulator):\n", | ||
" def __init__(self, dim=2):\n", | ||
" self.dim = dim\n", | ||
" \n", | ||
" def sample(self, batch_shape, num_obs=512):\n", | ||
" mean = np.random.normal(0.0, 2.0, size=batch_shape + (self.dim,))\n", | ||
" std = np.random.uniform(0.5, 2.0, size=batch_shape + (self.dim,)) \n", | ||
" obs = np.random.normal(mean[:, None], std[:, None], size=batch_shape + (num_obs, self.dim))\n", | ||
" \n", | ||
" return dict(mean=mean, std=std, obs=obs)" | ||
], | ||
"id": "dc5ec708d3d99ab6", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": "simulator = GaussianSimulator()", | ||
"id": "e5f5f552f76e6eb9", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", | ||
" inference_variables=[\"mean\", \"std\"],\n", | ||
" summary_variables=[\"obs\"]\n", | ||
")" | ||
], | ||
"id": "a02c67c2e6dee96", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"inference_network = bf.networks.FlowMatching(\n", | ||
" subnet=\"mlp\",\n", | ||
" subnet_kwargs=dict(\n", | ||
" depth=4,\n", | ||
" width=256,\n", | ||
" dropout=None,\n", | ||
" activation=\"relu\"\n", | ||
" ),\n", | ||
")" | ||
], | ||
"id": "5c1a803710507943", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"summary_network = bf.networks.DeepSet(\n", | ||
" summary_dim=32,\n", | ||
" depth=4,\n", | ||
" dropout=None,\n", | ||
" activation=\"relu\",\n", | ||
")" | ||
], | ||
"id": "a7fa52fe4e73478b", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"approximator = bf.ContinuousApproximator(\n", | ||
" data_adapter=data_adapter,\n", | ||
" inference_network=inference_network,\n", | ||
" summary_network=summary_network,\n", | ||
")" | ||
], | ||
"id": "6f36864935015663", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"epochs = 100\n", | ||
"training_batches = 1000\n", | ||
"validation_batches = 10\n", | ||
"batch_size = 256" | ||
], | ||
"id": "753e78d120f6c296", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"learning_rate = keras.optimizers.schedules.CosineDecay(\n", | ||
" initial_learning_rate=1e-4,\n", | ||
" decay_steps=epochs * training_batches,\n", | ||
" alpha=1e-3,\n", | ||
")" | ||
], | ||
"id": "573920445dfa5e64", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"optimizer = keras.optimizers.AdamW(\n", | ||
" learning_rate=learning_rate,\n", | ||
" \n", | ||
" # you can find good values for this by logging the gradient norms\n", | ||
" # we do this with the LogGradientNorm callback (currently only supported in torch)\n", | ||
" global_clipnorm=100,\n", | ||
" \n", | ||
" # use some weight decay since we train offline\n", | ||
" weight_decay=0.01,\n", | ||
")" | ||
], | ||
"id": "aaa2846d72a3bcfe", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"callbacks = [\n", | ||
" keras.callbacks.TensorBoard(log_dir=\"logs\", update_freq=int(0.25 * training_batches)),\n", | ||
"]" | ||
], | ||
"id": "14534267c53cbb84", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"if keras.backend.backend() == \"torch\":\n", | ||
" callbacks.append(bf.callbacks.LogGradientNorm())" | ||
], | ||
"id": "f7d575b83491f51c", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": "approximator.compile(optimizer=optimizer)", | ||
"id": "341a61e8ad3307db", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"training_data = simulator.sample((training_batches,))\n", | ||
"validation_data = simulator.sample((validation_batches,))" | ||
], | ||
"id": "ca767fa2f06be37d", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"training_dataset = bf.datasets.OfflineDataset(\n", | ||
" batch_size=batch_size,\n", | ||
" data=training_data,\n", | ||
" data_adapter=data_adapter,\n", | ||
")" | ||
], | ||
"id": "37983a076f6db16a", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"validation_dataset = bf.datasets.OfflineDataset(\n", | ||
" batch_size=batch_size,\n", | ||
" data=validation_data,\n", | ||
" data_adapter=data_adapter,\n", | ||
")" | ||
], | ||
"id": "a37a1be02273ee4e", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"x = keras.ops.zeros((16, 100, 2))\n", | ||
"summary_network(x).shape" | ||
], | ||
"id": "6cd805134d740c29", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"%load_ext tensorboard\n", | ||
"%tensorboard --logdir logs" | ||
], | ||
"id": "7a15808de5f97768", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"approximator.fit(\n", | ||
" batch_size=batch_size,\n", | ||
" epochs=epochs,\n", | ||
" dataset=training_dataset,\n", | ||
" validation_data=validation_dataset,\n", | ||
" callbacks=callbacks,\n", | ||
")" | ||
], | ||
"id": "fee9ee0a9b958bbf", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"# generate 16 samples, each with 1024 observations\n", | ||
"test_data = simulator.sample((16,), num_obs=1024)\n", | ||
"\n", | ||
"conditions = {\n", | ||
" \"obs\": test_data[\"obs\"]\n", | ||
"}\n", | ||
"\n", | ||
"samples = approximator.sample((16,), data=conditions)" | ||
], | ||
"id": "bd8f018107cfcaab", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"fig, axes = plt.subplots(4, 4, figsize=(16, 16))\n", | ||
"\n", | ||
"for i, ax in enumerate(axes.flat):\n", | ||
" x, y = conditions[\"obs\"][i, :, 0:2].T\n", | ||
" ax.scatter(x, y, s=5, color=\"black\", alpha=0.5, label=\"Samples\")\n", | ||
" \n", | ||
" # predictions in red\n", | ||
" mean, std = samples[\"mean\"][i, 0:2], samples[\"std\"][i, 0:2]\n", | ||
" ax.add_artist(matplotlib.patches.Ellipse(mean, 2 * std[0], 2 * std[1], color=\"red\", fill=False, ls=\":\", lw=3, label=\"Prediction\"))\n", | ||
" \n", | ||
" # true values in green\n", | ||
" mean, std = test_data[\"mean\"][i, 0:2], test_data[\"std\"][i, 0:2]\n", | ||
" ax.add_artist(matplotlib.patches.Ellipse(mean, 2 * std[0], 2 * std[1], color=\"green\", fill=False, ls=\"--\", lw=3, label=\"Ground Truth\"))\n", | ||
" \n", | ||
" ax.set_xlim(-5, 5)\n", | ||
" ax.set_ylim(-5, 5)\n", | ||
" ax.set_aspect(\"equal\", adjustable=\"box\")\n", | ||
"\n", | ||
"plt.legend(loc=\"lower right\")\n", | ||
"plt.show()" | ||
], | ||
"id": "92755ef87fae628c", | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": "", | ||
"id": "8ec13de03538936e", | ||
"outputs": [], | ||
"execution_count": null | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |