From a000256a82826065325c8a187b8204ae302cb3e4 Mon Sep 17 00:00:00 2001 From: Ivo Melse Date: Wed, 26 Jun 2024 13:20:03 +0200 Subject: [PATCH] Start with visualization --- notebooks/die.ipynb | 90 ++++++++- notebooks/diegraph.html | 321 ++++++++++++++++++++++++++++++++ notebooks/hi.html | 361 ++++++++++++++++++++++++++++++++++++ notebooks/monty.html | 321 ++++++++++++++++++++++++++++++++ notebooks/monty_hall.ipynb | 131 +++++++++++++ stormvogel/visualization.py | 76 +++++++- 6 files changed, 1286 insertions(+), 14 deletions(-) create mode 100644 notebooks/diegraph.html create mode 100644 notebooks/hi.html create mode 100644 notebooks/monty.html create mode 100644 notebooks/monty_hall.ipynb diff --git a/notebooks/die.ipynb b/notebooks/die.ipynb index 9822be1..7137aad 100644 --- a/notebooks/die.ipynb +++ b/notebooks/die.ipynb @@ -6,41 +6,111 @@ "metadata": {}, "outputs": [], "source": [ - "from stormvogel.visualization import make_slider" + "import stormvogel.model\n", + "import stormvogel.visualization" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, + "outputs": [], + "source": [ + "# Create a new model with the name \"Die\"\n", + "dtmc = stormvogel.model.new_dtmc(\"Die\")\n", + "\n", + "init = dtmc.get_initial_state()\n", + "\n", + "# From the initial state, add the transition to 6 new states with probability 1/6th.\n", + "init.set_transitions(\n", + " [(1 / 6, dtmc.new_state(f\"rolled{i}\", {\"rolled\": i})) for i in range(6)]\n", + ")\n", + "\n", + "# Print the resulting model in dot format.\n", + "# print(dtmc.to_dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "diegraph.html\n" + ] + }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9caed9de44144d50b26a7e76ef0accb8", - "version_major": 2, - "version_minor": 0 - }, + "text/html": [ + "\n", + " \n", + " " + ], "text/plain": [ - "interactive(children=(IntSlider(value=10, description='x', max=30, min=-10), Output()), _dom_classes=('widget-…" + "" ] }, "metadata": {}, "output_type": "display_data" + } + ], + "source": [ + "stormvogel.visualization.show(model=dtmc, name=\"diegraph\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hi.html\n" + ] }, { "data": { + "text/html": [ + "\n", + " \n", + " " + ], "text/plain": [ - ".(x)>" + "" ] }, - "execution_count": 2, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "make_slider()" + "from pyvis.network import Network\n", + "\n", + "g = Network(notebook=True,cdn_resources=\"remote\")\n", + "g.show_buttons()\n", + "g.show(\"hi.html\")" ] }, { diff --git a/notebooks/diegraph.html b/notebooks/diegraph.html new file mode 100644 index 0000000..c317554 --- /dev/null +++ b/notebooks/diegraph.html @@ -0,0 +1,321 @@ + + + + + + + + + + + + + + + +
+

+
+ + + + + + +
+

+
+ + + + + +
+ + +
+
+ + + + + + + \ No newline at end of file diff --git a/notebooks/hi.html b/notebooks/hi.html new file mode 100644 index 0000000..cb4eb42 --- /dev/null +++ b/notebooks/hi.html @@ -0,0 +1,361 @@ + + + + + + + + + + + + + + + +
+

+
+ + + + + + +
+

+
+ + + + + +
+ + +
+
+ + + +
+ + + + + \ No newline at end of file diff --git a/notebooks/monty.html b/notebooks/monty.html new file mode 100644 index 0000000..e1628bb --- /dev/null +++ b/notebooks/monty.html @@ -0,0 +1,321 @@ + + + + + + + + + + + + + + + +
+

+
+ + + + + + +
+

+
+ + + + + +
+ + +
+
+ + + + + + + \ No newline at end of file diff --git a/notebooks/monty_hall.ipynb b/notebooks/monty_hall.ipynb new file mode 100644 index 0000000..4be0cc9 --- /dev/null +++ b/notebooks/monty_hall.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "bce64a65-ea2a-42b6-adac-0533b7541ca0", + "metadata": {}, + "outputs": [], + "source": [ + "import stormvogel.model\n", + "\n", + "mdp = stormvogel.model.new_mdp(\"Monty Hall\")\n", + "\n", + "init = mdp.get_initial_state()\n", + "\n", + "# first choose car position\n", + "init.set_transitions(\n", + " [(1 / 3, mdp.new_state(\"carchosen\", {\"car_pos\": i})) for i in range(3)]\n", + ")\n", + "\n", + "# we choose a door in each case\n", + "for s in mdp.get_states_with(\"carchosen\"):\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " mdp.action(f\"open{i}\"),\n", + " mdp.new_state(\"open\", s.features | {\"chosen_pos\": i}),\n", + " )\n", + " for i in range(3)\n", + " ]\n", + " )\n", + "\n", + "# the other goat is revealed\n", + "for s in mdp.get_states_with(\"open\"):\n", + " car_pos = s.features[\"car_pos\"]\n", + " chosen_pos = s.features[\"chosen_pos\"]\n", + " other_pos = {0, 1, 2} - {car_pos, chosen_pos}\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " 1 / len(other_pos),\n", + " mdp.new_state(\"goatrevealed\", s.features | {\"reveal_pos\": i}),\n", + " )\n", + " for i in other_pos\n", + " ]\n", + " )\n", + "\n", + "# we must choose whether we want to switch\n", + "for s in mdp.get_states_with(\"goatrevealed\"):\n", + " car_pos = s.features[\"car_pos\"]\n", + " chosen_pos = s.features[\"chosen_pos\"]\n", + " reveal_pos = s.features[\"reveal_pos\"]\n", + " other_pos = list({0, 1, 2} - {reveal_pos, chosen_pos})[0]\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " mdp.action(\"stay\"),\n", + " mdp.new_state(\n", + " [\"done\"] + ([\"target\"] if chosen_pos == car_pos else []),\n", + " s.features | {\"chosen_pos\": chosen_pos},\n", + " ),\n", + " ),\n", + " (\n", + " mdp.action(\"switch\"),\n", + " mdp.new_state(\n", + " [\"done\"] + ([\"target\"] if other_pos == car_pos else []),\n", + " s.features | {\"chosen_pos\": other_pos},\n", + " ),\n", + " ),\n", + " ]\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "02b68164-00de-41b1-ba2b-65e12bef9032", + "metadata": {}, + "outputs": [ + { + "ename": "NotImplementedError", + "evalue": "Non-empty actions are not supported yet.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstormvogel\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvisualization\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m show\n\u001b[0;32m----> 3\u001b[0m \u001b[43mshow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmdp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmonty\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/git/stormvogel/.venv/lib/python3.11/site-packages/stormvogel/visualization.py:75\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(model, name, notebook)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mshow\u001b[39m(model: Model, name: \u001b[38;5;28mstr\u001b[39m, notebook: \u001b[38;5;28mbool\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 68\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create visualization of a Model using a pyvis Network\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \n\u001b[1;32m 70\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True.\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m vis \u001b[38;5;241m=\u001b[39m \u001b[43mVisualization\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnotebook\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 76\u001b[0m vis\u001b[38;5;241m.\u001b[39mshow()\n", + "File \u001b[0;32m~/git/stormvogel/.venv/lib/python3.11/site-packages/stormvogel/visualization.py:24\u001b[0m, in \u001b[0;36mVisualization.__init__\u001b[0;34m(self, model, name, notebook, cdn_resources)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mg \u001b[38;5;241m=\u001b[39m Network(notebook\u001b[38;5;241m=\u001b[39mnotebook, directed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, cdn_resources\u001b[38;5;241m=\u001b[39mcdn_resources)\n\u001b[1;32m 23\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__add_nodes()\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__add_transitions\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__set_layout()\n", + "File \u001b[0;32m~/git/stormvogel/.venv/lib/python3.11/site-packages/stormvogel/visualization.py:62\u001b[0m, in \u001b[0;36mVisualization.__add_transitions\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mg\u001b[38;5;241m.\u001b[39madd_edge(state_id, target\u001b[38;5;241m.\u001b[39mid, color\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mred\u001b[39m\u001b[38;5;124m\"\u001b[39m, label\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__formatted_prob(prob))\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 62\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNon-empty actions are not supported yet.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mNotImplementedError\u001b[0m: Non-empty actions are not supported yet." + ] + } + ], + "source": [ + "from stormvogel.visualization import show\n", + "\n", + "show(mdp, \"monty\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ec68bb5-2728-458e-8d11-6015f80d79f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index 26fb39a..cd7d39d 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -1,12 +1,80 @@ """Contains stuff for visualization""" - -from stormvogel.model import Model +from pyvis.network import Network +from stormvogel.model import Model, EmptyAction from ipywidgets import interact +from IPython.display import display +from fractions import Fraction + +class Visualization: + """Handles visualization of a Model using a pyvis Network.""" + def __init__(self, model: Model, name: str, notebook: bool=True, cdn_resources: str="remote") -> None: + """Create visualization of a Model using a pyvis Network + + Args: + model (Model): The stormvogel model to be displayed. + name (str): The name of the resulting html file. + notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True. + """ + self.model = model + if name[-5:] != ".html": # We do not require the user to explicitly type .html in their names. + name += ".html" + self.name = name + self.g = Network(notebook=notebook, directed=True, cdn_resources=cdn_resources) + self.__add_nodes() + self.__add_transitions() + self.__set_layout() + + def __set_layout(self): + self.g.set_options(""" +var options = { + "nodes": { + "color": { + "background": "white", + "border": "black" + } + }, + "edge": { + "color": "blue" + } +}""") + + def __add_nodes(self): + """For each state in the model, add a node to the graph.""" + for state in self.model.states.values(): + borderWidth = 1 + if state == self.model.get_initial_state(): + borderWidth = 3 + self.g.add_node(state.id, label=",".join(state.labels), color=None, borderWidth=borderWidth) + + def __formatted_prob(self, prob: float) -> str: + """Take a probability value and format it nicely""" + return str(Fraction(prob).limit_denominator(20)) + + def __add_transitions(self): + """For each transition in the model, add a transition in the graph.""" + for state_id, transition in self.model.transitions.items(): + for action, branch in transition.transition.items(): + if action == EmptyAction: + # Only draw probabilities + for prob, target in branch.branch: + self.g.add_edge(state_id, target.id, color="red", label=self.__formatted_prob(prob)) + else: + raise NotImplementedError("Non-empty actions are not supported yet.") + def show(self): + """Show the constructed model""" + display(self.g.show(name=self.name)) -def show(m: Model): - pass +def show(model: Model, name: str, notebook: bool=True): + """Create visualization of a Model using a pyvis Network + Args: + model (Model): The stormvogel model to be displayed. + name (str): The name of the resulting html file. + notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True. + """ + vis = Visualization(model, name, notebook) + vis.show() def make_slider(): return interact(lambda x: x, x=10)