Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 10, 2024
1 parent 94c744c commit 98a5564
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ def normalize(self):
# for ctmcs and mas we currently only add self loops
self.add_self_loops()

def get_sub_model(self, states: list[State]) -> "Model":
"""returns a submodel of the model based on a collection of states"""
def get_sub_model(self, states: list[State], normalize: bool = True) -> "Model":
"""Returns a submodel of the model based on a collection of states.
The states in the collection are the states that stay in the model."""
sub_model = copy.deepcopy(self)
remove = []
for state in sub_model.states.values():
Expand All @@ -500,10 +501,11 @@ def get_sub_model(self, states: list[State]) -> "Model":
for state in remove:
sub_model.remove_state(state)

sub_model.normalize()
if normalize:
sub_model.normalize()
return sub_model

def __free_state_id(self):
def __free_state_id(self) -> int:
"""Gets a free id in the states dict."""
# TODO: slow, not sure if that will become a problem though
i = 0
Expand Down Expand Up @@ -835,7 +837,7 @@ def set_rate(self, state: State, rate: Number):
raise RuntimeError("Cannot set a rate of a deterministic-time model.")
self.exit_rates[state.id] = rate

def get_type(self):
def get_type(self) -> ModelType:
"""Gets the type of this model"""
return self.type

Expand Down Expand Up @@ -883,7 +885,7 @@ def __str__(self) -> str:

return "\n".join(res)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if isinstance(other, Model):
return (
self.type == other.type
Expand All @@ -897,33 +899,33 @@ def __eq__(self, other):
return False


def new_dtmc(name: str | None = None, create_initial_state: bool = True):
def new_dtmc(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a DTMC."""
return Model(name, ModelType.DTMC, create_initial_state)


def new_mdp(name: str | None = None, create_initial_state: bool = True):
def new_mdp(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates an MDP."""
return Model(name, ModelType.MDP, create_initial_state)


def new_ctmc(name: str | None = None, create_initial_state: bool = True):
def new_ctmc(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a CTMC."""
return Model(name, ModelType.CTMC, create_initial_state)


def new_pomdp(name: str | None = None, create_initial_state: bool = True):
def new_pomdp(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a POMDP."""
return Model(name, ModelType.POMDP, create_initial_state)


def new_ma(name: str | None = None, create_initial_state: bool = True):
def new_ma(name: str | None = None, create_initial_state: bool = True) -> Model:
"""Creates a MA."""
return Model(name, ModelType.MA, create_initial_state)


def new_model(
modeltype: ModelType, name: str | None = None, create_initial_state: bool = True
):
) -> Model:
"""More general model creation function"""
return Model(name, modeltype, create_initial_state)

0 comments on commit 98a5564

Please sign in to comment.