Skip to content

Commit

Permalink
progress removing actions
Browse files Browse the repository at this point in the history
  • Loading branch information
YouGuessedMyName committed Dec 9, 2024
1 parent 74fd328 commit f0974c8
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 260 deletions.
6 changes: 3 additions & 3 deletions stormvogel/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def map_mdp(sparsemdp: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
else:
actionlabels = frozenset()
# TODO assign the correct action name and not only an index
action = model.new_action(str(i), actionlabels)
action = model.new_action(actionlabels)
branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]
transition[action] = stormvogel.model.Branch(branch)
transitions = stormvogel.model.Transition(transition)
Expand Down Expand Up @@ -492,7 +492,7 @@ def map_pomdp(sparsepomdp: stormpy.storage.SparsePomdp) -> stormvogel.model.Mode
actionlabels = frozenset()

# TODO assign the correct action name and not only an index
action = model.new_action(str(i), actionlabels)
action = model.new_action(actionlabels)
branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]
transition[action] = stormvogel.model.Branch(branch)
transitions = stormvogel.model.Transition(transition)
Expand Down Expand Up @@ -540,7 +540,7 @@ def map_ma(sparsema: stormpy.storage.SparseMA) -> stormvogel.model.Model:
actionlabels = frozenset()

# TODO assign the correct action name and not only an index
action = model.new_action(str(i), actionlabels)
action = model.new_action(actionlabels)
branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]
transition[action] = stormvogel.model.Branch(branch)
transitions = stormvogel.model.Transition(transition)
Expand Down
112 changes: 54 additions & 58 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,36 +199,29 @@ class Action:
"""Represents an action, e.g., in MDPs.
Note that this action object is completely independent of its corresponding branch.
Their relation is managed by Transitions.
Two actions with the same labels are considered equal.
Args:
name: A name for this action.
labels: The labels of this action. Corresponds to Storm labels.
"""
@staticmethod
def create(labels: frozenset[str] | str | None = None) -> 'Action':
if isinstance(labels, str):
return Action(frozenset({labels}))
elif isinstance(labels, frozenset):
return Action(labels)
else:
return Action(frozenset())

name: str
labels: frozenset[str]

def __str__(self):
return f"Action {self.name} with labels {self.labels}"

# TODO remove these after modifying the whole code base to remove names.
def __eq__(self, other):
if isinstance(other, Action):
return self.labels == other.labels
return False

def strict_eq(self, other):
"""Also requires the names to be equal."""
if isinstance(other, Action):
return self.name == other.name and self.labels == other.labels
return False

# def __hash__(self):
# return self.labels.__hash__()
return f"Action with labels {self.labels}"


# The empty action. Used for DTMCs and empty action transitions in mdps.
EmptyAction = Action("empty", frozenset())
EmptyAction = Action(frozenset())


@dataclass(order=True)
Expand Down Expand Up @@ -371,21 +364,24 @@ def set_from_rewards_vector(self, vector: list[Number]) -> None:
self.rewards[s.id, a] = vector[combined_id]
combined_id += 1

def get_state_reward(self, state: State) -> Number:
"""Gets the reward at said state or state action pair"""
def get_state_reward(self, state: State) -> Number | None:
"""Gets the reward at said state or state action pair. Return None if no reward is present."""
if self.model.supports_actions():
RuntimeError(
"This is a model with actions. Please call the get_action_state_reward(_at_id) function instead"
)
return self.rewards[state.id, EmptyAction]
if (state.id, EmptyAction) in self.rewards:
return self.rewards[state.id, EmptyAction]
else:
return None

def get_state_action_reward(self, state: State, action: Action) -> Number | None:
"""Gets the reward at said state or state action pair. Returns None if no reward was found."""
if self.model.supports_actions():
if action in state.available_actions():
try:
if (state.id, action) in self.rewards:
return self.rewards[state.id, action]
except KeyError:
else:
return None
else:
RuntimeError("This action is not available in this state")
Expand Down Expand Up @@ -474,7 +470,7 @@ class Model:
# Both of these are hashed by the id of the state (=number in the matrix)
states: dict[int, State]
transitions: dict[int, Transition]
actions: dict[str, Action] | None
actions: set[Action] | None
rewards: list[RewardModel]
# In ctmcs we work with rate transitions but additionally we can optionally store exit rates (hashed by id of the state)
exit_rates: dict[int, Number] | None
Expand All @@ -492,7 +488,7 @@ def __init__(

# Initialize actions if those are supported by the model type
if self.supports_actions():
self.actions = {}
self.actions = set()
else:
self.actions = None

Expand Down Expand Up @@ -625,7 +621,7 @@ def get_state_action_id(self, state: State, action: Action) -> int | None:
for s in self.states.values():
for a in s.available_actions():
if (
a.name == action.name
a == action
and action in s.available_actions()
and s == state
):
Expand Down Expand Up @@ -668,6 +664,8 @@ def set_transitions(
"""Set the transition from a state."""
if not isinstance(transitions, Transition):
transitions = transition_from_shorthand(transitions)
if not self.actions is None and EmptyAction in transitions.transition.keys():
self.actions.add(EmptyAction)
self.transitions[s.id] = transitions

def add_transitions(
Expand Down Expand Up @@ -714,8 +712,8 @@ def add_transitions(
else:
for action, branch in transitions.transition.items():
assert self.actions is not None
if action not in self.actions.values():
self.actions[action.name] = action
if action not in self.actions:
self.actions.add(action)
self.transitions[s.id].transition[action] = branch

def get_transitions(self, state_or_id: State | int) -> Transition:
Expand All @@ -733,19 +731,15 @@ def get_branch(self, state_or_id: State | int) -> Branch:
raise RuntimeError("Called get_branch on a non-empty transition.")
return transition[EmptyAction]

def new_action(self, name: str, labels: frozenset[str] | None = None) -> Action:
def new_action(self, labels: frozenset[str] | str | None = None) -> Action:
"""Creates a new action and returns it."""
if not self.supports_actions():
raise RuntimeError(
"Called new_action on a model that does not support actions"
)
assert self.actions is not None
if name in self.actions:
raise RuntimeError(
f"Tried to add action {name} but that action already exists"
)
action = Action(name, labels if labels else frozenset())
self.actions[name] = action
action = Action.create(labels)
self.actions.add(action)
return action

def reassign_ids(self):
Expand Down Expand Up @@ -845,30 +839,32 @@ def remove_transitions_between_states(
"This method only works for models that don't support actions."
)

def get_action(self, name: str) -> Action:
"""Gets an existing action."""
if not self.supports_actions():
raise RuntimeError(
"Called get_action on a model that does not support actions"
)
assert self.actions is not None
if name not in self.actions:
raise RuntimeError(
f"Tried to get action {name} but that action does not exist"
)
return self.actions[name]

def action(self, name: str) -> Action:
# TODO possibly obsolete?
# def get_action(self, name: str) -> Action:
# """Gets an existing action."""
# if not self.supports_actions():
# raise RuntimeError(
# "Called get_action on a model that does not support actions"
# )
# assert self.actions is not None
# if name not in self.actions:
# raise RuntimeError(
# f"Tried to get action {name} but that action does not exist"
# )
# return self.actions[name]

def action(self, labels: frozenset[str] | str | None) -> Action:
"""New action or get action if it exists."""
if not self.supports_actions():
raise RuntimeError(
"Called get_action on a model that does not support actions"
"Called method action on a model that does not support actions"
)
assert self.actions is not None
if name in self.actions:
return self.get_action(name)
else:
return self.new_action(name)
action = Action.create(labels)

if not action in self.actions:
self.new_action(labels)
return action

def new_state(
self,
Expand Down Expand Up @@ -980,7 +976,7 @@ def to_dot(self) -> str:
for state_id, transition in self.transitions.items():
for action, branch in transition.transition.items():
if action != EmptyAction:
dot += f'{action.name.replace(" ", "_")}{state_id} [ label = "", shape=point ];\n'
dot += f'{state_id} [ label = "", shape=point ];\n'
for state_id, transition in self.transitions.items():
for action, branch in transition.transition.items():
if action == EmptyAction:
Expand All @@ -989,9 +985,9 @@ def to_dot(self) -> str:
dot += f'{state_id} -> {target.id} [ label = "{prob}" ];\n'
else:
# Draw actions, then probabilities
dot += f'{state_id} -> {action.name.replace(" ", "_")}{state_id} [ label = "{action.name}" ];\n'
dot += f'{state_id} -> {state_id} [ label = "{action.name}" ];\n'
for prob, target in branch.branch:
dot += f'{action.name.replace(" ", "_")}{state_id} -> {target.id} [ label = "{prob}" ];\n'
dot += f'{state_id} -> {target.id} [ label = "{prob}" ];\n'

dot += "}"
return dot
Expand Down Expand Up @@ -1021,7 +1017,7 @@ def __eq__(self, other) -> bool:
if self.supports_actions():
assert self.actions is not None and other.actions is not None
for action, other_action in zip(
sorted(self.actions.values()), sorted(other.actions.values())
sorted(self.actions), sorted(other.actions)
):
if not action == other_action:
return False
Expand Down
2 changes: 1 addition & 1 deletion stormvogel/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def add_scheduler(self, stormpy_scheduler: stormpy.storage.Scheduler):
self.stormpy_scheduler = stormpy_scheduler
taken_actions = {}
for state in self.model.states.values():
taken_actions[state.id] = self.model.get_action(
taken_actions[state.id] = stormvogel.model.Action.create(
str(stormpy_scheduler.get_choice(state.id))
)
self.scheduler = Scheduler(self.model, taken_actions)
Expand Down
9 changes: 5 additions & 4 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import stormpy.examples
from typing import Callable
import random

from stormvogel.model import EmptyAction

class Path:
"""
Expand Down Expand Up @@ -242,9 +242,10 @@ def simulate(

# we already set the rewards for the initial state/stateaction
if model.supports_actions():
reward_model.set_state_action_reward_at_id(
partial_model.get_initial_state().id,
model.rewards[index].get_state_reward(model.get_initial_state()),
reward_model.set_state_action_reward(
partial_model.get_initial_state(),
EmptyAction,
model.rewards[index].get_state_reward(model.get_initial_state())
)
else:
reward_model.set_state_reward(
Expand Down
6 changes: 3 additions & 3 deletions stormvogel/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __add_transitions(self) -> None:
# In the visualization, both actions and states are nodes, so we need to keep track of how many actions we already have.
for state_id, transition in self.model.transitions.items():
for action, branch in transition.transition.items():
if action.strict_eq(stormvogel.model.EmptyAction):
if action == stormvogel.model.EmptyAction:
# Only draw probabilities
for prob, target in branch.branch:
self.nt.add_edge(
Expand All @@ -175,7 +175,7 @@ def __add_transitions(self) -> None:
choice = self.scheduler.get_choice_of_state(
state=self.model.get_state_by_id(state_id)
)
if action.strict_eq(choice):
if action == choice:
group = "scheduled_actions"

reward = self.__format_rewards(
Expand All @@ -185,7 +185,7 @@ def __add_transitions(self) -> None:
# Add the action's node
self.nt.add_node(
id=action_id,
label=action.name + reward,
label=",".join(action.labels) + reward,
group=group,
position_dict=self.positions,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/saved_test_layout.json
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@
"init": {
"color": "TEST_COLOR"
}
}
}
Loading

0 comments on commit f0974c8

Please sign in to comment.