From 93a5c9b22e5b246f6161cc14027278dd94e37c9c Mon Sep 17 00:00:00 2001 From: YouGuessedMyName Date: Thu, 19 Dec 2024 18:22:12 +0100 Subject: [PATCH] simulator WIP --- stormvogel/model.py | 5 ++++- stormvogel/simulator.py | 33 +++++++++++++++++++-------------- tests/test_simulator.py | 18 ++++++++++++++++++ 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/stormvogel/model.py b/stormvogel/model.py index 6be2e10..bf5cc91 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -411,7 +411,9 @@ def set_state_action_reward( if self.model.supports_actions(): if action in state.available_actions(): self.rewards[state.id, action] = value + print("rewards", self.rewards) else: + print('FAILED', state.available_actions(), action) RuntimeError("This action is not available in this state") else: RuntimeError( @@ -873,6 +875,7 @@ def new_state( ) -> State: """Creates a new state and returns it.""" state_id = self.__free_state_id() + print("free state id!", state_id) if isinstance(labels, list): state = State(labels, features or {}, state_id, self) elif isinstance(labels, str): @@ -893,7 +896,7 @@ def get_states_with_label(self, label: str) -> list[State]: collected_states.append(state) return collected_states - def get_state_by_id(self, state_id) -> State: + def get_state_by_id(self, state_id: int) -> State: """Get a state by its id.""" if state_id not in self.states: raise RuntimeError("Requested a non-existing state") diff --git a/stormvogel/simulator.py b/stormvogel/simulator.py index de4f3b6..dcdf2fa 100644 --- a/stormvogel/simulator.py +++ b/stormvogel/simulator.py @@ -306,22 +306,12 @@ def simulate( # we add the action to the partial model assert partial_model.actions is not None action = model.states[state_id].available_actions()[select_action] - if action not in partial_model.actions.values(): - partial_model.new_action(action.name) - - # we add the reward model to the partial model - discovery = simulator.step(actions[select_action]) - reward = discovery[1] - for index, rewardmodel in enumerate(partial_model.rewards): - row_group = stormpy_model.transition_matrix.get_row_group_start( - state_id - ) - state_action_pair = row_group + select_action - rewardmodel.set_state_action_reward_at_id( - state_action_pair, reward[index] - ) + if action not in partial_model.actions: + partial_model.new_action(action.labels) # we add the state + print("\nAdding state!") + discovery = simulator.step(actions[select_action]) state_id, labels = discovery[0], discovery[2] if state_id not in discovered_states: discovered_states.add(state_id) @@ -337,7 +327,22 @@ def simulate( last_state_partial = new_state last_state_id = state_id + + # we add the rewards. + reward = discovery[1] + for index, rewardmodel in enumerate(partial_model.rewards): + # row_group = stormpy_model.transition_matrix.get_row_group_start( + # state_id + # ) + # state_action_pair = row_group + select_action + state = model.get_state_by_id(state_id) + rewardmodel.set_state_action_reward(state, action, reward[index]) + print("ADD REWARD:", state.id, state.name, action.labels, reward[index]) + print("INTER:", rewardmodel.rewards) + #print(rewardmodel.rewards.items()) + print("RESULT", partial_model.rewards[1].rewards) if simulator.is_done(): break + return partial_model diff --git a/tests/test_simulator.py b/tests/test_simulator.py index 8eaaa83..77be716 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -74,6 +74,24 @@ def test_simulate(): rewardmodel2 = other_mdp.add_rewards("rewardmodel2") rewardmodel2.rewards = {0: 0, 7: 7, 16: 16} + # print(partial_model) + # print(other_mdp) + + + self = partial_model + other = other_mdp + print(sorted(self.rewards) == sorted(other.rewards)) + + print(sorted(self.rewards[0].rewards)) + print(sorted(other.rewards[0].rewards)) + + print(self.type == other.type + and self.states == other.states + and self.transitions == other.transitions + and sorted(self.rewards) == sorted(other.rewards) + and self.exit_rates == other.exit_rates + and self.markovian_states == other.markovian_states) + assert partial_model == other_mdp ######################################################################################################################