diff --git a/stormvogel/model.py b/stormvogel/model.py index a9daa30..2a17a9e 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -804,6 +804,8 @@ def get_action(self, name: str) -> Action: ) assert self.actions is not None if name not in self.actions: + print(name) + print(self.actions) raise RuntimeError( f"Tried to get action {name} but that action does not exist" ) diff --git a/stormvogel/pgc.py b/stormvogel/pgc.py index 68c80dd..4588c21 100644 --- a/stormvogel/pgc.py +++ b/stormvogel/pgc.py @@ -33,6 +33,8 @@ def __eq__(self, other): def build_pgc( delta, # Callable[[State, Action], list[tuple[float, State]]], initial_state_pgc: State, # TODO rewards function, label function + rewards=None, + labels=None, available_actions: Callable[[State], list[Action]] | None = None, modeltype: stormvogel.model.ModelType = stormvogel.model.ModelType.MDP, ) -> stormvogel.model.Model: @@ -67,13 +69,13 @@ def build_pgc( while len(states_to_be_visited) > 0: state = states_to_be_visited[0] states_to_be_visited.remove(state) - # we loop over all available actions and call the delta function for each action transition = {} if state not in states_seen: states_seen.append(state) if model.supports_actions(): + # we loop over all available actions and call the delta function for each action assert available_actions is not None for action in available_actions(state): try: @@ -98,8 +100,6 @@ def build_pgc( branch.append((tuple[0], new_state)) states_to_be_visited.append(tuple[1]) else: - # print(tuple[1].__dict__) - # print(model.states) branch.append( (tuple[0], model.get_state_by_name(str(tuple[1].__dict__))) ) @@ -133,4 +133,29 @@ def build_pgc( stormvogel.model.Transition(transition), ) + # we add the rewards + # TODO support multiple reward models + if rewards is not None: + rewardmodel = model.add_rewards("rewards") + if model.supports_actions(): + for state in states_seen: + assert available_actions is not None + for action in available_actions(state): + reward = rewards(state, action) + s = model.get_state_by_name(str(state.__dict__)) + assert s is not None + rewardmodel.set_state_action_reward( + s, + model.get_action(str(action.labels)), + reward, + ) + else: + for state in states_seen: + reward = rewards(state) + s = model.get_state_by_name(str(state.__dict__)) + assert s is not None + rewardmodel.set_state_reward(s, reward) + + # we add the labels + return model diff --git a/tests/test_pgc.py b/tests/test_pgc.py index 107540c..b3b79ed 100644 --- a/tests/test_pgc.py +++ b/tests/test_pgc.py @@ -16,6 +16,9 @@ def test_pgc_mdp(): def available_actions(s: pgc.State): return [left, right] + def rewards(s: pgc.State, a: pgc.Action): + return 1 + def delta(s: pgc.State, action: pgc.Action): if action == left: return ( @@ -40,6 +43,7 @@ def delta(s: pgc.State, action: pgc.Action): delta=delta, available_actions=available_actions, initial_state_pgc=initial_state, + rewards=rewards, ) # we build the model in the regular way: @@ -60,6 +64,10 @@ def delta(s: pgc.State, action: pgc.Action): model.add_transitions(state2, stormvogel.model.Transition({right: branch2})) model.add_transitions(state0, stormvogel.model.Transition({left: branch0})) + rewardmodel = model.add_rewards("rewards") + for i in range(2 * N): + rewardmodel.set_state_action_reward_at_id(i, 1) + assert model == pgc_model @@ -68,6 +76,9 @@ def test_pgc_dtmc(): p = 0.5 initial_state = pgc.State(s=0) + def rewards(s: pgc.State): + return 1 + def delta(s: pgc.State): match s.s: case 0: @@ -96,6 +107,7 @@ def delta(s: pgc.State): pgc_model = stormvogel.pgc.build_pgc( delta=delta, initial_state_pgc=initial_state, + rewards=rewards, modeltype=stormvogel.model.ModelType.DTMC, ) @@ -161,4 +173,8 @@ def delta(s: pgc.State): model.set_transitions(model.get_state_by_id(12), [(1, model.get_state_by_id(13))]) model.set_transitions(model.get_state_by_id(13), [(1, model.get_state_by_id(13))]) + rewardmodel = model.add_rewards("rewards") + for state in model.states.values(): + rewardmodel.set_state_reward(state, 1) + assert pgc_model == model