Skip to content

Commit

Permalink
get_sub_model method works
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 10, 2024
1 parent 67f7ff0 commit 94c744c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
14 changes: 14 additions & 0 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from fractions import Fraction
from typing import cast
import copy

Parameter = str

Expand Down Expand Up @@ -489,6 +490,19 @@ def normalize(self):
# for ctmcs and mas we currently only add self loops
self.add_self_loops()

def get_sub_model(self, states: list[State]) -> "Model":
"""returns a submodel of the model based on a collection of states"""
sub_model = copy.deepcopy(self)
remove = []
for state in sub_model.states.values():
if state not in states:
remove.append(state)
for state in remove:
sub_model.remove_state(state)

sub_model.normalize()
return sub_model

def __free_state_id(self):
"""Gets a free id in the states dict."""
# TODO: slow, not sure if that will become a problem though
Expand Down
19 changes: 18 additions & 1 deletion tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_transition_from_shorthand():
# Then we test it for a model with actions
mdp = stormvogel.model.new_mdp()
state = mdp.new_state()
action = mdp.new_action("0", frozenset("action"))
action = mdp.new_action("0", frozenset({"action"}))
transition_shorthand = [(action, state)]
branch = stormvogel.model.Branch(
cast(list[tuple[stormvogel.model.Number, stormvogel.model.State]], [(1, state)])
Expand Down Expand Up @@ -302,3 +302,20 @@ def test_add_transitions():
# print(mdp6.get_transitions(mdp6.get_initial_state()).transition)
# print([(action6a, state6), (action6b, state6)])
assert len(mdp6.get_transitions(mdp6.get_initial_state()).transition) == 2


def test_get_sub_model():
# we create the die dtmc and take a submodel
dtmc = examples.die.create_die_dtmc()
states = [dtmc.get_state_by_id(0), dtmc.get_state_by_id(1), dtmc.get_state_by_id(2)]
sub_model = dtmc.get_sub_model(states)

# we build what the submodel should look like
new_dtmc = stormvogel.model.new_dtmc("Die")
init = new_dtmc.get_initial_state()
init.set_transitions(
[(1 / 6, new_dtmc.new_state(f"rolled{i}", {"rolled": i})) for i in range(2)]
)
new_dtmc.normalize()

assert sub_model == new_dtmc

0 comments on commit 94c744c

Please sign in to comment.