From 1472ac77b28bc818c965532b954126c9c9cd58b9 Mon Sep 17 00:00:00 2001 From: YouGuessedMyName Date: Fri, 28 Jun 2024 14:49:16 +0200 Subject: [PATCH] Added action visualization --- notebooks/die.ipynb | 223 +++++++++++++------------------- notebooks/diegraph.html | 56 ++++----- notebooks/monty.html | 56 ++++----- notebooks/monty_hall.ipynb | 245 ++++++++++++++++-------------------- stormvogel/visualization.py | 22 ++-- 5 files changed, 269 insertions(+), 333 deletions(-) diff --git a/notebooks/die.ipynb b/notebooks/die.ipynb index 7137aad..caca1dc 100644 --- a/notebooks/die.ipynb +++ b/notebooks/die.ipynb @@ -1,145 +1,102 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import stormvogel.model\n", - "import stormvogel.visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# Create a new model with the name \"Die\"\n", - "dtmc = stormvogel.model.new_dtmc(\"Die\")\n", - "\n", - "init = dtmc.get_initial_state()\n", - "\n", - "# From the initial state, add the transition to 6 new states with probability 1/6th.\n", - "init.set_transitions(\n", - " [(1 / 6, dtmc.new_state(f\"rolled{i}\", {\"rolled\": i})) for i in range(6)]\n", - ")\n", - "\n", - "# Print the resulting model in dot format.\n", - "# print(dtmc.to_dot())" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "scrolled": true - }, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "diegraph.html\n" - ] + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import stormvogel.model\n", + "import stormvogel.visualization" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a new model with the name \"Die\"\n", + "dtmc = stormvogel.model.new_dtmc(\"Die\")\n", + "\n", + "init = dtmc.get_initial_state()\n", + "\n", + "# From the initial state, add the transition to 6 new states with probability 1/6th.\n", + "init.set_transitions(\n", + " [(1 / 6, dtmc.new_state(f\"rolled{i}\", {\"rolled\": i})) for i in range(6)]\n", + ")\n", + "\n", + "# Print the resulting model in dot format.\n", + "# print(dtmc.to_dot())" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "stormvogel.visualization.show(model=dtmc, name=\"diegraph\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hi.html\n" - ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "diegraph.html\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } ], - "text/plain": [ - "" + "source": [ + "stormvogel.visualization.show(model=dtmc, name=\"diegraph\")" ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } - ], - "source": [ - "from pyvis.network import Network\n", - "\n", - "g = Network(notebook=True,cdn_resources=\"remote\")\n", - "g.show_buttons()\n", - "g.show(\"hi.html\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/notebooks/diegraph.html b/notebooks/diegraph.html index c317554..509d77e 100644 --- a/notebooks/diegraph.html +++ b/notebooks/diegraph.html @@ -1,7 +1,7 @@ - + - - - - - - - - + + + + + + + +

@@ -234,24 +234,24 @@

float: left; } - - - + + +
- - + +
- - + + - \ No newline at end of file + diff --git a/notebooks/monty.html b/notebooks/monty.html index e1628bb..55ce12e 100644 --- a/notebooks/monty.html +++ b/notebooks/monty.html @@ -1,7 +1,7 @@ - + - - - - - - - - + + + + + + + +

@@ -234,24 +234,24 @@

float: left; } - - - + + +
- - + +
- - + + - \ No newline at end of file + diff --git a/notebooks/monty_hall.ipynb b/notebooks/monty_hall.ipynb index 4137d2f..ea92976 100644 --- a/notebooks/monty_hall.ipynb +++ b/notebooks/monty_hall.ipynb @@ -1,145 +1,116 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "bce64a65-ea2a-42b6-adac-0533b7541ca0", - "metadata": {}, - "outputs": [], - "source": [ - "import stormvogel.model\n", - "\n", - "mdp = stormvogel.model.new_mdp(\"Monty Hall\")\n", - "\n", - "init = mdp.get_initial_state()\n", - "\n", - "# first choose car position\n", - "init.set_transitions(\n", - " [(1 / 3, mdp.new_state(\"carchosen\", {\"car_pos\": i})) for i in range(3)]\n", - ")\n", - "\n", - "# we choose a door in each case\n", - "for s in mdp.get_states_with(\"carchosen\"):\n", - " s.set_transitions(\n", - " [\n", - " (\n", - " mdp.action(f\"open{i}\"),\n", - " mdp.new_state(\"open\", s.features | {\"chosen_pos\": i}),\n", - " )\n", - " for i in range(3)\n", - " ]\n", - " )\n", - "\n", - "# the other goat is revealed\n", - "for s in mdp.get_states_with(\"open\"):\n", - " car_pos = s.features[\"car_pos\"]\n", - " chosen_pos = s.features[\"chosen_pos\"]\n", - " other_pos = {0, 1, 2} - {car_pos, chosen_pos}\n", - " s.set_transitions(\n", - " [\n", - " (\n", - " 1 / len(other_pos),\n", - " mdp.new_state(\"goatrevealed\", s.features | {\"reveal_pos\": i}),\n", - " )\n", - " for i in other_pos\n", - " ]\n", - " )\n", - "\n", - "# we must choose whether we want to switch\n", - "for s in mdp.get_states_with(\"goatrevealed\"):\n", - " car_pos = s.features[\"car_pos\"]\n", - " chosen_pos = s.features[\"chosen_pos\"]\n", - " reveal_pos = s.features[\"reveal_pos\"]\n", - " other_pos = list({0, 1, 2} - {reveal_pos, chosen_pos})[0]\n", - " s.set_transitions(\n", - " [\n", - " (\n", - " mdp.action(\"stay\"),\n", - " mdp.new_state(\n", - " [\"done\"] + ([\"target\"] if chosen_pos == car_pos else []),\n", - " s.features | {\"chosen_pos\": chosen_pos},\n", - " ),\n", - " ),\n", - " (\n", - " mdp.action(\"switch\"),\n", - " mdp.new_state(\n", - " [\"done\"] + ([\"target\"] if other_pos == car_pos else []),\n", - " s.features | {\"chosen_pos\": other_pos},\n", - " ),\n", - " ),\n", - " ]\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "02b68164-00de-41b1-ba2b-65e12bef9032", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "monty.html\n" - ] + "cell_type": "code", + "execution_count": null, + "id": "bce64a65-ea2a-42b6-adac-0533b7541ca0", + "metadata": {}, + "outputs": [], + "source": [ + "import stormvogel.model\n", + "\n", + "mdp = stormvogel.model.new_mdp(\"Monty Hall\")\n", + "\n", + "init = mdp.get_initial_state()\n", + "\n", + "# first choose car position\n", + "init.set_transitions(\n", + " [(1 / 3, mdp.new_state(\"carchosen\", {\"car_pos\": i})) for i in range(3)]\n", + ")\n", + "\n", + "# we choose a door in each case\n", + "for s in mdp.get_states_with(\"carchosen\"):\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " mdp.action(f\"open{i}\"),\n", + " mdp.new_state(\"open\", s.features | {\"chosen_pos\": i}),\n", + " )\n", + " for i in range(3)\n", + " ]\n", + " )\n", + "\n", + "# the other goat is revealed\n", + "for s in mdp.get_states_with(\"open\"):\n", + " car_pos = s.features[\"car_pos\"]\n", + " chosen_pos = s.features[\"chosen_pos\"]\n", + " other_pos = {0, 1, 2} - {car_pos, chosen_pos}\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " 1 / len(other_pos),\n", + " mdp.new_state(\"goatrevealed\", s.features | {\"reveal_pos\": i}),\n", + " )\n", + " for i in other_pos\n", + " ]\n", + " )\n", + "\n", + "# we must choose whether we want to switch\n", + "for s in mdp.get_states_with(\"goatrevealed\"):\n", + " car_pos = s.features[\"car_pos\"]\n", + " chosen_pos = s.features[\"chosen_pos\"]\n", + " reveal_pos = s.features[\"reveal_pos\"]\n", + " other_pos = list({0, 1, 2} - {reveal_pos, chosen_pos})[0]\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " mdp.action(\"stay\"),\n", + " mdp.new_state(\n", + " [\"done\"] + ([\"target\"] if chosen_pos == car_pos else []),\n", + " s.features | {\"chosen_pos\": chosen_pos},\n", + " ),\n", + " ),\n", + " (\n", + " mdp.action(\"switch\"),\n", + " mdp.new_state(\n", + " [\"done\"] + ([\"target\"] if other_pos == car_pos else []),\n", + " s.features | {\"chosen_pos\": other_pos},\n", + " ),\n", + " ),\n", + " ]\n", + " )\n" + ] }, { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "id": "02b68164-00de-41b1-ba2b-65e12bef9032", + "metadata": {}, + "outputs": [], + "source": [ + "from stormvogel.visualization import show\n", + "\n", + "show(mdp, \"monty\")" ] - }, - "metadata": {}, - "output_type": "display_data" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ec68bb5-2728-458e-8d11-6015f80d79f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } - ], - "source": [ - "from stormvogel.visualization import show\n", - "\n", - "show(mdp, \"monty\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3ec68bb5-2728-458e-8d11-6015f80d79f7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index 180986f..8b8b14c 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -10,10 +10,12 @@ class Visualization: """Handles visualization of a Model using a pyvis Network.""" + ACTION_ID_OFFSET = 20**6 + def __init__( self, model: Model, - name: str, + name: str = "model", notebook: bool = True, cdn_resources: str = "remote", ) -> None: @@ -21,7 +23,7 @@ def __init__( Args: model (Model): The stormvogel model to be displayed. - name (str): The name of the resulting html file. + name (str): The name of the resulting html file. May or may not include .html extension. notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True. """ self.model = model @@ -43,9 +45,6 @@ def __set_layout(self): "background": "white", "border": "black" } - }, - "edge": { - "color": "blue" } }""") @@ -60,6 +59,7 @@ def __add_nodes(self): label=",".join(state.labels), color=None, # type: ignore borderWidth=borderWidth, + shape="circle", ) def __formatted_prob(self, prob: Number) -> str: @@ -67,6 +67,7 @@ def __formatted_prob(self, prob: Number) -> str: return str(Fraction(prob).limit_denominator(20)) def __add_transitions(self): + no_actions = 0 """For each transition in the model, add a transition in the graph.""" for state_id, transition in self.model.transitions.items(): for action, branch in transition.transition.items(): @@ -80,9 +81,16 @@ def __add_transitions(self): label=self.__formatted_prob(prob), ) else: - raise NotImplementedError( - "Non-empty actions are not supported yet." + self.g.add_node( + n_id=self.ACTION_ID_OFFSET + no_actions, + label=action.name, + shape="box", ) + print(action.name) + no_actions += 1 + # raise NotImplementedError( + # "Non-empty actions are not supported yet." + # ) def show(self): """Show the constructed model"""