diff --git a/stormvogel/mapping.py b/stormvogel/mapping.py index 4b006f4..74df716 100644 --- a/stormvogel/mapping.py +++ b/stormvogel/mapping.py @@ -348,9 +348,9 @@ def add_rewards( else rewards.state_rewards ): if model.supports_actions(): - rewardmodel.set_action_state(index, reward) + rewardmodel.set_state_action_reward_at_id(index, reward) else: - rewardmodel.set(model.get_state_by_id(index), reward) + rewardmodel.set_state_reward(model.get_state_by_id(index), reward) def map_dtmc(sparsedtmc: stormpy.storage.SparseDtmc) -> stormvogel.model.Model: """ diff --git a/stormvogel/model.py b/stormvogel/model.py index 2e4e3b8..5d59002 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -133,7 +133,7 @@ def add_transitions(self, transitions: "Transition | TransitionShorthand"): def available_actions(self) -> list["Action"]: """returns the list of all available actions in this state""" - if self.model.supports_actions(): + if self.model.supports_actions() and self.id in self.model.transitions.keys(): action_list = [] for action in self.model.transitions[self.id].transition.keys(): action_list.append(action) @@ -325,7 +325,7 @@ def transition_from_shorthand(shorthand: TransitionShorthand) -> Transition: ) -@dataclass(order=True) +@dataclass() class RewardModel: """Represents a state-exit reward model. Args: @@ -334,20 +334,79 @@ class RewardModel: """ name: str + model: "Model" # Hashed by the id of the state or state action pair (=number in the matrix) rewards: dict[int, Number] - def get(self, state: State) -> Number: - """Gets the reward at said state.""" + def __init__(self, name: str, model: "Model", rewards: dict[int, Number]): + self.name = name + self.rewards = rewards + self.model = model + + if self.model.supports_actions(): + self.set_action_state = {} + else: + self.state_action_pair = None + + def get_state_reward(self, state: State) -> Number: + """Gets the reward at said state or state action pair""" return self.rewards[state.id] - def set(self, state: State, value: Number): + def get_state_action_reward(self, state: State, action: Action) -> Number | None: + """Gets the reward at said state or state action pair""" + if self.model.supports_actions(): + if action in state.available_actions(): + id = self.model.get_state_action_id(state, action) + assert id is not None + return self.rewards[id] + else: + RuntimeError("This action is not available in this state") + else: + RuntimeError( + "The model this rewardmodel belongs to does not support actions" + ) + + def set_state_reward(self, state: State, value: Number): """Sets the reward at said state.""" - self.rewards[state.id] = value + if self.model.supports_actions(): + RuntimeError( + "This is a model with actions. Please call the set_action_state_reward(_at_id) function instead" + ) + else: + self.rewards[state.id] = value + + def set_state_action_reward(self, state: State, action: Action, value: Number): + """sets the reward at said state action pair (in case of models with actions)""" + if self.model.supports_actions(): + if action in state.available_actions(): + id = self.model.get_state_action_id(state, action) + assert id is not None + self.rewards[id] = value + else: + RuntimeError("This action is not available in this state") + else: + RuntimeError( + "The model this rewardmodel belongs to does not support actions" + ) + + def set_state_action_reward_at_id(self, action_state: int, value: Number): + """sets the reward at said state action pair for a given id (in the case of models with actions)""" + if self.model.supports_actions(): + self.rewards[action_state] = value + else: + RuntimeError( + "The model this rewardmodel belongs to does not support actions" + ) + + def __lt__(self, other) -> bool: + if not isinstance(other, RewardModel): + return NotImplemented + return self.name < other.name - def set_action_state(self, state_action_pair: int, value: Number): - """sets the reward at said state action pair""" - self.rewards[state_action_pair] = value + def __eq__(self, other) -> bool: + if isinstance(other, RewardModel): + return self.name == other.name and self.rewards == other.rewards + return False @dataclass @@ -416,15 +475,15 @@ def __init__( def supports_actions(self): """Returns whether this model supports actions.""" - return self.type in (ModelType.MDP, ModelType.POMDP, ModelType.MA) + return self.get_type() in (ModelType.MDP, ModelType.POMDP, ModelType.MA) def supports_rates(self): """Returns whether this model supports rates.""" - return self.type in (ModelType.CTMC, ModelType.MA) + return self.get_type() in (ModelType.CTMC, ModelType.MA) def supports_observations(self): """Returns whether this model supports observations.""" - return self.type == ModelType.POMDP + return self.get_type() == ModelType.POMDP def is_stochastic(self) -> bool: """For discrete models: Checks if all sums of outgoing transition probabilities for all states equal 1 @@ -515,6 +574,19 @@ def get_sub_model(self, states: list[State], normalize: bool = True) -> "Model": sub_model.normalize() return sub_model + def get_state_action_id(self, state: State, action: Action) -> int | None: + """we calculate the appropriate state action id for a given state and action""" + id = 0 + for s in self.states.values(): + for a in s.available_actions(): + if ( + a.name == action.name + and action in s.available_actions() + and s == state + ): + return id + id += 1 + def __free_state_id(self) -> int: """Gets a free id in the states dict.""" # TODO: slow, not sure if that will become a problem though @@ -824,7 +896,7 @@ def add_rewards(self, name: str) -> RewardModel: for model in self.rewards: if model.name == name: raise RuntimeError(f"Reward model {name} already present in model.") - reward_model = RewardModel(name, {}) + reward_model = RewardModel(name, self, {}) self.rewards.append(reward_model) return reward_model diff --git a/stormvogel/simulator.py b/stormvogel/simulator.py index 068a96c..edcf3bd 100644 --- a/stormvogel/simulator.py +++ b/stormvogel/simulator.py @@ -240,11 +240,17 @@ def simulate( for index, reward in enumerate(model.rewards): reward_model = partial_model.add_rewards(model.rewards[index].name) - # we already set the rewards for the initial state - reward_model.set( - partial_model.get_initial_state(), - model.rewards[index].get(model.get_initial_state()), - ) + # we already set the rewards for the initial state/stateaction + if model.supports_actions(): + reward_model.set_state_action_reward_at_id( + partial_model.get_initial_state().id, + model.rewards[index].get_state_reward(model.get_initial_state()), + ) + else: + reward_model.set_state_reward( + partial_model.get_initial_state(), + model.rewards[index].get_state_reward(model.get_initial_state()), + ) # now we start stepping through the model discovered_states = {0} @@ -276,7 +282,7 @@ def simulate( # we add the rewards for index, rewardmodel in enumerate(partial_model.rewards): - rewardmodel.set(new_state, reward[index]) + rewardmodel.set_state_reward(new_state, reward[index]) last_state_id = state_id if simulator.is_done(): @@ -310,7 +316,9 @@ def simulate( state_id ) state_action_pair = row_group + select_action - rewardmodel.set_action_state(state_action_pair, reward[index]) + rewardmodel.set_state_action_reward_at_id( + state_action_pair, reward[index] + ) # we add the state state_id, labels = discovery[0], discovery[2] diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index 6fb6bbd..7a3a3bb 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -204,7 +204,7 @@ def __format_rewards(self, s: stormvogel.model.State) -> str: res = "" for reward_model in self.model.rewards: try: - res += f"\n{reward_model.name}: {reward_model.get(s)}" + res += f"\n{reward_model.name}: {reward_model.get_state_reward(s)}" except ( KeyError ): # If this reward model does not have a reward for this state. diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index f76a31c..ff898e9 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -16,6 +16,10 @@ def test_available_actions(): ] assert mdp.get_state_by_id(1).available_actions() == action + # we also test it for a state with no available actions + mdp = stormvogel.model.new_mdp() + assert mdp.get_initial_state().available_actions() + def test_get_outgoing_transitions(): mdp = examples.monty_hall.create_monty_hall_mdp() @@ -319,3 +323,57 @@ def test_get_sub_model(): new_dtmc.normalize() assert sub_model == new_dtmc + + +def test_get_state_action_id(): + # we create an mdp: + mdp = examples.monty_hall.create_monty_hall_mdp() + state = mdp.get_state_by_id(2) + action = state.available_actions()[1] + + assert mdp.get_state_action_id(state, action) == 5 + + +def test_get_state_action_reward(): + # we create an mdp: + mdp = examples.monty_hall.create_monty_hall_mdp() + + # we add a reward model: + rewardmodel = mdp.add_rewards("rewardmodel") + for i in range(67): + rewardmodel.rewards[i] = i + + state = mdp.get_state_by_id(2) + action = state.available_actions()[1] + + assert rewardmodel.get_state_action_reward(state, action) == 5 + + +def test_set_state_action_reward(): + # we create an mdp: + mdp = stormvogel.model.new_mdp() + action = stormvogel.model.Action("0", frozenset()) + mdp.add_transitions(mdp.get_initial_state(), [(action, mdp.get_initial_state())]) + + # we make a reward model using the set_state_action_reward method: + rewardmodel = mdp.add_rewards("rewardmodel") + rewardmodel.set_state_action_reward(mdp.get_initial_state(), action, 5) + + # we make a reward model manually: + other_rewardmodel = stormvogel.model.RewardModel("rewardmodel", mdp, {0: 5}) + + assert rewardmodel == other_rewardmodel + + # we create an mdp: + mdp = examples.monty_hall.create_monty_hall_mdp() + + # we add a reward model with only one reward + rewardmodel = mdp.add_rewards("rewardmodel") + state = mdp.get_state_by_id(2) + action = state.available_actions()[1] + rewardmodel.set_state_action_reward(state, action, 3) + + # we make a reward model manually: + other_rewardmodel = stormvogel.model.RewardModel("rewardmodel", mdp, {5: 3}) + + assert rewardmodel == other_rewardmodel diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 3ae15cb..4f05f80 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -56,9 +56,9 @@ def test_rewards(mocker): model, one, init = simple_model() model.set_transitions(init, [(1, one)]) model.add_rewards("LOL") - model.get_rewards("LOL").set(one, 37) + model.get_rewards("LOL").set_state_reward(one, 37) model.add_rewards("HIHI") - model.get_rewards("HIHI").set(one, 42) + model.get_rewards("HIHI").set_state_reward(one, 42) vis = Visualization(model=model) vis.show() MockNetwork.add_node.assert_any_call(