Skip to content

Commit

Permalink
Merge pull request #150 from moves-rwth/149-make-it-clear-which-id-co…
Browse files Browse the repository at this point in the history
…rresponds-to-which-state-action-pair

149 make it clear which id corresponds to which state action pair
  • Loading branch information
PimLeerkes authored Nov 17, 2024
2 parents 58f007a + 9a4bd28 commit 7909cca
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 25 deletions.
4 changes: 2 additions & 2 deletions stormvogel/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ def add_rewards(
else rewards.state_rewards
):
if model.supports_actions():
rewardmodel.set_action_state(index, reward)
rewardmodel.set_state_action_reward_at_id(index, reward)
else:
rewardmodel.set(model.get_state_by_id(index), reward)
rewardmodel.set_state_reward(model.get_state_by_id(index), reward)

def map_dtmc(sparsedtmc: stormpy.storage.SparseDtmc) -> stormvogel.model.Model:
"""
Expand Down
98 changes: 85 additions & 13 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def add_transitions(self, transitions: "Transition | TransitionShorthand"):

def available_actions(self) -> list["Action"]:
"""returns the list of all available actions in this state"""
if self.model.supports_actions():
if self.model.supports_actions() and self.id in self.model.transitions.keys():
action_list = []
for action in self.model.transitions[self.id].transition.keys():
action_list.append(action)
Expand Down Expand Up @@ -325,7 +325,7 @@ def transition_from_shorthand(shorthand: TransitionShorthand) -> Transition:
)


@dataclass(order=True)
@dataclass()
class RewardModel:
"""Represents a state-exit reward model.
Args:
Expand All @@ -334,20 +334,79 @@ class RewardModel:
"""

name: str
model: "Model"
# Hashed by the id of the state or state action pair (=number in the matrix)
rewards: dict[int, Number]

def get(self, state: State) -> Number:
"""Gets the reward at said state."""
def __init__(self, name: str, model: "Model", rewards: dict[int, Number]):
self.name = name
self.rewards = rewards
self.model = model

if self.model.supports_actions():
self.set_action_state = {}
else:
self.state_action_pair = None

def get_state_reward(self, state: State) -> Number:
"""Gets the reward at said state or state action pair"""
return self.rewards[state.id]

def set(self, state: State, value: Number):
def get_state_action_reward(self, state: State, action: Action) -> Number | None:
"""Gets the reward at said state or state action pair"""
if self.model.supports_actions():
if action in state.available_actions():
id = self.model.get_state_action_id(state, action)
assert id is not None
return self.rewards[id]
else:
RuntimeError("This action is not available in this state")
else:
RuntimeError(
"The model this rewardmodel belongs to does not support actions"
)

def set_state_reward(self, state: State, value: Number):
"""Sets the reward at said state."""
self.rewards[state.id] = value
if self.model.supports_actions():
RuntimeError(
"This is a model with actions. Please call the set_action_state_reward(_at_id) function instead"
)
else:
self.rewards[state.id] = value

def set_state_action_reward(self, state: State, action: Action, value: Number):
"""sets the reward at said state action pair (in case of models with actions)"""
if self.model.supports_actions():
if action in state.available_actions():
id = self.model.get_state_action_id(state, action)
assert id is not None
self.rewards[id] = value
else:
RuntimeError("This action is not available in this state")
else:
RuntimeError(
"The model this rewardmodel belongs to does not support actions"
)

def set_state_action_reward_at_id(self, action_state: int, value: Number):
"""sets the reward at said state action pair for a given id (in the case of models with actions)"""
if self.model.supports_actions():
self.rewards[action_state] = value
else:
RuntimeError(
"The model this rewardmodel belongs to does not support actions"
)

def __lt__(self, other) -> bool:
if not isinstance(other, RewardModel):
return NotImplemented
return self.name < other.name

def set_action_state(self, state_action_pair: int, value: Number):
"""sets the reward at said state action pair"""
self.rewards[state_action_pair] = value
def __eq__(self, other) -> bool:
if isinstance(other, RewardModel):
return self.name == other.name and self.rewards == other.rewards
return False


@dataclass
Expand Down Expand Up @@ -416,15 +475,15 @@ def __init__(

def supports_actions(self):
"""Returns whether this model supports actions."""
return self.type in (ModelType.MDP, ModelType.POMDP, ModelType.MA)
return self.get_type() in (ModelType.MDP, ModelType.POMDP, ModelType.MA)

def supports_rates(self):
"""Returns whether this model supports rates."""
return self.type in (ModelType.CTMC, ModelType.MA)
return self.get_type() in (ModelType.CTMC, ModelType.MA)

def supports_observations(self):
"""Returns whether this model supports observations."""
return self.type == ModelType.POMDP
return self.get_type() == ModelType.POMDP

def is_stochastic(self) -> bool:
"""For discrete models: Checks if all sums of outgoing transition probabilities for all states equal 1
Expand Down Expand Up @@ -515,6 +574,19 @@ def get_sub_model(self, states: list[State], normalize: bool = True) -> "Model":
sub_model.normalize()
return sub_model

def get_state_action_id(self, state: State, action: Action) -> int | None:
"""we calculate the appropriate state action id for a given state and action"""
id = 0
for s in self.states.values():
for a in s.available_actions():
if (
a.name == action.name
and action in s.available_actions()
and s == state
):
return id
id += 1

def __free_state_id(self) -> int:
"""Gets a free id in the states dict."""
# TODO: slow, not sure if that will become a problem though
Expand Down Expand Up @@ -824,7 +896,7 @@ def add_rewards(self, name: str) -> RewardModel:
for model in self.rewards:
if model.name == name:
raise RuntimeError(f"Reward model {name} already present in model.")
reward_model = RewardModel(name, {})
reward_model = RewardModel(name, self, {})
self.rewards.append(reward_model)
return reward_model

Expand Down
22 changes: 15 additions & 7 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,17 @@ def simulate(
for index, reward in enumerate(model.rewards):
reward_model = partial_model.add_rewards(model.rewards[index].name)

# we already set the rewards for the initial state
reward_model.set(
partial_model.get_initial_state(),
model.rewards[index].get(model.get_initial_state()),
)
# we already set the rewards for the initial state/stateaction
if model.supports_actions():
reward_model.set_state_action_reward_at_id(
partial_model.get_initial_state().id,
model.rewards[index].get_state_reward(model.get_initial_state()),
)
else:
reward_model.set_state_reward(
partial_model.get_initial_state(),
model.rewards[index].get_state_reward(model.get_initial_state()),
)

# now we start stepping through the model
discovered_states = {0}
Expand Down Expand Up @@ -276,7 +282,7 @@ def simulate(

# we add the rewards
for index, rewardmodel in enumerate(partial_model.rewards):
rewardmodel.set(new_state, reward[index])
rewardmodel.set_state_reward(new_state, reward[index])

last_state_id = state_id
if simulator.is_done():
Expand Down Expand Up @@ -310,7 +316,9 @@ def simulate(
state_id
)
state_action_pair = row_group + select_action
rewardmodel.set_action_state(state_action_pair, reward[index])
rewardmodel.set_state_action_reward_at_id(
state_action_pair, reward[index]
)

# we add the state
state_id, labels = discovery[0], discovery[2]
Expand Down
2 changes: 1 addition & 1 deletion stormvogel/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __format_rewards(self, s: stormvogel.model.State) -> str:
res = ""
for reward_model in self.model.rewards:
try:
res += f"\n{reward_model.name}: {reward_model.get(s)}"
res += f"\n{reward_model.name}: {reward_model.get_state_reward(s)}"
except (
KeyError
): # If this reward model does not have a reward for this state.
Expand Down
58 changes: 58 additions & 0 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def test_available_actions():
]
assert mdp.get_state_by_id(1).available_actions() == action

# we also test it for a state with no available actions
mdp = stormvogel.model.new_mdp()
assert mdp.get_initial_state().available_actions()


def test_get_outgoing_transitions():
mdp = examples.monty_hall.create_monty_hall_mdp()
Expand Down Expand Up @@ -319,3 +323,57 @@ def test_get_sub_model():
new_dtmc.normalize()

assert sub_model == new_dtmc


def test_get_state_action_id():
# we create an mdp:
mdp = examples.monty_hall.create_monty_hall_mdp()
state = mdp.get_state_by_id(2)
action = state.available_actions()[1]

assert mdp.get_state_action_id(state, action) == 5


def test_get_state_action_reward():
# we create an mdp:
mdp = examples.monty_hall.create_monty_hall_mdp()

# we add a reward model:
rewardmodel = mdp.add_rewards("rewardmodel")
for i in range(67):
rewardmodel.rewards[i] = i

state = mdp.get_state_by_id(2)
action = state.available_actions()[1]

assert rewardmodel.get_state_action_reward(state, action) == 5


def test_set_state_action_reward():
# we create an mdp:
mdp = stormvogel.model.new_mdp()
action = stormvogel.model.Action("0", frozenset())
mdp.add_transitions(mdp.get_initial_state(), [(action, mdp.get_initial_state())])

# we make a reward model using the set_state_action_reward method:
rewardmodel = mdp.add_rewards("rewardmodel")
rewardmodel.set_state_action_reward(mdp.get_initial_state(), action, 5)

# we make a reward model manually:
other_rewardmodel = stormvogel.model.RewardModel("rewardmodel", mdp, {0: 5})

assert rewardmodel == other_rewardmodel

# we create an mdp:
mdp = examples.monty_hall.create_monty_hall_mdp()

# we add a reward model with only one reward
rewardmodel = mdp.add_rewards("rewardmodel")
state = mdp.get_state_by_id(2)
action = state.available_actions()[1]
rewardmodel.set_state_action_reward(state, action, 3)

# we make a reward model manually:
other_rewardmodel = stormvogel.model.RewardModel("rewardmodel", mdp, {5: 3})

assert rewardmodel == other_rewardmodel
4 changes: 2 additions & 2 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def test_rewards(mocker):
model, one, init = simple_model()
model.set_transitions(init, [(1, one)])
model.add_rewards("LOL")
model.get_rewards("LOL").set(one, 37)
model.get_rewards("LOL").set_state_reward(one, 37)
model.add_rewards("HIHI")
model.get_rewards("HIHI").set(one, 42)
model.get_rewards("HIHI").set_state_reward(one, 42)
vis = Visualization(model=model)
vis.show()
MockNetwork.add_node.assert_any_call(
Expand Down

0 comments on commit 7909cca

Please sign in to comment.