From 55eeeb6bac3ea870f75fe8fab72fdb86a38ad756 Mon Sep 17 00:00:00 2001 From: Roman Andriushchenko Date: Fri, 6 Dec 2024 16:16:06 +0100 Subject: [PATCH] handle POMDPs with unreachable observations --- paynt/quotient/fsc.py | 13 ++++++++----- paynt/quotient/pomdp.py | 3 +++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/paynt/quotient/fsc.py b/paynt/quotient/fsc.py index b9e88e04..3452ebce 100644 --- a/paynt/quotient/fsc.py +++ b/paynt/quotient/fsc.py @@ -66,12 +66,15 @@ def check_action_function(self, observation_to_actions): assert len(self.action_function[node]) == self.num_observations, \ "in memory node {}, FSC action function is not defined for all observations".format(node) for obs in range(self.num_observations): + if observation_to_actions[obs] == []: + assert self.action_function[node][obs] is None + continue if self.is_deterministic: - action = self.action_function[node][obs] - assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) + action_support = [self.action_function[node][obs]] else: - for action,_ in self.action_function[node][obs].items(): - assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) + action_support = self.action_function[node][obs].keys() + for action in action_support: + assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) def check_update_function(self): assert len(self.update_function) == self.num_nodes, "FSC update function is not defined for all memory nodes" @@ -91,7 +94,7 @@ def check(self, observation_to_actions): def fill_trivial_actions(self, observation_to_actions): ''' For each observation with 1 available action, set gamma(n,z) to that action. ''' for obs,actions in enumerate(observation_to_actions): - if len(actions)>1: + if len(actions) != 1: continue action = actions[0] if not self.is_deterministic: diff --git a/paynt/quotient/pomdp.py b/paynt/quotient/pomdp.py index 3bf1eb38..c725fe23 100644 --- a/paynt/quotient/pomdp.py +++ b/paynt/quotient/pomdp.py @@ -106,6 +106,9 @@ def __init__(self, pomdp, specification, decpomdp_manager=None): else : label = list(labels)[0] self.action_labels_at_observation[obs].append(label) + for obs,labels in enumerate(self.action_labels_at_observation): + if len(labels) == 0: + logger.warning(f"WARNING: POMDP has no action for observation {obs}") # mark perfect observations self.observation_states = [0 for obs in range(self.observations)]