From 37f8f534159a1393da5b50083043809fa1e0fb73 Mon Sep 17 00:00:00 2001 From: PimLeerkes Date: Tue, 4 Feb 2025 14:00:41 +0100 Subject: [PATCH] warning message for delete state function --- stormvogel/model.py | 23 ++++++++++++++++------- tests/test_model_methods.py | 12 ++++++++++-- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/stormvogel/model.py b/stormvogel/model.py index 6bef8de..0a3a6e5 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -92,7 +92,11 @@ def __init__( self.observation = None if name is None: - self.name = str(id) # TODO Two states can have same name in some cases + if str(id) in used_names: + raise RuntimeError( + "You need to choose a state name because of a conflict caused by removal of states." + ) + self.name = str(id) else: self.name = name @@ -769,9 +773,15 @@ def reassign_ids(self): } def remove_state( - self, state: State, normalize: bool = True, reassign_ids: bool = True + self, state: State, normalize: bool = True, reassign_ids: bool = False ): - """properly removes a state, it can optionally normalize the model and reassign ids automatically""" + """Properly removes a state, it can optionally normalize the model and reassign ids automatically.""" + + if reassign_ids: + print( + "Warning: Using this can cause problems in your code if there are existing references to states by id." + ) + if state in self.states.values(): # we remove the state from the transitions # first we remove transitions that go into the state @@ -885,16 +895,15 @@ def new_state( self, labels: list[str] | str | None = None, features: dict[str, int] | None = None, - name: str | None = None, ) -> State: """Creates a new state and returns it.""" state_id = self.__free_state_id() if isinstance(labels, list): - state = State(labels, features or {}, state_id, self, name=name) + state = State(labels, features or {}, state_id, self) elif isinstance(labels, str): - state = State([labels], features or {}, state_id, self, name=name) + state = State([labels], features or {}, state_id, self) elif labels is None: - state = State([], features or {}, state_id, self, name=name) + state = State([], features or {}, state_id, self) self.states[state_id] = state diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index 4ab92e3..943d71f 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -155,7 +155,7 @@ def test_normalize(): def test_remove_state(): # we make a normal ctmc and remove a state ctmc = examples.nuclear_fusion_ctmc.create_nuclear_fusion_ctmc() - ctmc.remove_state(ctmc.get_state_by_id(3)) + ctmc.remove_state(ctmc.get_state_by_id(3), reassign_ids=True) # we make a ctmc with the state already missing new_ctmc = stormvogel.model.new_ctmc("Nuclear fusion") @@ -193,7 +193,7 @@ def test_remove_state(): mdp.set_transitions(mdp.get_initial_state(), transition) # we remove a state - mdp.remove_state(mdp.get_state_by_id(0)) + mdp.remove_state(mdp.get_state_by_id(0), reassign_ids=True) # we make the mdp with the state already missing new_mdp = stormvogel.model.new_mdp(create_initial_state=False) @@ -203,6 +203,14 @@ def test_remove_state(): assert mdp == new_mdp + # this should fail: + new_dtmc = examples.die.create_die_dtmc() + state0 = new_dtmc.get_state_by_id(0) + new_dtmc.remove_state(new_dtmc.get_initial_state(), reassign_ids=True) + state1 = new_dtmc.get_state_by_id(0) + + assert state0 != state1 + def test_remove_transitions_between_states(): # we make a model and remove transitions between two states