diff --git a/docs/getting_started/layouts/grid.json b/docs/getting_started/layouts/grid.json index d8aeca8..13d39ba 100644 --- a/docs/getting_started/layouts/grid.json +++ b/docs/getting_started/layouts/grid.json @@ -126,6 +126,7 @@ "result_symbol": "\u2606", "show_rewards": true, "reward_symbol": "\u20ac", + "show_zero_rewards": false, "show_observations": true, "observation_symbol": "\u0298" }, @@ -301,4 +302,4 @@ "width": 800, "height": 600, "physics": true -} +} \ No newline at end of file diff --git a/docs/getting_started/layouts/monty.json b/docs/getting_started/layouts/monty.json new file mode 100644 index 0000000..69c933a --- /dev/null +++ b/docs/getting_started/layouts/monty.json @@ -0,0 +1,447 @@ +{ + "__fake_macros": { + "__group_macro": { + "borderWidth": 1, + "color": { + "background": "white", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "ellipse", + "mass": 1, + "font": { + "color": "black", + "size": 14 + } + } + }, + "groups": { + "states": { + "borderWidth": 1, + "color": { + "background": "white", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "ellipse", + "mass": 1, + "font": { + "color": "black", + "size": 33 + } + }, + "actions": { + "borderWidth": 1, + "color": { + "background": "lightblue", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "box", + "mass": 1, + "font": { + "color": "black", + "size": 36 + } + }, + "scheduled_actions": { + "borderWidth": 1, + "color": { + "background": "pink", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "box", + "mass": 1, + "font": { + "color": "black", + "size": 36 + }, + "schedColor": false + } + }, + "reload_button": false, + "edges": { + "arrows": "to", + "font": { + "color": "black", + "size": 14 + }, + "color": { + "color": "black" + } + }, + "numbers": { + "fractions": true, + "digits": 5 + }, + "state_properties": { + "show_results": true, + "result_symbol": "\u2606", + "show_rewards": true, + "reward_symbol": "\u20ac", + "show_zero_rewards": true, + "show_observations": true, + "observation_symbol": "\u0298" + }, + "layout": { + "randomSeed": 5 + }, + "misc": { + "enable_physics": true, + "width": 929, + "height": 745, + "explore": false + }, + "saving": { + "relative_path": true, + "filename": "layouts/monty.json", + "save_button": false, + "load_button": false + }, + "positions": { + "0": { + "x": 13, + "y": -7 + }, + "1": { + "x": -162, + "y": 44 + }, + "2": { + "x": 153, + "y": 116 + }, + "3": { + "x": 62, + "y": -189 + }, + "4": { + "x": -500, + "y": 175 + }, + "5": { + "x": -416, + "y": -150 + }, + "6": { + "x": -176, + "y": 322 + }, + "7": { + "x": 347, + "y": 369 + }, + "8": { + "x": 511, + "y": 88 + }, + "9": { + "x": 77, + "y": 413 + }, + "10": { + "x": 342, + "y": -310 + }, + "11": { + "x": 173, + "y": -510 + }, + "12": { + "x": -120, + "y": -474 + }, + "13": { + "x": -587, + "y": 325 + }, + "14": { + "x": -653, + "y": 99 + }, + "15": { + "x": -547, + "y": -255 + }, + "16": { + "x": -234, + "y": 463 + }, + "17": { + "x": 446, + "y": 496 + }, + "18": { + "x": 592, + "y": -52 + }, + "19": { + "x": 673, + "y": 159 + }, + "20": { + "x": 83, + "y": 576 + }, + "21": { + "x": 476, + "y": -407 + }, + "22": { + "x": 228, + "y": -670 + }, + "23": { + "x": -265, + "y": -433 + }, + "24": { + "x": -146, + "y": -650 + }, + "25": { + "x": -647, + "y": 606 + }, + "26": { + "x": -859, + "y": 431 + }, + "27": { + "x": -941, + "y": 146 + }, + "28": { + "x": -856, + "y": -91 + }, + "29": { + "x": -669, + "y": -514 + }, + "30": { + "x": -826, + "y": -327 + }, + "31": { + "x": -432, + "y": 648 + }, + "32": { + "x": -242, + "y": 743 + }, + "33": { + "x": 709, + "y": 595 + }, + "34": { + "x": 473, + "y": 779 + }, + "35": { + "x": 842, + "y": -208 + }, + "36": { + "x": 345, + "y": -100 + }, + "37": { + "x": 875, + "y": 357 + }, + "38": { + "x": 949, + "y": 59 + }, + "39": { + "x": 214, + "y": 826 + }, + "40": { + "x": -38, + "y": 836 + }, + "41": { + "x": 602, + "y": -658 + }, + "42": { + "x": 755, + "y": -462 + }, + "43": { + "x": 184, + "y": -954 + }, + "44": { + "x": 446, + "y": -860 + }, + "45": { + "x": -474, + "y": -630 + }, + "46": { + "x": -169, + "y": -181 + }, + "47": { + "x": -341, + "y": -870 + }, + "48": { + "x": -73, + "y": -922 + }, + "10000000000": { + "x": -329, + "y": 115 + }, + "10000000001": { + "x": -282, + "y": -60 + }, + "10000000002": { + "x": -134, + "y": 184 + }, + "10000000003": { + "x": 250, + "y": 245 + }, + "10000000004": { + "x": 328, + "y": 92 + }, + "10000000005": { + "x": 83, + "y": 255 + }, + "10000000006": { + "x": 213, + "y": -217 + }, + "10000000007": { + "x": 120, + "y": -352 + }, + "10000000008": { + "x": -39, + "y": -321 + }, + "10000000009": { + "x": -597, + "y": 477 + }, + "10000000010": { + "x": -735, + "y": 371 + }, + "10000000011": { + "x": -803, + "y": 137 + }, + "10000000012": { + "x": -738, + "y": -22 + }, + "10000000013": { + "x": -592, + "y": -399 + }, + "10000000014": { + "x": -698, + "y": -273 + }, + "10000000015": { + "x": -359, + "y": 533 + }, + "10000000016": { + "x": -211, + "y": 609 + }, + "10000000017": { + "x": 590, + "y": 528 + }, + "10000000018": { + "x": 443, + "y": 645 + }, + "10000000019": { + "x": 725, + "y": -133 + }, + "10000000020": { + "x": 473, + "y": -119 + }, + "10000000021": { + "x": 762, + "y": 280 + }, + "10000000022": { + "x": 817, + "y": 102 + }, + "10000000023": { + "x": 174, + "y": 695 + }, + "10000000024": { + "x": 3, + "y": 704 + }, + "10000000025": { + "x": 519, + "y": -549 + }, + "10000000026": { + "x": 627, + "y": -413 + }, + "10000000027": { + "x": 187, + "y": -816 + }, + "10000000028": { + "x": 354, + "y": -757 + }, + "10000000029": { + "x": -381, + "y": -529 + }, + "10000000030": { + "x": -242, + "y": -292 + }, + "10000000031": { + "x": -259, + "y": -758 + }, + "10000000032": { + "x": -84, + "y": -786 + } + }, + "width": 929, + "height": 745, + "physics": true +} \ No newline at end of file diff --git a/docs/getting_started/layouts/small_monty.json b/docs/getting_started/layouts/small_monty.json new file mode 100644 index 0000000..f46ec23 --- /dev/null +++ b/docs/getting_started/layouts/small_monty.json @@ -0,0 +1,139 @@ +{ + "__fake_macros": { + "__group_macro": { + "borderWidth": 1, + "color": { + "background": "white", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "ellipse", + "mass": 1, + "font": { + "color": "black", + "size": 14 + } + } + }, + "groups": { + "states": { + "borderWidth": 1, + "color": { + "background": "white", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "ellipse", + "mass": 1, + "font": { + "color": "black", + "size": 14 + } + }, + "actions": { + "borderWidth": 1, + "color": { + "background": "lightblue", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "box", + "mass": 1, + "font": { + "color": "black", + "size": 14 + } + }, + "scheduled_actions": { + "borderWidth": 1, + "color": { + "background": "pink", + "border": "black", + "highlight": { + "background": "white", + "border": "red" + } + }, + "shape": "box", + "mass": 1, + "font": { + "color": "black", + "size": 14 + }, + "schedColor": false + } + }, + "reload_button": false, + "edges": { + "arrows": "to", + "font": { + "color": "black", + "size": 14 + }, + "color": { + "color": "black" + } + }, + "numbers": { + "fractions": true, + "digits": 5 + }, + "state_properties": { + "show_results": true, + "result_symbol": "\u2606", + "show_rewards": true, + "reward_symbol": "\u20ac", + "show_zero_rewards": true, + "show_observations": true, + "observation_symbol": "\u0298" + }, + "layout": { + "randomSeed": 5 + }, + "misc": { + "enable_physics": true, + "width": 575, + "height": 148, + "explore": false + }, + "saving": { + "relative_path": true, + "filename": "layouts/small_monty.json", + "save_button": false, + "load_button": false + }, + "positions": { + "0": { + "x": -203, + "y": -60 + }, + "1": { + "x": -99, + "y": 2 + }, + "2": { + "x": 0, + "y": 0 + }, + "3": { + "x": 94, + "y": -13 + }, + "4": { + "x": 204, + "y": 38 + } + }, + "width": 575, + "height": 148, + "physics": true +} \ No newline at end of file diff --git a/docs/getting_started/model.html b/docs/getting_started/model.html index 0ee85f4..978cc9f 100644 --- a/docs/getting_started/model.html +++ b/docs/getting_started/model.html @@ -1,8 +1,8 @@ + > \ No newline at end of file diff --git a/docs/getting_started/pomdp-maze.ipynb b/docs/getting_started/pomdp-maze.ipynb index 77c0cba..c543065 100644 --- a/docs/getting_started/pomdp-maze.ipynb +++ b/docs/getting_started/pomdp-maze.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "id": "2f4d4394-8a2a-424e-b16e-8ed202bdf493", "metadata": {}, "outputs": [], @@ -132,13 +132,13 @@ " if not observation == OUT_OF_BOUNDS:\n", " took_dir = pomdp.new_state([d, f\"({x},{y})\"])\n", " grid[y][x].add_transitions([(action, took_dir)])\n", - " reward_model.set_state_action_reward(grid[y][x], action, -1)\n", + " #reward_model.set_state_action_reward(grid[y][x], action, -1)\n", " # print(took_dir)\n", " took_dir.add_transitions([(1, grid[res_y][res_x])])\n", - " #reward_model.set_state_action_reward(took_dir, EmptyAction, -1)\n", + " reward_model.set_state_action_reward(took_dir, EmptyAction, -1)\n", " took_dir.set_observation(observation)\n", " pomdp.add_self_loops()\n", - " #reward_model.set_unset_rewards(0)\n", + " reward_model.set_unset_rewards(0)\n", " return pomdp, positions" ] }, @@ -152,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "id": "ebdc5973-ff3f-4056-9593-511837966101", "metadata": {}, "outputs": [ @@ -188,125 +188,28 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "id": "d0400922-7b90-43c5-a01e-cc051089521d", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{1: -1, 2: -1, 3: -1, 4: -1, 6: -1, 7: -1, 8: -1, 9: -1, 10: -1, 11: -1, 12: -1, 13: -1, 14: -1, 15: -1, 16: -1, 17: -1}\n", - "format rewards. ['escaped'] empty 0\n", - "0\n", - "format rewards. ['t', '(1,1)'] empty None\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['e', '(1,2)'] empty 5\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(2,1)'] empty None\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,1)'] empty None\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,2)'] empty None\n", - "5\n", - "format rewards. ['↑', '(1,1)'] empty 18\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['↓', '(1,1)'] empty 19\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['←', '(1,1)'] empty 20\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['→', '(1,1)'] empty 21\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['↑', '(2,1)'] empty 22\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['↓', '(2,1)'] empty 23\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['←', '(2,1)'] empty 24\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['→', '(2,1)'] empty 25\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['↑', '(3,1)'] empty 26\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['↓', '(3,1)'] empty 27\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['←', '(3,1)'] empty 28\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['→', '(3,1)'] empty 29\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['↑', '(3,2)'] empty 30\n", - "18\n", - "format rewards. ['↓', '(3,2)'] empty 31\n", - "19\n", - "format rewards. ['←', '(3,2)'] empty 32\n", - "20\n", - "format rewards. ['→', '(3,2)'] empty 33\n", - "21\n", - "format rewards. ['t', '(1,1)'] ↑ 1\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(1,1)'] ↓ 2\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(1,1)'] ← 3\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(1,1)'] → 4\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(2,1)'] ↑ 6\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(2,1)'] ↓ 7\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(2,1)'] ← 8\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(2,1)'] → 9\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,1)'] ↑ 10\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,1)'] ↓ 11\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,1)'] ← 12\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,1)'] → 13\n", - "result: \n", - "€\t: -1\n", - "format rewards. ['t', '(3,2)'] ↑ 14\n", - "5\n", - "format rewards. ['t', '(3,2)'] ↓ 15\n", - "5\n", - "format rewards. ['t', '(3,2)'] ← 16\n", - "5\n", - "format rewards. ['t', '(3,2)'] → 17\n", - "5\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f4bebfadc6b3474d93fd948b47d30845", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d4333a09feed4d3ead275709f23aa5ce", + "model_id": "4ecb6189263e4af380ebed894d770208", "version_major": 2, "version_minor": 0 }, @@ -316,14 +219,26 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d8ea2641bc684837abddeb9cf3ff72e0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Output(), Output()))" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "pomdp, positions = grid_world(LEVEL)\n", "\n", - "print(pomdp.rewards[0].rewards)\n", - "\n", - "vis = show(pomdp, layout=Layout(\"layouts/grid.json\"), separate_labels=[\"t\", \"e\"], show_editor=False)" + "vis = show(pomdp, layout=Layout(\"layouts/grid.json\"), separate_labels=[\"t\", \"e\"], show_editor=True)" ] }, { @@ -344,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "id": "8d5b0e8e-47db-4dfb-a533-a8181ec04751", "metadata": {}, "outputs": [ @@ -352,28 +267,82 @@ "name": "stdout", "output_type": "stream", "text": [ - "ERROR (Model.cpp:71): Invalid size (16) of state action reward vector (expected:34).\n" + "-------------------------------------------------------------- \n", + "Model type: \tPOMDP (sparse)\n", + "States: \t22\n", + "Transitions: \t34\n", + "Choices: \t34\n", + "Observations: \t5\n", + "Reward Models: (default)\n", + "State Labels: \t12 labels\n", + " * (1,2) -> 1 item(s)\n", + " * ↑ -> 4 item(s)\n", + " * t -> 4 item(s)\n", + " * (3,1) -> 5 item(s)\n", + " * ↓ -> 4 item(s)\n", + " * (1,1) -> 5 item(s)\n", + " * ← -> 4 item(s)\n", + " * (2,1) -> 5 item(s)\n", + " * (3,2) -> 5 item(s)\n", + " * e -> 1 item(s)\n", + " * → -> 4 item(s)\n", + " * escaped -> 1 item(s)\n", + "Choice Labels: \t0 labels\n", + "-------------------------------------------------------------- \n", + "\n" ] }, { - "ename": "RuntimeError", - "evalue": "IllegalArgumentException: Invalid size (16) of state action reward vector (expected:34).", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstormvogel\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmapping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m stormvogel_to_stormpy\n\u001b[0;32m----> 3\u001b[0m stormpy_model \u001b[38;5;241m=\u001b[39m \u001b[43mstormvogel_to_stormpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpomdp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(stormpy_model)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m# TODO use stormpy to find the best policy/schedule, i.e. escape the maze as quickly as possible.\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Ask Pim or Linus for help?\u001b[39;00m\n", - "File \u001b[0;32m~/git/env/lib/python3.11/site-packages/stormvogel/mapping.py:300\u001b[0m, in \u001b[0;36mstormvogel_to_stormpy\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m map_ctmc(model)\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m model\u001b[38;5;241m.\u001b[39mget_type() \u001b[38;5;241m==\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mModelType\u001b[38;5;241m.\u001b[39mPOMDP:\n\u001b[0;32m--> 300\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmap_pomdp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 301\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m model\u001b[38;5;241m.\u001b[39mget_type() \u001b[38;5;241m==\u001b[39m stormvogel\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mModelType\u001b[38;5;241m.\u001b[39mMA:\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m map_ma(model)\n", - "File \u001b[0;32m~/git/env/lib/python3.11/site-packages/stormvogel/mapping.py:225\u001b[0m, in \u001b[0;36mstormvogel_to_stormpy..map_pomdp\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 223\u001b[0m components\u001b[38;5;241m.\u001b[39mobservability_classes \u001b[38;5;241m=\u001b[39m observations\n\u001b[1;32m 224\u001b[0m components\u001b[38;5;241m.\u001b[39mchoice_labeling \u001b[38;5;241m=\u001b[39m choice_labeling\n\u001b[0;32m--> 225\u001b[0m pomdp \u001b[38;5;241m=\u001b[39m \u001b[43mstormpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSparsePomdp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcomponents\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pomdp\n", - "\u001b[0;31mRuntimeError\u001b[0m: IllegalArgumentException: Invalid size (16) of state action reward vector (expected:34)." - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "82f50ed1d2eb4766985660f1e36adc61", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "35d598f4a40d4267940f270d70d053ad", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e06d4b6b1b0c4f05a55eba449d643844", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Output(), Output()))" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "from stormvogel.mapping import stormvogel_to_stormpy\n", + "from stormvogel.mapping import stormvogel_to_stormpy, stormpy_to_stormvogel\n", "\n", "stormpy_model = stormvogel_to_stormpy(pomdp)\n", "print(stormpy_model)\n", + "pomdp2 = stormpy_to_stormvogel(stormpy_model)\n", + "vis2 = show(pomdp2, layout=Layout(\"layouts/grid.json\"), separate_labels=[\"t\", \"e\"], show_editor=True)\n", + "\n", "\n", "# TODO use stormpy to find the best policy/schedule, i.e. escape the maze as quickly as possible.\n", "# Ask Pim or Linus for help?" @@ -382,27 +351,25 @@ { "cell_type": "code", "execution_count": 9, - "id": "a6cfdb51-bd07-45a7-9772-635c213f235e", + "id": "b1171d26-20f6-40e5-af57-721db56be040", "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'reward_model' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mreward_model\u001b[49m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'reward_model' is not defined" + "name": "stdout", + "output_type": "stream", + "text": [ + "{'0': Action(name='0', labels=frozenset()), '1': Action(name='1', labels=frozenset()), '2': Action(name='2', labels=frozenset()), '3': Action(name='3', labels=frozenset()), '4': Action(name='4', labels=frozenset()), '5': Action(name='5', labels=frozenset()), '6': Action(name='6', labels=frozenset()), '7': Action(name='7', labels=frozenset()), '8': Action(name='8', labels=frozenset()), '9': Action(name='9', labels=frozenset()), '10': Action(name='10', labels=frozenset()), '11': Action(name='11', labels=frozenset()), '12': Action(name='12', labels=frozenset()), '13': Action(name='13', labels=frozenset()), '14': Action(name='14', labels=frozenset()), '15': Action(name='15', labels=frozenset()), '16': Action(name='16', labels=frozenset()), '17': Action(name='17', labels=frozenset()), '18': Action(name='18', labels=frozenset()), '19': Action(name='19', labels=frozenset()), '20': Action(name='20', labels=frozenset()), '21': Action(name='21', labels=frozenset()), '22': Action(name='22', labels=frozenset()), '23': Action(name='23', labels=frozenset()), '24': Action(name='24', labels=frozenset()), '25': Action(name='25', labels=frozenset()), '26': Action(name='26', labels=frozenset()), '27': Action(name='27', labels=frozenset()), '28': Action(name='28', labels=frozenset()), '29': Action(name='29', labels=frozenset()), '30': Action(name='30', labels=frozenset()), '31': Action(name='31', labels=frozenset()), '32': Action(name='32', labels=frozenset()), '33': Action(name='33', labels=frozenset())}\n" ] } ], - "source": [] + "source": [ + "print(pomdp2.actions)\n" + ] }, { "cell_type": "code", "execution_count": null, - "id": "d0642839-3412-4418-93b0-f6a73a40715b", + "id": "d7b41d4e-b9db-40fb-8aaf-cf3ce5e31cea", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/getting_started/simulator.ipynb b/docs/getting_started/simulator.ipynb index f7fafbe..7ee24d6 100644 --- a/docs/getting_started/simulator.ipynb +++ b/docs/getting_started/simulator.ipynb @@ -1,449 +1,2915 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "a7245ed2", - "metadata": {}, - "source": [ - "# The simulator" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "a8ddc37c-66d2-43e4-8162-6be19a1d70a1", - "metadata": {}, - "outputs": [], - "source": [ - "from stormvogel import show\n", - "import stormvogel.model" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "id": "a7245ed2", + "metadata": {}, + "source": [ + "# The simulator" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a8ddc37c-66d2-43e4-8162-6be19a1d70a1", + "metadata": {}, + "outputs": [], + "source": [ + "from stormvogel.show import show\n", + "from stormvogel.layout import Layout\n", + "import stormvogel.model\n", + "import stormvogel.simulator" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cab40f99-3460-4497-8b9f-3d669eee1e11", + "metadata": {}, + "outputs": [], + "source": [ + "# We create the monty hall mdp\n", + "mdp = stormvogel.model.new_mdp(\"Monty Hall\")\n", + "\n", + "init = mdp.get_initial_state()\n", + "\n", + "# first choose car position\n", + "init.set_transitions(\n", + " [(1 / 3, mdp.new_state(\"carchosen\", {\"car_pos\": i})) for i in range(3)]\n", + ")\n", + "\n", + "# we choose a door in each case\n", + "for s in mdp.get_states_with_label(\"carchosen\"):\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " mdp.action(f\"open{i}\"),\n", + " mdp.new_state(\"open\", s.features | {\"chosen_pos\": i}),\n", + " )\n", + " for i in range(3)\n", + " ]\n", + " )\n", + "\n", + "# the other goat is revealed\n", + "for s in mdp.get_states_with_label(\"open\"):\n", + " car_pos = s.features[\"car_pos\"]\n", + " chosen_pos = s.features[\"chosen_pos\"]\n", + " other_pos = {0, 1, 2} - {car_pos, chosen_pos}\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " 1 / len(other_pos),\n", + " mdp.new_state(\"goatrevealed\", s.features | {\"reveal_pos\": i}),\n", + " )\n", + " for i in other_pos\n", + " ]\n", + " )\n", + "\n", + "# we must choose whether we want to switch\n", + "for s in mdp.get_states_with_label(\"goatrevealed\"):\n", + " car_pos = s.features[\"car_pos\"]\n", + " chosen_pos = s.features[\"chosen_pos\"]\n", + " reveal_pos = s.features[\"reveal_pos\"]\n", + " other_pos = list({0, 1, 2} - {reveal_pos, chosen_pos})[0]\n", + " s.set_transitions(\n", + " [\n", + " (\n", + " mdp.action(\"stay\"),\n", + " mdp.new_state(\n", + " [\"done\"] + ([\"target\"] if chosen_pos == car_pos else []),\n", + " s.features | {\"chosen_pos\": chosen_pos},\n", + " ),\n", + " ),\n", + " (\n", + " mdp.action(\"switch\"),\n", + " mdp.new_state(\n", + " [\"done\"] + ([\"target\"] if other_pos == car_pos else []),\n", + " s.features | {\"chosen_pos\": other_pos},\n", + " ),\n", + " ),\n", + " ]\n", + " )\n", + "\n", + "# we add self loops to all states with no outgoing transitions\n", + "mdp.add_self_loops()" + ] + }, + { + "cell_type": "markdown", + "id": "d1f90374-dc85-4f31-b59f-aaf5e48a32f7", + "metadata": {}, + "source": [ + "We show what our mdp model looks like." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c129cf62-40ca-4246-8718-5c859744e7f8", + "metadata": { + "scrolled": true + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 8, - "id": "cab40f99-3460-4497-8b9f-3d669eee1e11", - "metadata": {}, - "outputs": [], - "source": [ - "# We create the monty hall mdp\n", - "mdp = stormvogel.model.new_mdp(\"Monty Hall\")\n", - "\n", - "init = mdp.get_initial_state()\n", - "\n", - "# first choose car position\n", - "init.set_transitions(\n", - " [(1 / 3, mdp.new_state(\"carchosen\", {\"car_pos\": i})) for i in range(3)]\n", - ")\n", - "\n", - "# we choose a door in each case\n", - "for s in mdp.get_states_with_label(\"carchosen\"):\n", - " s.set_transitions(\n", - " [\n", - " (\n", - " mdp.action(f\"open{i}\"),\n", - " mdp.new_state(\"open\", s.features | {\"chosen_pos\": i}),\n", - " )\n", - " for i in range(3)\n", - " ]\n", - " )\n", - "\n", - "# the other goat is revealed\n", - "for s in mdp.get_states_with_label(\"open\"):\n", - " car_pos = s.features[\"car_pos\"]\n", - " chosen_pos = s.features[\"chosen_pos\"]\n", - " other_pos = {0, 1, 2} - {car_pos, chosen_pos}\n", - " s.set_transitions(\n", - " [\n", - " (\n", - " 1 / len(other_pos),\n", - " mdp.new_state(\"goatrevealed\", s.features | {\"reveal_pos\": i}),\n", - " )\n", - " for i in other_pos\n", - " ]\n", - " )\n", - "\n", - "# we must choose whether we want to switch\n", - "for s in mdp.get_states_with_label(\"goatrevealed\"):\n", - " car_pos = s.features[\"car_pos\"]\n", - " chosen_pos = s.features[\"chosen_pos\"]\n", - " reveal_pos = s.features[\"reveal_pos\"]\n", - " other_pos = list({0, 1, 2} - {reveal_pos, chosen_pos})[0]\n", - " s.set_transitions(\n", - " [\n", - " (\n", - " mdp.action(\"stay\"),\n", - " mdp.new_state(\n", - " [\"done\"] + ([\"target\"] if chosen_pos == car_pos else []),\n", - " s.features | {\"chosen_pos\": chosen_pos},\n", - " ),\n", - " ),\n", - " (\n", - " mdp.action(\"switch\"),\n", - " mdp.new_state(\n", - " [\"done\"] + ([\"target\"] if other_pos == car_pos else []),\n", - " s.features | {\"chosen_pos\": other_pos},\n", - " ),\n", - " ),\n", - " ]\n", - " )\n", - "\n", - "# we add self loops to all states with no outgoing transitions\n", - "mdp.add_self_loops()" + "data": { + "text/html": [ + "\n", + " " + ], + "text/plain": [ + "" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "vis = show(mdp, layout=Layout(\"layouts/monty.json\"), save_and_embed=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b5b2990c-65ed-4d7b-a4b8-f303843622e5", + "metadata": {}, + "source": [ + "We want to simulate this model. That is, we start at the initial state and then we walk through the model by choosing random actions.\n", + "\n", + "When we do this, we get a partial model as a result that contains everything we discovered during this walk. \n", + "\n", + "Try running this multiple times, and observe that sometimes we get to the target and sometimes we do not." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "eb0fadc0-7bb6-4c1d-ae3e-9e16527726ab", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 9, - "id": "eb0fadc0-7bb6-4c1d-ae3e-9e16527726ab", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ModelType.MDP with name None\n", - "\n", - "States:\n", - "State 0 with labels ['init'] and features {}\n", - "State 1 with labels ['carchosen'] and features {}\n", - "State 2 with labels ['open'] and features {}\n", - "State 3 with labels ['goatrevealed'] and features {}\n", - "State 4 with labels ['done', 'target'] and features {}\n", - "\n", - "Transitions:\n", - "0.3333333333333333 -> State 1 with labels ['carchosen'] and features {}\n", - "1.0 -> State 2 with labels ['open'] and features {}\n", - "1.0 -> State 3 with labels ['goatrevealed'] and features {}\n", - "1.0 -> State 4 with labels ['done', 'target'] and features {}\n" - ] - } + "data": { + "text/html": [ + "\n", + " " ], - "source": [ - "#we want to simulate this model. That is, we start at the initial state and then\n", - "#we walk through the model according to transition probabilities.\n", - "#When we do this, we get a partial model as a result that contains everything we discovered\n", - "#during this walk.\n", - "\n", - "#we can choose how many steps we take:\n", - "steps = 4\n", - "\n", - "#and we can specify a seed if we want:\n", - "seed = 123456\n", - "\n", - "#then we run the simulator:\n", - "partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", - "print(partial_model)" + "text/plain": [ + "" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# we can choose how many steps we take:\n", + "steps = 4\n", + "\n", + "# and we can specify a seed if we want:\n", + "seed = 12345676346\n", + "\n", + "# then we run the simulator:\n", + "partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", + "# We could also provide a seed.\n", + "#partial_model = stormvogel.simulator.simulate(mdp, steps=steps, seed=seed)\n", + "\n", + "vis = show(partial_model, save_and_embed=True, layout=Layout(\"layouts/small_monty.json\"))" + ] + }, + { + "cell_type": "markdown", + "id": "49e3893d-bc35-4648-87eb-74a6a222ebf0", + "metadata": {}, + "source": [ + "We can also provide a scheduler (i.e. policy) which chooses what actions we should take at all time.\n", + "\n", + "In this case, we always take the first action, which means that we open door 0, and don't switch doors." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "59ac1e34-866c-42c4-b19b-c2a15c830e2e", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 10, - "id": "59ac1e34-866c-42c4-b19b-c2a15c830e2e", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ModelType.MDP with name None\n", - "\n", - "States:\n", - "State 0 with labels ['init'] and features {}\n", - "State 1 with labels ['carchosen'] and features {}\n", - "State 2 with labels ['open'] and features {}\n", - "State 3 with labels ['goatrevealed'] and features {}\n", - "State 4 with labels ['done'] and features {}\n", - "\n", - "Transitions:\n", - "0.3333333333333333 -> State 1 with labels ['carchosen'] and features {}\n", - "1.0 -> State 2 with labels ['open'] and features {}\n", - "1.0 -> State 3 with labels ['goatrevealed'] and features {}\n", - "1.0 -> State 4 with labels ['done'] and features {}\n" - ] - } + "data": { + "text/html": [ + "\n", + " " ], - "source": [ - "#it still chooses random actions but we can prevent this by providing a scheduler:\n", - "taken_actions = {}\n", - "for id, state in mdp.states.items():\n", - " taken_actions[id] = state.available_actions()[0]\n", - "scheduler = stormvogel.result.Scheduler(mdp, taken_actions)\n", - "\n", - "partial_model = stormvogel.simulator.simulate(mdp, steps=steps, scheduler=scheduler, seed=seed)\n", - "print(partial_model)" + "text/plain": [ + "" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#it still chooses random actions but we can prevent this by providing a scheduler:\n", + "taken_actions = {}\n", + "for id, state in mdp.states.items():\n", + " taken_actions[id] = state.available_actions()[0]\n", + "scheduler = stormvogel.result.Scheduler(mdp, taken_actions)\n", + "\n", + "partial_model = stormvogel.simulator.simulate(mdp, steps=steps, scheduler=scheduler, seed=seed)\n", + "vis = show(partial_model, save_and_embed=True, layout=Layout(\"layouts/small_monty.json\"))" + ] + }, + { + "cell_type": "markdown", + "id": "57a9b77d-4a75-42e4-8006-0bb11f2b345c", + "metadata": {}, + "source": [ + "We can highlight the scheduled states in the visualization of the entire model." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7e23fc38-b2af-4f02-b0a2-5d06151d2ca5", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 11, - "id": "22871288-755c-463f-9150-f207c2f5c211", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "data": { + "text/html": [ + "\n", + " " ], - "source": [ - "#we can also visualize the partial model that we get from the simulator:\n", - "vis = show.show(partial_model, save_and_embed=True)" + "text/plain": [ + "" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "vis = show(mdp, save_and_embed=True, layout=Layout(\"layouts/monty.json\"), scheduler=scheduler)" + ] + }, + { + "cell_type": "markdown", + "id": "e4f388d8-d08b-40f5-a61b-1f5f29d004c9", + "metadata": {}, + "source": [ + "We can also get a path from the simulator function." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "34d0c293-d090-4e3d-9e80-4351f5fcba62", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initial state --(action: empty)--> state: 2 --(action: open0)--> state: 7 --(action: empty)--> state: 17 --(action: stay)--> state: 33\n" + ] + } + ], + "source": [ + "#we can also use another simulator function that returns a path instead of a partial model:\n", + "path = stormvogel.simulator.simulate_path(mdp, steps=4, scheduler=scheduler, seed=123456)\n", + "\n", + "print(path)" + ] + }, + { + "cell_type": "markdown", + "id": "1e0f6fea-6cd3-43e0-beea-84dc26eeca0b", + "metadata": {}, + "source": [ + "We can even visualize this path interactively! This works with any Path, not just a scheduler path. TODO." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "afbb3234-99e4-49d0-b259-f598e895f600", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 12, - "id": "34d0c293-d090-4e3d-9e80-4351f5fcba62", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "initial state --(action: empty)--> state: 2 --(action: open0)--> state: 7 --(action: empty)--> state: 17 --(action: stay)--> state: 33\n" - ] - } + "data": { + "text/html": [ + "\n", + " " ], - "source": [ - "#we can also use another simulator function that returns a path instead of a partial model:\n", - "path = stormvogel.simulator.simulate_path(mdp, steps=4, scheduler=scheduler, seed=123456)\n", - "\n", - "print(path)" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "id": "99c763fa-82ea-42ff-8833-79c640f14518", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" + "ename": "AttributeError", + "evalue": "'Visualization' object has no attribute 'show_path'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[16], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtime\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m sleep\n\u001b[1;32m 4\u001b[0m vis \u001b[38;5;241m=\u001b[39m show(mdp, save_and_embed\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, layout\u001b[38;5;241m=\u001b[39mLayout(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlayouts/monty.json\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m----> 5\u001b[0m \u001b[43mvis\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshow_path\u001b[49m(path)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m state \u001b[38;5;129;01min\u001b[39;00m path:\n\u001b[1;32m 7\u001b[0m vis\u001b[38;5;241m.\u001b[39mhighlight_state(state, color\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mred\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAttributeError\u001b[0m: 'Visualization' object has no attribute 'show_path'" + ] } + ], + "source": [ + "from stormvogel.show import show\n", + "from time import sleep\n", + "\n", + "vis = show(mdp, save_and_embed=True, layout=Layout(\"layouts/monty.json\"))\n", + "vis.show_path(path)\n", + "for state in path:\n", + " vis.highlight_state(state, color=\"red\")\n", + " sleep(1)\n", + " # TODO should crash\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f6fe7f4-cc9e-4c1d-9850-3799ca47a903", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 5 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/getting_started/study.html b/docs/getting_started/study.html index df1046d..e3abc67 100644 --- a/docs/getting_started/study.html +++ b/docs/getting_started/study.html @@ -1,6 +1,6 @@ " + "" ], "text/plain": [ "" @@ -419,10 +210,38 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fdbfe0ff407a4a308031b5e7e783cf40", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "78d798387dbc4a6b8896a4d6125884ee", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Output(), Output()))" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "vis = show(mdp, layout=Layout(\"layouts/pinkgreen.json\"), name=\"study\", save_and_embed=True)" + "vis = show(mdp, layout=Layout(\"layouts/pinkgreen.json\"), name=\"study\", show_editor=True)" ] }, { diff --git a/stormvogel/layouts/default.json b/stormvogel/layouts/default.json index ffc9633..e80e93a 100644 --- a/stormvogel/layouts/default.json +++ b/stormvogel/layouts/default.json @@ -92,6 +92,7 @@ "result_symbol": "\u2606", "show_rewards": true, "reward_symbol": "\u20ac", + "show_zero_rewards": true, "show_observations": true, "observation_symbol": "\u0298" }, diff --git a/stormvogel/layouts/explore.json b/stormvogel/layouts/explore.json index b5f3411..39e2e08 100644 --- a/stormvogel/layouts/explore.json +++ b/stormvogel/layouts/explore.json @@ -92,6 +92,7 @@ "result_symbol": "\u2606", "show_rewards": true, "reward_symbol": "\u20ac", + "show_zero_rewards": true, "show_observations": true, "observation_symbol": "\u0298" }, diff --git a/stormvogel/layouts/schema.json b/stormvogel/layouts/schema.json index d904e20..9bbaa63 100644 --- a/stormvogel/layouts/schema.json +++ b/stormvogel/layouts/schema.json @@ -117,6 +117,10 @@ "__description": "Symbol", "__widget": "Text" }, + "show_zero_rewards": { + "__description": "Show zero rewards", + "__widget": "Checkbox" + }, "show_observations": { "__description": "Show observations", "__widget": "Checkbox" diff --git a/stormvogel/mapping.py b/stormvogel/mapping.py index 74df716..c97d36b 100644 --- a/stormvogel/mapping.py +++ b/stormvogel/mapping.py @@ -79,7 +79,7 @@ def add_rewards( reward_models = {} for rewardmodel in model.rewards: reward_models[rewardmodel.name] = stormpy.SparseRewardModel( - optional_state_action_reward_vector=list(rewardmodel.rewards.values()) + optional_state_action_reward_vector=list(rewardmodel.reward_vector()) ) return reward_models @@ -339,18 +339,23 @@ def add_rewards( """ adds the rewards from the sparsemodel to either the states or the state action pairs of the model """ - for reward_model in sparsemodel.reward_models: - rewards = sparsemodel.get_reward_model(reward_model) - rewardmodel = model.add_rewards(reward_model) - for index, reward in enumerate( - rewards.state_action_rewards - if rewards.has_state_action_rewards - else rewards.state_rewards - ): - if model.supports_actions(): - rewardmodel.set_state_action_reward_at_id(index, reward) - else: - rewardmodel.set_state_reward(model.get_state_by_id(index), reward) + for reward_model_name in sparsemodel.reward_models: + rewards = sparsemodel.get_reward_model(reward_model_name) + rewardmodel = model.add_rewards(reward_model_name) + if model.supports_actions(): + rewardmodel.set_from_rewards_vector(rewards.state_action_rewards) + else: + rewardmodel.set_from_rewards_vector(rewards.state_rewards) + + # for index, reward in enumerate( + # rewards.state_action_rewards + # if rewards.has_state_action_rewards + # else rewards.state_rewards + # ): + # if model.supports_actions(): + # rewardmodel.set_state_action_reward_at_id(index, reward) + # else: + # rewardmodel.set_state_reward(model.get_state_by_id(index), reward) def map_dtmc(sparsedtmc: stormpy.storage.SparseDtmc) -> stormvogel.model.Model: """ diff --git a/stormvogel/model.py b/stormvogel/model.py index dedd8f4..ffbf086 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction -from typing import cast +from typing import Tuple, cast import copy Parameter = str @@ -341,10 +341,12 @@ class RewardModel: name: str model: "Model" - # Hashed by the id of the state or state action pair (=number in the matrix) - rewards: dict[int, Number] + rewards: dict[Tuple[int, Action], Number] + """Rewards dict. Hashed by state id and Action. + The function update_rewards can be called to update rewards. After this, rewards will correspond to intermediate_rewards. + Note that in models without actions, EmptyAction will be used here.""" - def __init__(self, name: str, model: "Model", rewards: dict[int, Number]): + def __init__(self, name: str, model: "Model", rewards: dict[Tuple[int, Action], Number]): self.name = name self.rewards = rewards self.model = model @@ -353,18 +355,30 @@ def __init__(self, name: str, model: "Model", rewards: dict[int, Number]): self.set_action_state = {} else: self.state_action_pair = None + + def set_from_rewards_vector(self, vector:list[Number]) -> None: + """Set the rewards of this model according to a stormpy rewards vector.""" + combined_id = 0 + self.rewards = dict() + for s in self.model.states.values(): + for a in s.available_actions(): + self.rewards[s.id,a] = vector[combined_id] + combined_id += 1 def get_state_reward(self, state: State) -> Number: """Gets the reward at said state or state action pair""" - return self.rewards[state.id] + if self.model.supports_actions(): + RuntimeError("This is a model with actions. Please call the get_action_state_reward(_at_id) function instead") + return self.rewards[state.id, EmptyAction] def get_state_action_reward(self, state: State, action: Action) -> Number | None: - """Gets the reward at said state or state action pair""" + """Gets the reward at said state or state action pair. Returns None if no reward was found.""" if self.model.supports_actions(): if action in state.available_actions(): - id = self.model.get_state_action_id(state, action) - assert id is not None - return self.rewards[id] + try: + return self.rewards[state.id, action] + except KeyError: + return None else: RuntimeError("This action is not available in this state") else: @@ -379,40 +393,50 @@ def set_state_reward(self, state: State, value: Number): "This is a model with actions. Please call the set_action_state_reward(_at_id) function instead" ) else: - self.rewards[state.id] = value + self.rewards[state.id, EmptyAction] = value - def set_state_action_reward(self, state: State, action: Action, value: Number): - """sets the reward at said state action pair (in case of models with actions)""" + def set_state_action_reward(self, state: State, action: Action, value: Number, auto_update_rewards:bool=True): + """sets the reward at said state action pair (in case of models with actions). + If you disable auto_update_rewards, you will need to call update_intermediate_to""" if self.model.supports_actions(): if action in state.available_actions(): - id = self.model.get_state_action_id(state, action) - assert id is not None - self.rewards[id] = value + self.rewards[state.id, action] = value else: RuntimeError("This action is not available in this state") else: RuntimeError( "The model this rewardmodel belongs to does not support actions" ) - - def set_state_action_reward_at_id(self, action_state: int, value: Number): - """sets the reward at said state action pair for a given id (in the case of models with actions)""" - if self.model.supports_actions(): - self.rewards[action_state] = value - else: - RuntimeError( - "The model this rewardmodel belongs to does not support actions" - ) + + def reward_vector(self) -> list[Number]: + """Return the rewards in a stormpy format.""" + vector = [] + for s in self.model.states.values(): + for a in s.available_actions(): + reward = self.rewards[s.id,a] + if reward is None: + RuntimeError("A reward was not set. You might want to call set_unset_rewards.") + vector.append(reward) + return vector + + + # def set_state_action_reward_at_id(self, action_state: int, value: Number): + # """sets the reward at said state action pair for a given id (in the case of models with actions). + # WARNING This function is only intended for internal use within stormvogel. Use set_state_action_reward instead!""" + # if self.model.supports_actions(): + # self.rewards[action_state] = value + # else: + # RuntimeError( + # "The model this rewardmodel belongs to does not support actions" + # ) def set_unset_rewards(self, value: Number): """Fills up rewards that were not set yet with the specified value. Use this if converting to stormpy doesn't work because the reward vector does not have the expected length.""" - expected_length = 0 - for s_id, state in self.model.states.items(): - expected_length += len(state.available_actions()) - for i in range(expected_length): - if expected_length not in self.rewards: - self.rewards[i] = value + for s in self.model.states.values(): + for a in s.available_actions(): + if (s.id, a) not in self.rewards: + self.rewards[s.id, a] = value def __lt__(self, other) -> bool: if not isinstance(other, RewardModel): diff --git a/stormvogel/show.py b/stormvogel/show.py index 56591d5..4b457b4 100644 --- a/stormvogel/show.py +++ b/stormvogel/show.py @@ -14,6 +14,7 @@ def show( model: stormvogel.model.Model, result: stormvogel.result.Result | None = None, + scheduler: stormvogel.result.Scheduler | None = None, name: str = "model", layout: stormvogel.layout.Layout | None = None, positions: dict[str, dict[str, int]] | None = None, @@ -46,6 +47,7 @@ def show( model=model, name=name, result=result, + scheduler=scheduler, layout=layout, positions=positions, separate_labels=separate_labels, diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index b44f50e..c33b685 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -40,6 +40,7 @@ def __init__( model: stormvogel.model.Model, name: str | None = None, result: stormvogel.result.Result | None = None, + scheduler: stormvogel.result.Scheduler | None = None, layout: stormvogel.layout.Layout = stormvogel.layout.DEFAULT(), separate_labels: list[str] = [], positions: dict[str, dict[str, int]] | None = None, @@ -49,13 +50,12 @@ def __init__( do_init_server: bool = True, ) -> None: """Create visualization of a Model using a pyvis Network - - NEVER CREATE TWO VISUALIZATIONS WITH THE SAME NAME, STUFF MIGHT BREAK. - Args: model (Model): The stormvogel model to be displayed. name (str, optional): Internally used name. Will be randomly generated if left as None. result (Result, optional): Result corresponding to the model. + scheduler(Scheduler, optional): Scheduler. The scheduled states can be given a distinct layout. + If not set, then the scheduler from the result will be used. layout (Layout, optional): Layout used for the visualization. separate_labels (list[str], optional): Labels that should be edited separately according to the layout. positions (dict[int, dict[str, int]] | None): A dictionary from state ids to positions. @@ -66,12 +66,18 @@ def __init__( do_init_server (bool): Enable if you would like to start the server which is required for some visualization features. Defaults to True. """ super().__init__(output, do_display, debug_output) + # Having two visualizations with the same name might break some interactive html stuff. This is why we add a random word to it. if name is None: self.name: str = random_word(10) else: self.name: str = name + random_word(10) self.model: stormvogel.model.Model = model - self.result: stormvogel.result.Result = result + self.result: stormvogel.result.Result | None = result + self.scheduler: stormvogel.result.Scheduler | None = scheduler + # If a scheduler was not set explictely, but a result was set, then take the scheduler from the results. + if self.scheduler is None: + if not self.result is None: + self.scheduler = self.result.scheduler self.layout: stormvogel.layout.Layout = layout self.separate_labels: set[str] = set(map(und, separate_labels)).union( self.layout.layout["groups"].keys() @@ -151,7 +157,6 @@ def __add_transitions(self) -> None: if self.nt is None: return action_id = self.ACTION_ID_OFFSET - scheduler = self.result.scheduler if self.result is not None else None # In the visualization, both actions and states are nodes, so we need to keep track of how many actions we already have. for state_id, transition in self.model.transitions.items(): for action, branch in transition.transition.items(): @@ -166,11 +171,11 @@ def __add_transitions(self) -> None: else: # Put the action in the group scheduled_actions if appropriate. group = "actions" - if scheduler is not None: - choice = scheduler.get_choice_of_state( + if self.scheduler is not None: + choice = self.scheduler.get_choice_of_state( state=self.model.get_state_by_id(state_id) ) - if choice == action: + if action.strict_eq(choice): group = "scheduled_actions" reward = self.__format_rewards(self.model.get_state_by_id(state_id), action) @@ -214,24 +219,20 @@ def __format_rewards(self, s: stormvogel.model.State, a: stormvogel.model.Action """Create a string that contains either the state exit reward (if actions are not supported) or the reward of taking this action from this state. (if actions ARE supported) Starts with newline""" - - if len(self.model.rewards) == 0 or not self.layout.layout["state_properties"]["show_rewards"]: + if not self.layout.layout["state_properties"]["show_rewards"]: return "" - res = "\n" + self.layout.layout["state_properties"]["reward_symbol"] + EMPTY_RES = "\n" + self.layout.layout["state_properties"]["reward_symbol"] + res = EMPTY_RES for reward_model in self.model.rewards: - print("format rewards.", s.labels, a.name, self.model.get_state_action_id(s,a)) - try: - reward = 4269 - if self.model.supports_actions() and a != stormvogel.model.EmptyAction: - reward = reward_model.get_state_action_reward(s, a) - else: - reward = reward_model.get_state_reward(s) + if self.model.supports_actions(): + reward = reward_model.get_state_action_reward(s, a) + else: + reward = reward_model.get_state_reward(s) + if not reward is None and not\ + (not self.layout.layout["state_properties"]["show_zero_rewards"] and reward == 0): res += f"\t{reward_model.name}: {reward}" - except KeyError as e: # If this reward model does not have a reward for this state. - print("keyerror with", e) - print(self.model.get_state_action_id(s,a)) - return "" - print("result:", res) + if res == EMPTY_RES: + return "" return res def __format_result(self, s: stormvogel.model.State) -> str: diff --git a/tests/saved_test_layout.json b/tests/saved_test_layout.json index 46eaedf..f358561 100644 --- a/tests/saved_test_layout.json +++ b/tests/saved_test_layout.json @@ -117,4 +117,4 @@ "init": { "color": "TEST_COLOR" } -} +} \ No newline at end of file