From 98a556421d07c65db94d06bc590c3442709d8db2 Mon Sep 17 00:00:00 2001 From: PimLeerkes Date: Sun, 10 Nov 2024 10:15:03 +0100 Subject: [PATCH] small changes --- stormvogel/model.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/stormvogel/model.py b/stormvogel/model.py index 5d239e1..12c59cb 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -490,8 +490,9 @@ def normalize(self): # for ctmcs and mas we currently only add self loops self.add_self_loops() - def get_sub_model(self, states: list[State]) -> "Model": - """returns a submodel of the model based on a collection of states""" + def get_sub_model(self, states: list[State], normalize: bool = True) -> "Model": + """Returns a submodel of the model based on a collection of states. + The states in the collection are the states that stay in the model.""" sub_model = copy.deepcopy(self) remove = [] for state in sub_model.states.values(): @@ -500,10 +501,11 @@ def get_sub_model(self, states: list[State]) -> "Model": for state in remove: sub_model.remove_state(state) - sub_model.normalize() + if normalize: + sub_model.normalize() return sub_model - def __free_state_id(self): + def __free_state_id(self) -> int: """Gets a free id in the states dict.""" # TODO: slow, not sure if that will become a problem though i = 0 @@ -835,7 +837,7 @@ def set_rate(self, state: State, rate: Number): raise RuntimeError("Cannot set a rate of a deterministic-time model.") self.exit_rates[state.id] = rate - def get_type(self): + def get_type(self) -> ModelType: """Gets the type of this model""" return self.type @@ -883,7 +885,7 @@ def __str__(self) -> str: return "\n".join(res) - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, Model): return ( self.type == other.type @@ -897,33 +899,33 @@ def __eq__(self, other): return False -def new_dtmc(name: str | None = None, create_initial_state: bool = True): +def new_dtmc(name: str | None = None, create_initial_state: bool = True) -> Model: """Creates a DTMC.""" return Model(name, ModelType.DTMC, create_initial_state) -def new_mdp(name: str | None = None, create_initial_state: bool = True): +def new_mdp(name: str | None = None, create_initial_state: bool = True) -> Model: """Creates an MDP.""" return Model(name, ModelType.MDP, create_initial_state) -def new_ctmc(name: str | None = None, create_initial_state: bool = True): +def new_ctmc(name: str | None = None, create_initial_state: bool = True) -> Model: """Creates a CTMC.""" return Model(name, ModelType.CTMC, create_initial_state) -def new_pomdp(name: str | None = None, create_initial_state: bool = True): +def new_pomdp(name: str | None = None, create_initial_state: bool = True) -> Model: """Creates a POMDP.""" return Model(name, ModelType.POMDP, create_initial_state) -def new_ma(name: str | None = None, create_initial_state: bool = True): +def new_ma(name: str | None = None, create_initial_state: bool = True) -> Model: """Creates a MA.""" return Model(name, ModelType.MA, create_initial_state) def new_model( modeltype: ModelType, name: str | None = None, create_initial_state: bool = True -): +) -> Model: """More general model creation function""" return Model(name, modeltype, create_initial_state)