From 6acd405b13f565ae1474ba68868edc0df0609f88 Mon Sep 17 00:00:00 2001 From: YouGuessedMyName Date: Tue, 2 Jul 2024 15:19:43 +0200 Subject: [PATCH] Json layouts working --- notebooks/custom_layout.json | 12 +- notebooks/die.ipynb | 47 ++--- notebooks/diegraph.html | 321 -------------------------------- notebooks/groups.html | 167 ----------------- notebooks/layouts.ipynb | 39 +++- notebooks/model.html | 321 -------------------------------- notebooks/monty.html | 321 -------------------------------- notebooks/monty_hall.ipynb | 2 +- notebooks/study.ipynb | 71 +++++-- stormvogel/layout.py | 90 ++++++--- stormvogel/layouts/default.json | 20 +- stormvogel/visualization.py | 95 ++++++---- 12 files changed, 263 insertions(+), 1243 deletions(-) delete mode 100644 notebooks/diegraph.html delete mode 100644 notebooks/groups.html delete mode 100644 notebooks/model.html delete mode 100644 notebooks/monty.html diff --git a/notebooks/custom_layout.json b/notebooks/custom_layout.json index e2922bb..d29c056 100644 --- a/notebooks/custom_layout.json +++ b/notebooks/custom_layout.json @@ -1,3 +1,13 @@ { - "color": "red" + "nodes": { + "color": { + "background": "red", + "border": "orange" + } + }, + "init": {}, + "rounding": { + "fractions": "False", + "digits": 2 + } } diff --git a/notebooks/die.ipynb b/notebooks/die.ipynb index caca1dc..d138da2 100644 --- a/notebooks/die.ipynb +++ b/notebooks/die.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -12,36 +12,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, - "outputs": [], - "source": [ - "# Create a new model with the name \"Die\"\n", - "dtmc = stormvogel.model.new_dtmc(\"Die\")\n", - "\n", - "init = dtmc.get_initial_state()\n", - "\n", - "# From the initial state, add the transition to 6 new states with probability 1/6th.\n", - "init.set_transitions(\n", - " [(1 / 6, dtmc.new_state(f\"rolled{i}\", {\"rolled\": i})) for i in range(6)]\n", - ")\n", - "\n", - "# Print the resulting model in dot format.\n", - "# print(dtmc.to_dot())" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "scrolled": true - }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "diegraph.html\n" + "model.html\n" ] }, { @@ -51,7 +29,7 @@ " " + "" ] }, "metadata": {}, @@ -67,7 +45,20 @@ } ], "source": [ - "stormvogel.visualization.show(model=dtmc, name=\"diegraph\")" + "# Create a new model with the name \"Die\"\n", + "dtmc = stormvogel.model.new_dtmc(\"Die\")\n", + "\n", + "init = dtmc.get_initial_state()\n", + "\n", + "# From the initial state, add the transition to 6 new states with probability 1/6th.\n", + "init.set_transitions(\n", + " [(1 / 6, dtmc.new_state(f\"rolled{i}\", {\"rolled\": i})) for i in range(6)]\n", + ")\n", + "\n", + "stormvogel.visualization.show(dtmc)\n", + "\n", + "# Print the resulting model in dot format.\n", + "# print(dtmc.to_dot())" ] }, { diff --git a/notebooks/diegraph.html b/notebooks/diegraph.html deleted file mode 100644 index 509d77e..0000000 --- a/notebooks/diegraph.html +++ /dev/null @@ -1,321 +0,0 @@ - - - - - - - - - - - - - - - -
-

-
- - - - - - -
-

-
- - - - - -
- - -
-
- - - - - - - diff --git a/notebooks/groups.html b/notebooks/groups.html deleted file mode 100644 index ca55f3d..0000000 --- a/notebooks/groups.html +++ /dev/null @@ -1,167 +0,0 @@ - - - - - - - - - -
-

-
- - - - - - -
-

-
- - - - - -
- - -
-
- - - -
- - - - - diff --git a/notebooks/layouts.ipynb b/notebooks/layouts.ipynb index 2bff512..637b767 100644 --- a/notebooks/layouts.ipynb +++ b/notebooks/layouts.ipynb @@ -5,7 +5,15 @@ "execution_count": 1, "id": "e0d2c1f7-3a10-4f30-a4d1-2d446de06487", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'nodes': {'color': {'background': 'white', 'border': 'black'}}, 'init': {'borderWidth': 3, 'color': 'green', 'shape': 'circle'}, 'states': {'borderWidth': 1, 'color': 'None', 'shape': 'circle'}}\n" + ] + } + ], "source": [ "import stormvogel.layout\n", "from stormvogel.layout import DEFAULT" @@ -13,15 +21,40 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "8c98f637-ed9c-430b-b414-09c665b307f8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'nodes': {'color': {'background': 'red', 'border': 'orange'}}, 'init': {'borderWidth': 3, 'color': 'green', 'shape': 'circle'}, 'states': {'borderWidth': 1, 'color': 'None', 'shape': 'circle'}}\n", + "{'nodes': {'color': {'background': 'red', 'border': 'orange'}}, 'init': {'borderWidth': 3, 'color': 'green', 'shape': 'circle'}, 'states': {'borderWidth': 1, 'color': 'None', 'shape': 'circle'}}\n" + ] + } + ], "source": [ "# Import custom layout\n", "l1 = stormvogel.layout.Layout(\"custom_layout.json\")\n", "l2 = stormvogel.layout.Layout(path=\"/home/ivo/git/stormvogel/notebooks/custom_layout.json\", path_relative=False)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c0b4979-fd00-4393-bf8a-def6d8361ec8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c44a1ec0-1d94-431a-b62f-a32df6d35884", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/model.html b/notebooks/model.html deleted file mode 100644 index e52367d..0000000 --- a/notebooks/model.html +++ /dev/null @@ -1,321 +0,0 @@ - - - - - - - - - - - - - - - -
-

-
- - - - - - -
-

-
- - - - - -
- - -
-
- - - - - - - diff --git a/notebooks/monty.html b/notebooks/monty.html deleted file mode 100644 index e502921..0000000 --- a/notebooks/monty.html +++ /dev/null @@ -1,321 +0,0 @@ - - - - - - - - - - - - - - - -
-

-
- - - - - - -
-

-
- - - - - -
- - -
-
- - - - - - - diff --git a/notebooks/monty_hall.ipynb b/notebooks/monty_hall.ipynb index e53ad94..e85046b 100644 --- a/notebooks/monty_hall.ipynb +++ b/notebooks/monty_hall.ipynb @@ -99,7 +99,7 @@ " " ], "text/plain": [ - "" + "" ] }, "metadata": {}, diff --git a/notebooks/study.ipynb b/notebooks/study.ipynb index b02edeb..992eded 100644 --- a/notebooks/study.ipynb +++ b/notebooks/study.ipynb @@ -7,18 +7,60 @@ "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "'Visualization' object has no attribute 'layout'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 33\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# If you did not study, then there is only a 40% chance that you pass the test.\u001b[39;00m\n\u001b[1;32m 28\u001b[0m not_studied\u001b[38;5;241m.\u001b[39mset_transitions([\n\u001b[1;32m 29\u001b[0m (\u001b[38;5;241m4\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m10\u001b[39m, pass_test),\n\u001b[1;32m 30\u001b[0m (\u001b[38;5;241m6\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m10\u001b[39m, fail_test)\n\u001b[1;32m 31\u001b[0m ])\n\u001b[0;32m---> 33\u001b[0m \u001b[43mstormvogel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvisualization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmdp\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/git/stormvogel/.venv/lib/python3.11/site-packages/stormvogel/visualization.py:119\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(model, name, notebook)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mshow\u001b[39m(model: Model, name: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m, notebook: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 112\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create and show a visualization of a Model using a pyvis Network\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \n\u001b[1;32m 114\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[38;5;124;03m notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True.\u001b[39;00m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 119\u001b[0m vis \u001b[38;5;241m=\u001b[39m \u001b[43mVisualization\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnotebook\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 120\u001b[0m vis\u001b[38;5;241m.\u001b[39mshow()\n", - "File \u001b[0;32m~/git/stormvogel/.venv/lib/python3.11/site-packages/stormvogel/visualization.py:43\u001b[0m, in \u001b[0;36mVisualization.__init__\u001b[0;34m(self, model, name, notebook, cdn_resources, layout)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m=\u001b[39m name\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mg \u001b[38;5;241m=\u001b[39m Network(notebook\u001b[38;5;241m=\u001b[39mnotebook, directed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, cdn_resources\u001b[38;5;241m=\u001b[39mcdn_resources)\n\u001b[0;32m---> 43\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__add_states\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__add_transitions()\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m layout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/git/stormvogel/.venv/lib/python3.11/site-packages/stormvogel/visualization.py:57\u001b[0m, in \u001b[0;36mVisualization.__add_states\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m borderWidth \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mget_initial_state():\n\u001b[0;32m---> 57\u001b[0m borderWidth \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayout\u001b[49m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minit\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mborderWidth\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mg\u001b[38;5;241m.\u001b[39madd_node(\n\u001b[1;32m 59\u001b[0m state\u001b[38;5;241m.\u001b[39mid,\n\u001b[1;32m 60\u001b[0m label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(state\u001b[38;5;241m.\u001b[39mlabels),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 63\u001b[0m shape\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdot\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 64\u001b[0m )\n", - "\u001b[0;31mAttributeError\u001b[0m: 'Visualization' object has no attribute 'layout'" + "name": "stdout", + "output_type": "stream", + "text": [ + "study default layout.html\n" ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "study custom layout.html\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -53,8 +95,11 @@ " (4/10, pass_test),\n", " (6/10, fail_test)\n", "])\n", - "\n", - "stormvogel.visualization.show(mdp)" + "# We use a custom user-defined layout. The parts that it does not define will use the default.\n", + "# This particular custom layout makes states red and uses rounding instead of fractions.\n", + "layout = stormvogel.visualization.Layout(\"custom_layout.json\")\n", + "stormvogel.visualization.show(mdp, name=\"study default layout\")\n", + "stormvogel.visualization.show(mdp, name=\"study custom layout\", layout=layout)" ] }, { diff --git a/stormvogel/layout.py b/stormvogel/layout.py index 2b16df4..bc92bc8 100644 --- a/stormvogel/layout.py +++ b/stormvogel/layout.py @@ -1,42 +1,43 @@ """Contains the code responsible for saving/loading layouts and modifying them interactively.""" +from functools import reduce +from typing import Any from pyvis.network import Network import os import json +PACKAGE_ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) + class Layout: layout: dict[str, str] def __init__(self, path: str, path_relative: bool = True) -> None: - """Load a new Layout from a json file. Use either a custom or a template file. + """Load a new Layout from a json file. + Whenever keys are not present in the provided json file, their default values are used instead + as specified in DEFAULTS (=layouts/default.json). Args: - path (str, optional): Relavant if custom is true. Path to your custom layout file, relative to the current working directory. Defaults to None. - path_relative (bool): Relavant if custom is true. If set to true, then stormvogel will look for a custom layout file relative to the current working directory. + path (str): Path to your custom layout file, relative to the current working directory. + path_relative (bool, optional): If set to true, then stormvogel will look for a custom layout + file relative to the current working directory. Defaults to True. """ if path_relative: complete_path = os.path.join(os.getcwd(), path) else: complete_path = path with open(complete_path) as f: - json_string = f.read() - self.layout = json.loads(json_string) + parsed_str = f.read() + parsed_dict = json.loads(parsed_str) + with open(os.path.join(PACKAGE_ROOT_DIR, "layouts/default.json")) as f: + default_str = f.read() + default_dict = json.loads(default_str) + # Combine the parsed dict with default to fill missing keys as default values. + self.layout = Layout.merge_dict(default_dict, parsed_dict) def set_nt_layout(self, nt: Network) -> None: """Set the layout of the network passed as the arugment.""" - # We here use <> instead of {} because the f-string formatting already uses them. - # option_string = f""" - # var options = < - # "nodes": < - # "color": < - # "background": "{self.layout["color"]}", - # "border": "black" - # > - # > - # >""".replace("<", "{").replace(">", "}") options = "var options = " + str(self.layout).replace("'", '"') - print(options) nt.set_options(options) def save(self) -> None: @@ -45,19 +46,58 @@ def save(self) -> None: def show_buttons(self) -> None: raise NotImplementedError() - def __str__(self) -> str: - raise NotImplementedError() + @staticmethod + def merge_dict(dict1: dict, dict2: dict): + """Merge two nested dictionaries recursively. + Args: + dict1 (dict): + dict2 (dict): + + If dict2 has a value that dict1 does not have, then the value in dict2 is chosen. + If dict1 has a DICTIONARY and dict2 has a VALUE with the same key, then dict1 gets priority. - def __getitem__(self, key: str) -> str | None: - try: - return self.layout[key] - except KeyError: - return None + Taken from StackOverflow user Anatoliy R on July 2 2024. + https://stackoverflow.com/questions/43797333/how-to-merge-two-nested-dict-in-python""" + for key, val in dict1.items(): + if isinstance(val, dict): + if key in dict2 and type(dict2[key] == dict): + Layout.merge_dict(dict1[key], dict2[key]) + else: + if key in dict2: + dict1[key] = dict2[key] + for key, val in dict2.items(): + if key not in dict1: + dict1[key] = val + + return dict1 + + def rget(self, *keys) -> Any: + """Recursively get an entry from the layout. + If a key is not present, KeyError will be thrown. + This should never happen to users because the default values will be used in the case of missing entries.""" + res = reduce( + lambda c, k: c.__getitem__(k), list(keys), self.layout + ) # Throws KeyError if key not present. + if isinstance(res, str): + if res == "None": + return None + elif res == "True": + return True + elif res == "False": + return False + elif res.isdigit(): + return int(res) + else: + return res + else: + return res + + def __str__(self) -> str: + return str(self.layout) -package_root_dir = os.path.dirname(os.path.realpath(__file__)) # Define template layouts. DEFAULT = Layout( - os.path.join(package_root_dir, "layouts/default.json"), path_relative=False + os.path.join(PACKAGE_ROOT_DIR, "layouts/default.json"), path_relative=False ) diff --git a/stormvogel/layouts/default.json b/stormvogel/layouts/default.json index be4715f..b4aec44 100644 --- a/stormvogel/layouts/default.json +++ b/stormvogel/layouts/default.json @@ -6,7 +6,23 @@ } }, "init": { - "borderWidth": 10, - "color": "green" + "borderWidth": 3, + "color": "lightgreen", + "shape": "circle" + }, + "states": { + "borderWidth": 1, + "color": "None", + "shape": "drop" + }, + "actions": { + "borderWidth": 1, + "color": "lightblue", + "shape": "box" + }, + "rounding": { + "fractions": "True", + "digits": 5, + "max_denominator": 20 } } diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index 152eb98..31a56ad 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -3,7 +3,6 @@ from pyvis.network import Network from stormvogel.model import Model, EmptyAction, Number from stormvogel.layout import Layout, DEFAULT -from ipywidgets import interact from IPython.display import display from fractions import Fraction @@ -12,9 +11,10 @@ class Visualization: """Handles visualization of a Model using a pyvis Network.""" name: str - g: Network - ACTION_ID_OFFSET = 10**8 + nt: Network layout: Layout + + ACTION_ID_OFFSET: int = 10**10 # In the visualization, both actions and states are nodes with an id. # This offset is used to keep their ids from colliding. It should be some high constant. @@ -24,7 +24,7 @@ def __init__( name: str = "model", notebook: bool = True, cdn_resources: str = "remote", - layout: Layout | None = None, + layout: Layout = DEFAULT, ) -> None: """Create visualization of a Model using a pyvis Network @@ -33,43 +33,55 @@ def __init__( name (str, optional): The name of the resulting html file. May or may not include .html extension. notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True. """ - if layout is None: - self.layout = DEFAULT - else: - self.layout = layout self.model = model + self.layout = layout if ( name[-5:] != ".html" ): # We do not require the user to explicitly type .html in their names name += ".html" self.name = name - self.g = Network(notebook=notebook, directed=True, cdn_resources=cdn_resources) + self.nt = Network(notebook=notebook, directed=True, cdn_resources=cdn_resources) self.__add_states() self.__add_transitions() - - self.layout.set_nt_layout(self.g) + self.layout.set_nt_layout(self.nt) def __add_states(self): """For each state in the model, add a node to the graph.""" for state in self.model.states.values(): - borderWidth = 1 if state == self.model.get_initial_state(): - borderWidth = self.layout["init"]["borderWidth"] # type: ignore - self.g.add_node( - state.id, - label=",".join(state.labels), - color=None, # type: ignore - borderWidth=borderWidth, - shape="dot", - ) + self.nt.add_node( + state.id, + label=",".join(state.labels), + color=self.layout.rget("init", "color"), + borderWidth=self.layout.rget("init", "borderWidth"), + shape=self.layout.rget("init", "shape"), + ) + + else: + self.nt.add_node( + state.id, + label=",".join(state.labels), + color=self.layout.rget("states", "color"), + borderWidth=self.layout.rget("states", "borderWidth"), + shape=self.layout.rget("states", "shape"), + ) def __formatted_probability(self, prob: Number) -> str: - """Take a probability value and format it nicely using a fraction.""" - return str(Fraction(prob).limit_denominator(20)) + """Take a probability value and format it nicely using a fraction or rounding it. + Which one of these to pick is specified in the layout.""" + if self.layout.rget("rounding", "fractions"): + return str( + Fraction(prob).limit_denominator( + self.layout.rget("rounding", "max_denominator") + ) + ) + else: + return str(round(float(prob), self.layout.rget("rounding", "digits"))) def __add_transitions(self): """For each transition in the model, add a transition in the graph. - Also handles actions by calling __add_action""" + Also handles creating nodes for actions and their respective transitions. + Note that an action may appear multiple times in the model with a different state as source.""" action_id = self.ACTION_ID_OFFSET # In the visualization, both actions and states are nodes, so we need to keep track of how many actions we already have. for state_id, transition in self.model.transitions.items(): @@ -77,38 +89,45 @@ def __add_transitions(self): if action == EmptyAction: # Only draw probabilities for prob, target in branch.branch: - self.g.add_edge( + self.nt.add_edge( state_id, target.id, - color="red", + color=None, # type: ignore label=self.__formatted_probability(prob), ) else: # Add the action's node - self.g.add_node( + self.nt.add_node( n_id=action_id, - color=None, # type: ignore label=action.name, - shape="box", + color=self.layout.rget("actions", "color"), # type: ignore + borderWidth=self.layout.rget("actions", "borderWidth"), + shape=self.layout.rget("actions", "shape"), ) # Add transition from this state TO the action. - self.g.add_edge(state_id, action_id, color="red") # type: ignore - # Add transition FROM the action to the values in its branch. + self.nt.add_edge(state_id, action_id, color=None) # type: ignore + # Add transition FROM the action to the states in its branch. for prob, target in branch.branch: - self.g.add_edge( + self.nt.add_edge( action_id, target.id, - color="red", + color=None, # type: ignore label=self.__formatted_probability(prob), ) action_id += 1 def show(self): - """Show the constructed graph.""" - display(self.g.show(name=self.name)) + """Show the constructed graph as a html file.""" + display(self.nt.show(name=self.name)) -def show(model: Model, name: str = "model", notebook: bool = True): +def show( + model: Model, + name: str = "model", + notebook: bool = True, + cdn_resources: str = "remote", + layout: Layout = DEFAULT, +): """Create and show a visualization of a Model using a pyvis Network Args: @@ -116,9 +135,5 @@ def show(model: Model, name: str = "model", notebook: bool = True): name (str, optional): The name of the resulting html file. notebook (bool, optional): Leave to true if you are using in a notebook. Defaults to True. """ - vis = Visualization(model, name, notebook) + vis = Visualization(model, name, notebook, cdn_resources, layout) vis.show() - - -def make_slider(): - return interact(lambda x: x, x=10)