Skip to content

Commit

Permalink
refactoring the simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 10, 2024
1 parent 0c5ac0c commit 3d6ca95
Showing 1 changed file with 56 additions and 64 deletions.
120 changes: 56 additions & 64 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ def __eq__(self, other):
return False


def get_range_index(
state: stormvogel.model.State,
scheduler: stormvogel.result.Scheduler
| Callable[[stormvogel.model.State], stormvogel.model.Action],
) -> int:
"""Helper function to convert the chosen action in a state by a scheduler to a range index."""
assert scheduler is not None
if isinstance(scheduler, stormvogel.result.Scheduler):
action = scheduler.get_choice_of_state(state)
elif callable(scheduler):
action = scheduler(state)
else:
raise TypeError("Must be of type Scheduler or a function")

available_actions = state.available_actions()

assert action is not None
return available_actions.index(action)


def simulate_path(
model: stormvogel.model.Model,
steps: int = 1,
Expand All @@ -117,21 +137,6 @@ def simulate_path(
Returns a path object.
"""

def get_range_index(stateid: int):
"""Helper function to convert the chosen action in a state by a scheduler to a range index."""
assert scheduler is not None
if isinstance(scheduler, stormvogel.result.Scheduler):
action = scheduler.get_choice_of_state(model.get_state_by_id(state))
elif callable(scheduler):
action = scheduler(model.get_state_by_id(state))
else:
raise TypeError("Must be of type Scheduler or a function")

available_actions = model.states[stateid].available_actions()

assert action is not None
return available_actions.index(action)

# we initialize the simulator
stormpy_model = stormvogel.mapping.stormvogel_to_stormpy(model)
if seed:
Expand All @@ -141,36 +146,38 @@ def get_range_index(stateid: int):
assert simulator is not None

# we start adding states or state action pairs to the path
state = 0
state_id = 0
path = {}
simulator.restart()
if not model.supports_actions():
for i in range(steps):
# for each step we add a state to the path
if not model.states[state].is_absorbing() and not simulator.is_done():
state, reward, labels = simulator.step()
path[i + 1] = model.states[state]
if not model.states[state_id].is_absorbing() and not simulator.is_done():
state_id, reward, labels = simulator.step()
path[i + 1] = model.states[state_id]
else:
break
else:
for i in range(steps):
# we first choose an action (randomly or according to scheduler)
actions = simulator.available_actions()
select_action = (
random.randint(0, len(actions) - 1)
if not scheduler
else get_range_index(state)
get_range_index(model.get_state_by_id(state_id), scheduler)
if scheduler
else random.randint(0, len(actions) - 1)
)

# we add the state action pair to the path
stormvogel_action = model.states[state].available_actions()[select_action]
stormvogel_action = model.states[state_id].available_actions()[
select_action
]

if (
not model.states[state].is_absorbing(stormvogel_action)
not model.states[state_id].is_absorbing(stormvogel_action)
and not simulator.is_done()
):
state, reward, labels = simulator.step(actions[select_action])
path[i + 1] = (stormvogel_action, model.states[state])
state_id, reward, labels = simulator.step(actions[select_action])
path[i + 1] = (stormvogel_action, model.states[state_id])
else:
break

Expand Down Expand Up @@ -201,21 +208,6 @@ def simulate(
Returns the partial model discovered by all the runs of the simulator together
"""

def get_range_index(stateid: int):
"""Helper function to convert the chosen action in a state by a scheduler to a range index."""
assert scheduler is not None
if isinstance(scheduler, stormvogel.result.Scheduler):
action = scheduler.get_choice_of_state(model.get_state_by_id(state))
elif callable(scheduler):
action = scheduler(model.get_state_by_id(state))
else:
raise TypeError("Must be of type Scheduler or a function")

available_actions = model.states[stateid].available_actions()

assert action is not None
return available_actions.index(action)

# we initialize the simulator
stormpy_model = stormvogel.mapping.stormvogel_to_stormpy(model)
assert stormpy_model is not None
Expand Down Expand Up @@ -245,53 +237,54 @@ def get_range_index(stateid: int):
if not partial_model.supports_actions():
for i in range(runs):
simulator.restart()
last_state = 0
last_state_id = 0
for j in range(steps):
state, reward, labels = simulator.step()
state_id, reward, labels = simulator.step()
# we get the rewards in reversed order
reward.reverse()

# we add to the partial model what we discovered (if new)
if state not in discovered_states:
discovered_states.add(state)
if state_id not in discovered_states:
discovered_states.add(state_id)

# we also add the transitions that we travelled through, so we need to keep track of the last state
probability = 0
transitions = model.get_transitions(last_state)
transitions = model.get_transitions(last_state_id)
for tuple in transitions.transition[
stormvogel.model.EmptyAction
].branch:
if tuple[1].id == state:
if tuple[1].id == state_id:
probability += float(tuple[0])

new_state = partial_model.new_state(list(labels))
partial_model.get_state_by_id(last_state).add_transitions(
partial_model.get_state_by_id(last_state_id).add_transitions(
[(probability, new_state)]
)

# we add the rewards
for index, rewardmodel in enumerate(partial_model.rewards):
rewardmodel.set(new_state, reward[index])

last_state = state
last_state_id = state_id
if simulator.is_done():
break
else:
state = 0
last_state_partial = partial_model.get_initial_state()
last_state_id = 0
for i in range(runs):
state_id = 0
last_state_partial = partial_model.get_initial_state()
last_state_id = 0
simulator.restart()
for j in range(steps):
# we first choose an action
actions = simulator.available_actions()
select_action = (
random.randint(0, len(actions) - 1)
if not scheduler
else get_range_index(state)
get_range_index(model.get_state_by_id(state_id), scheduler)
if scheduler
else random.randint(0, len(actions) - 1)
)

# we add the action to the partial model
assert partial_model.actions is not None
action = model.states[state].available_actions()[select_action]
action = model.states[state_id].available_actions()[select_action]
if action not in partial_model.actions.values():
partial_model.new_action(action.name)

Expand All @@ -300,28 +293,27 @@ def get_range_index(stateid: int):
reward = discovery[1]
for index, rewardmodel in enumerate(partial_model.rewards):
row_group = stormpy_model.transition_matrix.get_row_group_start(
state
state_id
)
state_action_pair = row_group + select_action
rewardmodel.set_action_state(state_action_pair, reward[index])

# we add the state
state, labels = discovery[0], discovery[2]
if state not in discovered_states:
discovered_states.add(state)
state_id, labels = discovery[0], discovery[2]
if state_id not in discovered_states:
discovered_states.add(state_id)

# we also add the transitions that we travelled through, so we need to keep track of the last state
probability = 0
transitions = model.get_transitions(last_state_id)
for tuple in transitions.transition[action].branch:
if tuple[1].id == state:
if tuple[1].id == state_id:
probability += float(tuple[0])

new_state = partial_model.new_state(list(labels))
last_state_partial.add_transitions([(probability, new_state)])

last_state_partial = new_state
last_state_id = state
last_state_id = state_id
if simulator.is_done():
break

Expand Down

0 comments on commit 3d6ca95

Please sign in to comment.