Skip to content

Commit

Permalink
warning message for delete state function
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Feb 4, 2025
1 parent bfdd07c commit 37f8f53
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
23 changes: 16 additions & 7 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def __init__(
self.observation = None

if name is None:
self.name = str(id) # TODO Two states can have same name in some cases
if str(id) in used_names:
raise RuntimeError(
"You need to choose a state name because of a conflict caused by removal of states."
)
self.name = str(id)
else:
self.name = name

Expand Down Expand Up @@ -769,9 +773,15 @@ def reassign_ids(self):
}

def remove_state(
self, state: State, normalize: bool = True, reassign_ids: bool = True
self, state: State, normalize: bool = True, reassign_ids: bool = False
):
"""properly removes a state, it can optionally normalize the model and reassign ids automatically"""
"""Properly removes a state, it can optionally normalize the model and reassign ids automatically."""

if reassign_ids:
print(
"Warning: Using this can cause problems in your code if there are existing references to states by id."
)

if state in self.states.values():
# we remove the state from the transitions
# first we remove transitions that go into the state
Expand Down Expand Up @@ -885,16 +895,15 @@ def new_state(
self,
labels: list[str] | str | None = None,
features: dict[str, int] | None = None,
name: str | None = None,
) -> State:
"""Creates a new state and returns it."""
state_id = self.__free_state_id()
if isinstance(labels, list):
state = State(labels, features or {}, state_id, self, name=name)
state = State(labels, features or {}, state_id, self)
elif isinstance(labels, str):
state = State([labels], features or {}, state_id, self, name=name)
state = State([labels], features or {}, state_id, self)
elif labels is None:
state = State([], features or {}, state_id, self, name=name)
state = State([], features or {}, state_id, self)

self.states[state_id] = state

Expand Down
12 changes: 10 additions & 2 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_normalize():
def test_remove_state():
# we make a normal ctmc and remove a state
ctmc = examples.nuclear_fusion_ctmc.create_nuclear_fusion_ctmc()
ctmc.remove_state(ctmc.get_state_by_id(3))
ctmc.remove_state(ctmc.get_state_by_id(3), reassign_ids=True)

# we make a ctmc with the state already missing
new_ctmc = stormvogel.model.new_ctmc("Nuclear fusion")
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_remove_state():
mdp.set_transitions(mdp.get_initial_state(), transition)

# we remove a state
mdp.remove_state(mdp.get_state_by_id(0))
mdp.remove_state(mdp.get_state_by_id(0), reassign_ids=True)

# we make the mdp with the state already missing
new_mdp = stormvogel.model.new_mdp(create_initial_state=False)
Expand All @@ -203,6 +203,14 @@ def test_remove_state():

assert mdp == new_mdp

# this should fail:
new_dtmc = examples.die.create_die_dtmc()
state0 = new_dtmc.get_state_by_id(0)
new_dtmc.remove_state(new_dtmc.get_initial_state(), reassign_ids=True)
state1 = new_dtmc.get_state_by_id(0)

assert state0 != state1


def test_remove_transitions_between_states():
# we make a model and remove transitions between two states
Expand Down

0 comments on commit 37f8f53

Please sign in to comment.