Skip to content

Commit

Permalink
added more tests and fixed available actions case where state has no …
Browse files Browse the repository at this point in the history
…actions
  • Loading branch information
PimLeerkes committed Nov 17, 2024
1 parent 5c58880 commit 9a4bd28
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
8 changes: 4 additions & 4 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def add_transitions(self, transitions: "Transition | TransitionShorthand"):

def available_actions(self) -> list["Action"]:
"""returns the list of all available actions in this state"""
if self.model.supports_actions():
if self.model.supports_actions() and self.id in self.model.transitions.keys():
action_list = []
for action in self.model.transitions[self.id].transition.keys():
action_list.append(action)
Expand Down Expand Up @@ -360,7 +360,7 @@ def get_state_action_reward(self, state: State, action: Action) -> Number | None
assert id is not None
return self.rewards[id]
else:
RuntimeError("This action is not a choice for this state")
RuntimeError("This action is not available in this state")
else:
RuntimeError(
"The model this rewardmodel belongs to does not support actions"
Expand All @@ -383,7 +383,7 @@ def set_state_action_reward(self, state: State, action: Action, value: Number):
assert id is not None
self.rewards[id] = value
else:
RuntimeError("This action is not a choice for this state")
RuntimeError("This action is not available in this state")
else:
RuntimeError(
"The model this rewardmodel belongs to does not support actions"
Expand Down Expand Up @@ -575,7 +575,7 @@ def get_sub_model(self, states: list[State], normalize: bool = True) -> "Model":
return sub_model

def get_state_action_id(self, state: State, action: Action) -> int | None:
"""we calculate the appropriate state_action_id for a given state and action"""
"""we calculate the appropriate state action id for a given state and action"""
id = 0
for s in self.states.values():
for a in s.available_actions():
Expand Down
18 changes: 18 additions & 0 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def test_available_actions():
]
assert mdp.get_state_by_id(1).available_actions() == action

# we also test it for a state with no available actions
mdp = stormvogel.model.new_mdp()
assert mdp.get_initial_state().available_actions()


def test_get_outgoing_transitions():
mdp = examples.monty_hall.create_monty_hall_mdp()
Expand Down Expand Up @@ -359,3 +363,17 @@ def test_set_state_action_reward():
other_rewardmodel = stormvogel.model.RewardModel("rewardmodel", mdp, {0: 5})

assert rewardmodel == other_rewardmodel

# we create an mdp:
mdp = examples.monty_hall.create_monty_hall_mdp()

# we add a reward model with only one reward
rewardmodel = mdp.add_rewards("rewardmodel")
state = mdp.get_state_by_id(2)
action = state.available_actions()[1]
rewardmodel.set_state_action_reward(state, action, 3)

# we make a reward model manually:
other_rewardmodel = stormvogel.model.RewardModel("rewardmodel", mdp, {5: 3})

assert rewardmodel == other_rewardmodel

0 comments on commit 9a4bd28

Please sign in to comment.