Skip to content

Commit

Permalink
fixed ma remove_state bug and added: is_absorbing function
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 6, 2024
1 parent 5b57a3b commit 4f141f1
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 28 deletions.
5 changes: 1 addition & 4 deletions examples/simple_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def create_simple_ma():

init = ma.get_initial_state()

# We have 2 actions
# We have 5 actions
init.set_transitions(
[
(
Expand All @@ -29,9 +29,6 @@ def create_simple_ma():
# we add self loops to all states with no outgoing transitions
ma.add_self_loops()

# we delete a state
ma.remove_state(ma.get_state_by_id(3), True)

return ma


Expand Down
28 changes: 15 additions & 13 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def get_outgoing_transitions(
return branch.branch
return None

def has_outgoing_transition(self, action: "Action | None" = None) -> bool:
"""returns if the state has a nonzero outgoing transition or not"""
def is_absorbing(self, action: "Action | None" = None) -> bool:
"""returns if the state has a nonzero transition going to another state or not"""
transitions = self.get_outgoing_transitions(action)
if transitions is not None:
for transition in transitions:
if float(transition[0]) > 0:
return True
return False
if float(transition[0]) > 0 and transition[1] != self:
return False
return True

def __str__(self):
res = f"State {self.id} with labels {self.labels} and features {self.features}"
Expand Down Expand Up @@ -630,20 +630,22 @@ def remove_state(
# first we remove transitions that go into the state
remove_actions_index = []
for index, transition in self.transitions.items():
for action in transition.transition.items():
for index_tuple, tuple in enumerate(action[1].branch):
for action, branch in transition.transition.items():
for index_tuple, tuple in enumerate(branch.branch):
# remove the tuple if it goes to the state
if tuple[1].id == state.id:
self.transitions[index].transition[action[0]].branch.pop(
self.transitions[index].transition[action].branch.pop(
index_tuple
)

# if we have empty objects we need to remove those as well
if self.transitions[index].transition[action[0]].branch == []:
remove_actions_index.append((action[0], index))
# here we remove those empty objects
# if we have empty actions we need to remove those as well (later)
if branch.branch == []:
remove_actions_index.append((action, index))
# here we remove those empty actions (this needs to happen after the other for loops)
for action, index in remove_actions_index:
self.transitions[index].transition.pop(action)
if self.transitions[index].transition == {}:
# if we have no actions at all anymore, delete the transition
if self.transitions[index].transition == {} and not index == state.id:
self.transitions.pop(index)

# we remove transitions that come out of the state
Expand Down
7 changes: 2 additions & 5 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ def get_range_index(stateid: int):
if not model.supports_actions():
for i in range(steps):
# for each step we add a state to the path
if (
model.states[state].has_outgoing_transition()
and not simulator.is_done()
):
if not model.states[state].is_absorbing() and not simulator.is_done():
state, reward, labels = simulator.step()
path[i + 1] = model.states[state]
else:
Expand All @@ -159,7 +156,7 @@ def get_range_index(stateid: int):
stormvogel_action = model.states[state].available_actions()[select_action]

if (
model.states[state].has_outgoing_transition(stormvogel_action)
not model.states[state].is_absorbing(stormvogel_action)
and not simulator.is_done()
):
state, reward, labels = simulator.step(actions[select_action])
Expand Down
7 changes: 1 addition & 6 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def sparse_equal(
)


"""
def test_stormpy_to_stormvogel_and_back_dtmc():
# we test it for an example stormpy representation of a dtmc
stormpy_dtmc = examples.stormpy_dtmc.example_building_dtmcs_01()
Expand Down Expand Up @@ -159,7 +158,6 @@ def test_stormvogel_to_stormpy_and_back_mdp():
# print(new_stormvogel_mdp)

assert new_stormvogel_mdp == stormvogel_mdp
"""


def test_stormvogel_to_stormpy_and_back_ctmc():
Expand All @@ -174,7 +172,6 @@ def test_stormvogel_to_stormpy_and_back_ctmc():
assert new_stormvogel_ctmc == stormvogel_ctmc


"""
def test_stormpy_to_stormvogel_and_back_ctmc():
# we create a stormpy representation of an example ctmc
stormpy_ctmc = examples.stormpy_ctmc.example_building_ctmcs_01()
Expand Down Expand Up @@ -211,9 +208,8 @@ def test_stormpy_to_stormvogel_and_back_pomdp():
# print(new_stormpy_pomdp)

assert sparse_equal(stormpy_pomdp, new_stormpy_pomdp)
"""

"""

def test_stormvogel_to_stormpy_and_back_ma():
# we create a stormpy representation of an example ma
stormvogel_ma = examples.simple_ma.create_simple_ma()
Expand All @@ -237,4 +233,3 @@ def test_stormpy_to_stormvogel_and_back_ma():
# print(new_stormpy_ma)

assert sparse_equal(stormpy_ma, new_stormpy_ma)
"""

0 comments on commit 4f141f1

Please sign in to comment.