Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

119 fix the action dicts equality in models #147

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions stormvogel/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,12 @@ def map_mdp(sparsemdp: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
for i in range(row_group_start, row_group_end):
row = matrix.get_row(i)

actionlabels = frozenset(
sparsemdp.choice_labeling.get_labels_of_choice(i)
if sparsemdp.has_choice_labeling()
else str(i)
)
if sparsemdp.has_choice_labeling():
actionlabels = frozenset(
sparsemdp.choice_labeling.get_labels_of_choice(i)
)
else:
actionlabels = frozenset()
# TODO assign the correct action name and not only an index
action = model.new_action(str(i), actionlabels)
branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]
Expand Down Expand Up @@ -482,11 +483,13 @@ def map_pomdp(sparsepomdp: stormpy.storage.SparsePomdp) -> stormvogel.model.Mode
for i in range(row_group_start, row_group_end):
row = matrix.get_row(i)

actionlabels = frozenset(
sparsepomdp.choice_labeling.get_labels_of_choice(i)
if sparsepomdp.has_choice_labeling()
else str(i)
)
if sparsepomdp.has_choice_labeling():
actionlabels = frozenset(
sparsepomdp.choice_labeling.get_labels_of_choice(i)
)
else:
actionlabels = frozenset()

# TODO assign the correct action name and not only an index
action = model.new_action(str(i), actionlabels)
branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]
Expand Down Expand Up @@ -528,11 +531,13 @@ def map_ma(sparsema: stormpy.storage.SparseMA) -> stormvogel.model.Model:
for i in range(row_group_start, row_group_end):
row = matrix.get_row(i)

actionlabels = frozenset(
sparsema.choice_labeling.get_labels_of_choice(i)
if sparsema.has_choice_labeling()
else str(i)
)
if sparsema.has_choice_labeling():
actionlabels = frozenset(
sparsema.choice_labeling.get_labels_of_choice(i)
)
else:
actionlabels = frozenset()

# TODO assign the correct action name and not only an index
action = model.new_action(str(i), actionlabels)
branch = [(x.value(), model.get_state_by_id(x.column)) for x in row]
Expand Down
46 changes: 31 additions & 15 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ class Action:
def __str__(self):
return f"Action {self.name} with labels {self.labels}"

def __eq__(self, other):
if isinstance(other, Action):
return self.labels == other.labels
return False


# The empty action. Used for DTMCs and empty action transitions in mdps.
EmptyAction = Action("empty", frozenset())
Expand Down Expand Up @@ -270,17 +275,22 @@ def __str__(self):
parts.append(f"{action} => {branch}")
return "; ".join(parts + [])

def __eq__(self, other):
if isinstance(other, Transition):
return sorted(list(self.transition.values())) == sorted(
list(other.transition.values())
)
return False

def has_empty_action(self) -> bool:
# Note that we don't have to deal with the corner case where there are both empty and non-empty transitions. This is dealt with at __init__.
return self.transition.keys() == {EmptyAction}

def __eq__(self, other):
if isinstance(other, Transition):
if len(self.transition) != len(other.transition):
return False
for item, other_item in zip(
sorted(self.transition.items()), sorted(other.transition.items())
):
if not (item[0] == other_item[0] and item[1] == other_item[1]):
return False
return True
return False


TransitionShorthand = list[tuple[Number, State]] | list[tuple[Action, State]]

Expand Down Expand Up @@ -542,7 +552,7 @@ def set_transitions(self, s: State, transitions: Transition | TransitionShorthan
self.transitions[s.id] = transitions

def add_transitions(self, s: State, transitions: Transition | TransitionShorthand):
"""Add new transitions from a state. If no transition currently exists, the result will be the same as set_transitions."""
"""Add new transitions from a state to the model. If no transition currently exists, the result will be the same as set_transitions."""

if not isinstance(transitions, Transition):
transitions = transition_from_shorthand(transitions)
Expand Down Expand Up @@ -581,8 +591,11 @@ def add_transitions(self, s: State, transitions: Transition | TransitionShorthan
transitions.transition[EmptyAction]
)
else:
for choice, branch in transitions.transition.items():
self.transitions[s.id].transition[choice] = branch
for action, branch in transitions.transition.items():
assert self.actions is not None
if action not in self.actions.values():
self.actions[action.name] = action
self.transitions[s.id].transition[action] = branch

def get_transitions(self, state_or_id: State | int) -> Transition:
"""Get the transition at state s. Throws a KeyError if not present."""
Expand Down Expand Up @@ -610,10 +623,7 @@ def new_action(self, name: str, labels: frozenset[str] | None = None) -> Action:
raise RuntimeError(
f"Tried to add action {name} but that action already exists"
)
if labels:
action = Action(name, labels)
else:
action = Action(name, frozenset())
action = Action(name, labels if labels else frozenset())
self.actions[name] = action
return action

Expand Down Expand Up @@ -887,14 +897,20 @@ def __str__(self) -> str:

def __eq__(self, other) -> bool:
if isinstance(other, Model):
if self.supports_actions():
assert self.actions is not None and other.actions is not None
for action, other_action in zip(
sorted(self.actions.values()), sorted(other.actions.values())
):
if not action == other_action:
return False
return (
self.type == other.type
and self.states == other.states
and self.transitions == other.transitions
and sorted(self.rewards) == sorted(other.rewards)
and self.exit_rates == other.exit_rates
and self.markovian_states == other.markovian_states
# TODO: and self.actions == other.actions
)
return False

Expand Down
16 changes: 15 additions & 1 deletion stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,21 @@ def __str__(self) -> str:

def __eq__(self, other):
if isinstance(other, Path):
return self.path == other.path and self.model == other.model
if not self.model.supports_actions():
return self.path == other.path and self.model == other.model
else:
if len(self.path) != len(other.path):
return False
for tuple, other_tuple in zip(
sorted(self.path.values()), sorted(other.path.values())
):
assert not (
isinstance(tuple, stormvogel.model.State)
or isinstance(other_tuple, stormvogel.model.State)
)
if not (tuple[0] == other_tuple[0] and tuple[1] == other_tuple[1]):
return False
return self.model == other.model
else:
return False

Expand Down
2 changes: 1 addition & 1 deletion tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def test_convert_model_checker_results_mdp():
0,
0,
] == [
int(list(action.labels)[0])
int(list(action.name)[0])
for action in stormvogel_result.scheduler.taken_actions.values()
]

Expand Down
6 changes: 3 additions & 3 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_simulate_path():
other_path = stormvogel.simulator.Path(
{
1: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(3)),
2: (pomdp.actions["open2"], pomdp.get_state_by_id(12)),
3: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(23)),
4: (pomdp.actions["switch"], pomdp.get_state_by_id(46)),
2: (pomdp.actions["open0"], pomdp.get_state_by_id(10)),
3: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(21)),
4: (pomdp.actions["stay"], pomdp.get_state_by_id(41)),
},
pomdp,
)
Expand Down