Skip to content

Commit

Permalink
we can now pass a function as scheduler to the simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 10, 2024
1 parent 35bd070 commit 0c5ac0c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
27 changes: 23 additions & 4 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import stormvogel.model
import stormpy.examples.files
import stormpy.examples
from typing import Callable
import random


Expand Down Expand Up @@ -99,7 +100,9 @@ def __eq__(self, other):
def simulate_path(
model: stormvogel.model.Model,
steps: int = 1,
scheduler: stormvogel.result.Scheduler | None = None,
scheduler: stormvogel.result.Scheduler
| Callable[[stormvogel.model.State], stormvogel.model.Action]
| None = None,
seed: int | None = None,
) -> Path:
"""
Expand All @@ -108,6 +111,7 @@ def simulate_path(
model: The stormvogel model that the simulator should run on.
steps: The number of steps the simulator walks through the model.
scheduler: A stormvogel scheduler to determine what actions should be taken. Random if not provided.
(instead of a stormvogel scheduler, a function from states to actions can also be provided.)
seed: The seed for the function that determines for each state what the next state will be. Random seed if not provided.
Returns a path object.
Expand All @@ -116,7 +120,13 @@ def simulate_path(
def get_range_index(stateid: int):
"""Helper function to convert the chosen action in a state by a scheduler to a range index."""
assert scheduler is not None
action = scheduler.get_choice_of_state(model.get_state_by_id(state))
if isinstance(scheduler, stormvogel.result.Scheduler):
action = scheduler.get_choice_of_state(model.get_state_by_id(state))
elif callable(scheduler):
action = scheduler(model.get_state_by_id(state))
else:
raise TypeError("Must be of type Scheduler or a function")

available_actions = model.states[stateid].available_actions()

assert action is not None
Expand Down Expand Up @@ -173,7 +183,9 @@ def simulate(
model: stormvogel.model.Model,
steps: int = 1,
runs: int = 1,
scheduler: stormvogel.result.Scheduler | None = None,
scheduler: stormvogel.result.Scheduler
| Callable[[stormvogel.model.State], stormvogel.model.Action]
| None = None,
seed: int | None = None,
) -> stormvogel.model.Model | None:
"""
Expand All @@ -183,6 +195,7 @@ def simulate(
steps: The number of steps the simulator walks through the model
runs: The number of times the model gets simulated.
scheduler: A stormvogel scheduler to determine what actions should be taken. Random if not provided.
(instead of a stormvogel scheduler, a function from states to actions can also be provided.)
seed: The seed for the function that determines for each state what the next state will be. Random seed if not provided.
Returns the partial model discovered by all the runs of the simulator together
Expand All @@ -191,7 +204,13 @@ def simulate(
def get_range_index(stateid: int):
"""Helper function to convert the chosen action in a state by a scheduler to a range index."""
assert scheduler is not None
action = scheduler.get_choice_of_state(model.get_state_by_id(state))
if isinstance(scheduler, stormvogel.result.Scheduler):
action = scheduler.get_choice_of_state(model.get_state_by_id(state))
elif callable(scheduler):
action = scheduler(model.get_state_by_id(state))
else:
raise TypeError("Must be of type Scheduler or a function")

available_actions = model.states[stateid].available_actions()

assert action is not None
Expand Down
53 changes: 51 additions & 2 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_simulate():
rewardmodel3.rewards[stateid] = float(1)

assert partial_model == other_dtmc

######################################################################################################################
# we make a monty hall mdp and run the simulator with it
mdp = examples.monty_hall.create_monty_hall_mdp()
rewardmodel = mdp.add_rewards("rewardmodel")
Expand Down Expand Up @@ -74,6 +74,31 @@ def test_simulate():
rewardmodel2.rewards = {0: 0, 7: 7, 16: 16}

assert partial_model == other_mdp
######################################################################################################################

# we test the simulator for an mdp with a lambda as Scheduler

def scheduler(state: stormvogel.model.State) -> stormvogel.model.Action:
actions = state.available_actions()
return actions[0]

mdp = examples.monty_hall.create_monty_hall_mdp()

partial_model = stormvogel.simulator.simulate(
mdp, runs=1, steps=3, seed=1, scheduler=scheduler
)

# we make the partial model that should be created by the simulator
other_mdp = stormvogel.model.new_mdp()
other_mdp.get_initial_state().set_transitions(
[(1 / 3, other_mdp.new_state("carchosen"))]
)
other_mdp.get_state_by_id(1).set_transitions([(1, other_mdp.new_state("open"))])
other_mdp.get_state_by_id(2).set_transitions(
[(1, other_mdp.new_state("goatrevealed"))]
)

assert partial_model == other_mdp


def test_simulate_path():
Expand All @@ -93,7 +118,7 @@ def test_simulate_path():
)

assert path == other_path

##############################################################################################
# we make the monty hall pomdp and run simulate path with it
pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp()
taken_actions = {}
Expand All @@ -118,3 +143,27 @@ def test_simulate_path():
)

assert path == other_path

##############################################################################################
# we test the monty hall pomdp with a lambda as scheduler
def scheduler(state: stormvogel.model.State) -> stormvogel.model.Action:
actions = state.available_actions()
return actions[0]

pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp()
path = stormvogel.simulator.simulate_path(
pomdp, steps=4, seed=1, scheduler=scheduler
)

# we make the path that the simulate path function should create
other_path = stormvogel.simulator.Path(
{
1: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(3)),
2: (pomdp.actions["open0"], pomdp.get_state_by_id(10)),
3: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(21)),
4: (pomdp.actions["stay"], pomdp.get_state_by_id(41)),
},
pomdp,
)

assert path == other_path

0 comments on commit 0c5ac0c

Please sign in to comment.