Skip to content

Commit

Permalink
handle POMDPs with unreachable observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Dec 6, 2024
1 parent 66968d5 commit 55eeeb6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
13 changes: 8 additions & 5 deletions paynt/quotient/fsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ def check_action_function(self, observation_to_actions):
assert len(self.action_function[node]) == self.num_observations, \
"in memory node {}, FSC action function is not defined for all observations".format(node)
for obs in range(self.num_observations):
if observation_to_actions[obs] == []:
assert self.action_function[node][obs] is None
continue
if self.is_deterministic:
action = self.action_function[node][obs]
assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action)
action_support = [self.action_function[node][obs]]
else:
for action,_ in self.action_function[node][obs].items():
assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action)
action_support = self.action_function[node][obs].keys()
for action in action_support:
assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action)

def check_update_function(self):
assert len(self.update_function) == self.num_nodes, "FSC update function is not defined for all memory nodes"
Expand All @@ -91,7 +94,7 @@ def check(self, observation_to_actions):
def fill_trivial_actions(self, observation_to_actions):
''' For each observation with 1 available action, set gamma(n,z) to that action. '''
for obs,actions in enumerate(observation_to_actions):
if len(actions)>1:
if len(actions) != 1:
continue
action = actions[0]
if not self.is_deterministic:
Expand Down
3 changes: 3 additions & 0 deletions paynt/quotient/pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def __init__(self, pomdp, specification, decpomdp_manager=None):
else :
label = list(labels)[0]
self.action_labels_at_observation[obs].append(label)
for obs,labels in enumerate(self.action_labels_at_observation):
if len(labels) == 0:
logger.warning(f"WARNING: POMDP has no action for observation {obs}")

# mark perfect observations
self.observation_states = [0 for obs in range(self.observations)]
Expand Down

0 comments on commit 55eeeb6

Please sign in to comment.