Skip to content

Commit

Permalink
Merge pull request #135 from moves-rwth/124-create-get_sub_model-method
Browse files Browse the repository at this point in the history
124 create get sub model method
  • Loading branch information
PimLeerkes authored Nov 10, 2024
2 parents 67f7ff0 + 98a5564 commit 35bd070
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
34 changes: 25 additions & 9 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from fractions import Fraction
from typing import cast
import copy

Parameter = str

Expand Down Expand Up @@ -489,7 +490,22 @@ def normalize(self):
# for ctmcs and mas we currently only add self loops
self.add_self_loops()

def __free_state_id(self):
def get_sub_model(self, states: list[State], normalize: bool = True) -> "Model":
"""Returns a submodel of the model based on a collection of states.
The states in the collection are the states that stay in the model."""
sub_model = copy.deepcopy(self)
remove = []
for state in sub_model.states.values():
if state not in states:
remove.append(state)
for state in remove:
sub_model.remove_state(state)

if normalize:
sub_model.normalize()
return sub_model

def __free_state_id(self) -> int:
"""Gets a free id in the states dict."""
# TODO: slow, not sure if that will become a problem though
i = 0
Expand Down Expand Up @@ -821,7 +837,7 @@ def set_rate(self, state: State, rate: Number):
raise RuntimeError("Cannot set a rate of a deterministic-time model.")
self.exit_rates[state.id] = rate

def get_type(self):
def get_type(self) -> ModelType:
"""Gets the type of this model"""
return self.type

Expand Down Expand Up @@ -869,7 +885,7 @@ def __str__(self) -> str:

return "\n".join(res)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if isinstance(other, Model):
return (
self.type == other.type
Expand All @@ -883,33 +899,33 @@ def __eq__(self, other):
return False


def new_dtmc(name: str | None = None, create_initial_state: bool = True):
def new_dtmc(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a DTMC."""
return Model(name, ModelType.DTMC, create_initial_state)


def new_mdp(name: str | None = None, create_initial_state: bool = True):
def new_mdp(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates an MDP."""
return Model(name, ModelType.MDP, create_initial_state)


def new_ctmc(name: str | None = None, create_initial_state: bool = True):
def new_ctmc(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a CTMC."""
return Model(name, ModelType.CTMC, create_initial_state)


def new_pomdp(name: str | None = None, create_initial_state: bool = True):
def new_pomdp(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a POMDP."""
return Model(name, ModelType.POMDP, create_initial_state)


def new_ma(name: str | None = None, create_initial_state: bool = True):
def new_ma(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a MA."""
return Model(name, ModelType.MA, create_initial_state)


def new_model(
modeltype: ModelType, name: str | None = None, create_initial_state: bool = True
):
) -> Model:
"""More general model creation function"""
return Model(name, modeltype, create_initial_state)
19 changes: 18 additions & 1 deletion tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_transition_from_shorthand():
# Then we test it for a model with actions
mdp = stormvogel.model.new_mdp()
state = mdp.new_state()
action = mdp.new_action("0", frozenset("action"))
action = mdp.new_action("0", frozenset({"action"}))
transition_shorthand = [(action, state)]
branch = stormvogel.model.Branch(
cast(list[tuple[stormvogel.model.Number, stormvogel.model.State]], [(1, state)])
Expand Down Expand Up @@ -302,3 +302,20 @@ def test_add_transitions():
# print(mdp6.get_transitions(mdp6.get_initial_state()).transition)
# print([(action6a, state6), (action6b, state6)])
assert len(mdp6.get_transitions(mdp6.get_initial_state()).transition) == 2


def test_get_sub_model():
# we create the die dtmc and take a submodel
dtmc = examples.die.create_die_dtmc()
states = [dtmc.get_state_by_id(0), dtmc.get_state_by_id(1), dtmc.get_state_by_id(2)]
sub_model = dtmc.get_sub_model(states)

# we build what the submodel should look like
new_dtmc = stormvogel.model.new_dtmc("Die")
init = new_dtmc.get_initial_state()
init.set_transitions(
[(1 / 6, new_dtmc.new_state(f"rolled{i}", {"rolled": i})) for i in range(2)]
)
new_dtmc.normalize()

assert sub_model == new_dtmc

0 comments on commit 35bd070

Please sign in to comment.