From 8a62c5f807dcd56871ba5fb8369cb73e8c395ce1 Mon Sep 17 00:00:00 2001 From: Roman Andriushchenko Date: Wed, 15 Jan 2025 15:32:19 +0100 Subject: [PATCH] new condition for enabling don't-care action --- paynt/quotient/mdp.py | 69 +++++--- paynt/quotient/quotient.py | 1 + paynt/synthesizer/decision_tree.py | 30 ++-- paynt/synthesizer/synthesizer_ar.py | 1 - paynt/verification/property_result.py | 2 +- .../src/synthesis/quotient/ColoringSmt.cpp | 153 ++++++++++++++---- .../src/synthesis/quotient/ColoringSmt.h | 33 ++-- payntbind/src/synthesis/quotient/TreeNode.cpp | 27 +++- payntbind/src/synthesis/quotient/TreeNode.h | 6 + payntbind/src/synthesis/quotient/bindings.cpp | 1 + .../translation/choiceTransformation.cpp | 42 +++-- .../translation/choiceTransformation.h | 6 +- 12 files changed, 276 insertions(+), 95 deletions(-) diff --git a/paynt/quotient/mdp.py b/paynt/quotient/mdp.py index da20e0ca..43fd4091 100644 --- a/paynt/quotient/mdp.py +++ b/paynt/quotient/mdp.py @@ -273,10 +273,12 @@ def to_graphviz(self): class MdpQuotient(paynt.quotient.quotient.Quotient): + # label for action executing a random action selection + DONT_CARE_ACTION_LABEL = "__random__" # if true, an explicit action executing a random choice of an available action will be added to each state add_dont_care_action = False # if true, irrelevant states will not be considered for tree mapping - filter_irrelevant_states = True + filter_deterministic_states = False @classmethod def get_state_valuations(cls, model): @@ -304,7 +306,7 @@ def __init__(self, mdp, specification): # list of relevant variables: variables having at least two different options on relevant states self.variables = None - # for every state, a valuation of relevant variables; contains empty list for irrelevant states + # for every state, a valuation of relevant variables self.relevant_state_valuations = None # decision tree obtained after reset_tree self.decision_tree = None @@ -312,22 +314,24 @@ def __init__(self, mdp, specification): # deprecated # updated = payntbind.synthesis.restoreActionsInAbsorbingStates(mdp) # if updated is not None: mdp = updated - action_labels,_ = payntbind.synthesis.extractActionLabels(mdp) - if "__random__" not in action_labels and MdpQuotient.add_dont_care_action: - logger.debug("adding explicit don't-care action to every state...") - mdp = payntbind.synthesis.addDontCareAction(mdp) # identify relevant states self.state_is_relevant = [True for state in range(mdp.nr_states)] - if MdpQuotient.filter_irrelevant_states: - state_is_absorbing = self.identify_absorbing_states(mdp) - self.state_is_relevant = [self.state_is_relevant[state] and not absorbing for state,absorbing in enumerate(state_is_absorbing)] + state_is_absorbing = self.identify_absorbing_states(mdp) + self.state_is_relevant = [relevant and not state_is_absorbing[state] for state,relevant in enumerate(self.state_is_relevant)] + + if MdpQuotient.filter_deterministic_states: state_has_actions = self.identify_states_with_actions(mdp) - self.state_is_relevant = [self.state_is_relevant[state] and has_actions for state,has_actions in enumerate(state_has_actions)] + self.state_is_relevant = [relevant and state_has_actions[state] for state,relevant in enumerate(self.state_is_relevant)] self.state_is_relevant_bv = stormpy.BitVector(mdp.nr_states) [self.state_is_relevant_bv.set(state,value) for state,value in enumerate(self.state_is_relevant)] logger.debug(f"MDP has {self.state_is_relevant_bv.number_of_set_bits()}/{self.state_is_relevant_bv.size()} relevant states") + action_labels,_ = payntbind.synthesis.extractActionLabels(mdp) + if MdpQuotient.DONT_CARE_ACTION_LABEL not in action_labels and MdpQuotient.add_dont_care_action: + logger.debug("adding explicit don't-care action to relevant states...") + mdp = payntbind.synthesis.addDontCareAction(mdp,self.state_is_relevant_bv) + self.quotient_mdp = mdp self.choice_destinations = payntbind.synthesis.computeChoiceDestinations(mdp) self.action_labels,self.choice_to_action = payntbind.synthesis.extractActionLabels(mdp) @@ -359,9 +363,9 @@ def __init__(self, mdp, specification): logger.debug(f"found the following {len(self.variables)} variables: {[str(v) for v in self.variables]}") - def scheduler_json_to_choices(self, scheduler_json): + def scheduler_json_to_choices(self, scheduler_json, discard_unreachable_states=False): variable_name,state_valuations = self.get_state_valuations(self.quotient_mdp) - ndi = self.quotient_mdp.nondeterministic_choice_indices.copy() + nci = self.quotient_mdp.nondeterministic_choice_indices.copy() assert self.quotient_mdp.nr_states == len(scheduler_json) state_to_choice = self.empty_scheduler() for state_decision in scheduler_json: @@ -377,25 +381,52 @@ def scheduler_json_to_choices(self, scheduler_json): action_labels = actions[0]["labels"] assert len(action_labels) <= 1 if len(action_labels) == 0: - state_to_choice[state] = ndi[state] + state_to_choice[state] = nci[state] continue action = self.action_labels.index(action_labels[0]) # find a choice that executes this action - for choice in range(ndi[state],ndi[state+1]): + for choice in range(nci[state],nci[state+1]): if self.choice_to_action[choice] == action: state_to_choice[state] = choice break else: assert False, "action is not available in the state" - state_to_choice = self.discard_unreachable_choices(state_to_choice) + # enable implicit actions + for state,choice in enumerate(state_to_choice): + if choice is None: + logger.warning(f"WARNING: scheduler has no action for state {state}") + state_to_choice[state] = nci[state] + + if discard_unreachable_states: + state_to_choice = self.discard_unreachable_choices(state_to_choice) + # keep only relevant states + state_to_choice = [choice if self.state_is_relevant[state] else None for state,choice in enumerate(state_to_choice)] choices = self.state_to_choice_to_choices(state_to_choice) - return choices + + scheduler_json_relevant = [] + for state_decision in scheduler_json: + valuation = [state_decision["s"][name] for name in variable_name] + for state,state_valuation in enumerate(state_valuations): + if valuation == state_valuation: + break + if state_to_choice[state] is None: + continue + scheduler_json_relevant.append(state_decision) + + return choices,scheduler_json_relevant + def reset_tree(self, depth, enable_harmonization=True): ''' Rebuild the decision tree template, the design space and the coloring. ''' logger.debug(f"building tree of depth {depth}") + + num_actions = len(self.action_labels) + dont_care_action = num_actions + if MdpQuotient.DONT_CARE_ACTION_LABEL in self.action_labels: + dont_care_action = self.action_labels.index(MdpQuotient.DONT_CARE_ACTION_LABEL) + self.decision_tree = DecisionTree(self,self.variables) self.decision_tree.set_depth(depth) @@ -405,6 +436,7 @@ def reset_tree(self, depth, enable_harmonization=True): tree_list = self.decision_tree.to_list() self.coloring = payntbind.synthesis.ColoringSmt( self.quotient_mdp.nondeterministic_choice_indices, self.choice_to_action, + num_actions, dont_care_action, self.quotient_mdp.state_valuations, self.state_is_relevant_bv, variable_name, variable_domain, tree_list, enable_harmonization ) @@ -453,10 +485,7 @@ def build(self, family): choices = self.coloring.selectCompatibleChoices(family.family) else: choices = self.coloring.selectCompatibleChoices(family.family, family.parent_info.selected_choices) - if choices.number_of_set_bits() == 0: - family.mdp = None - family.analysis_result = self.build_unsat_result() - return + assert choices.number_of_set_bits() > 0 # proceed as before family.selected_choices = choices diff --git a/paynt/quotient/quotient.py b/paynt/quotient/quotient.py index c4dbfc3e..030b8745 100644 --- a/paynt/quotient/quotient.py +++ b/paynt/quotient/quotient.py @@ -96,6 +96,7 @@ def mdp_to_dtmc(mdp): def build_assignment(self, family): assert family.size == 1, "expecting family of size 1" choices = self.coloring.selectCompatibleChoices(family.family) + assert choices.number_of_set_bits() > 0 mdp,state_map,choice_map = self.restrict_quotient(choices) model = Quotient.mdp_to_dtmc(mdp) return paynt.models.models.SubMdp(model,state_map,choice_map) diff --git a/paynt/synthesizer/decision_tree.py b/paynt/synthesizer/decision_tree.py index bf0ba983..97cee31a 100644 --- a/paynt/synthesizer/decision_tree.py +++ b/paynt/synthesizer/decision_tree.py @@ -64,10 +64,8 @@ def harmonize_inconsistent_scheduler(self, family): def verify_family(self, family): self.num_families_considered += 1 self.quotient.build(family) - if family.mdp is None: - self.num_families_skipped += 1 - return + self.stat.iteration(family.mdp) if family.parent_info is not None: for choice in family.parent_info.scheduler_choices: if not family.selected_choices[choice]: @@ -86,8 +84,6 @@ def verify_family(self, family): self.check_specification(family) if not family.analysis_result.can_improve: return - if SynthesizerDecisionTree.scheduler_path is not None: - return self.harmonize_inconsistent_scheduler(family) @@ -204,10 +200,7 @@ def map_scheduler(self, scheduler_choices): if self.resource_limit_reached(): break - # self.counters_print() - def run(self, optimum_threshold=None): - scheduler_choices = None if SynthesizerDecisionTree.scheduler_path is None: paynt_mdp = paynt.models.models.Mdp(self.quotient.quotient_mdp) @@ -216,7 +209,18 @@ def run(self, optimum_threshold=None): opt_result_value = None with open(SynthesizerDecisionTree.scheduler_path, 'r') as f: scheduler_json = json.load(f) - scheduler_choices = self.quotient.scheduler_json_to_choices(scheduler_json) + scheduler_choices,scheduler_json_relevant = self.quotient.scheduler_json_to_choices(scheduler_json, discard_unreachable_states=True) + + # export transformed scheduler + # import os + # directory = os.path.dirname(SynthesizerDecisionTree.scheduler_path) + # transformed_name = f"scheduler-reachable.storm.json" + # scheduler_relevant_path = os.path.join(directory, transformed_name) + # with open(scheduler_relevant_path, 'w') as f: + # json.dump(scheduler_json_relevant, f, indent=4) + # logger.debug(f"stored transformed scheduler to {scheduler_relevant_path}") + # exit() + submdp = self.quotient.build_from_choice_mask(scheduler_choices) mc_result = submdp.model_check_property(self.quotient.get_property()) opt_result_value = mc_result.value @@ -261,9 +265,9 @@ def run(self, optimum_threshold=None): time_total = round(paynt.utils.timer.GlobalTimer.read(),2) logger.info(f"synthesis finished after {time_total} seconds") - # print() - # for name,time in self.quotient.coloring.getProfilingInfo(): - # time_percent = round(time/time_total*100,1) - # print(f"{name} = {time} s ({time_percent} %)") + print() + for name,time in self.quotient.coloring.getProfilingInfo(): + time_percent = round(time/time_total*100,1) + print(f"{name} = {time} s ({time_percent} %)") return self.best_tree diff --git a/paynt/synthesizer/synthesizer_ar.py b/paynt/synthesizer/synthesizer_ar.py index 85760bcb..2f8beb1b 100644 --- a/paynt/synthesizer/synthesizer_ar.py +++ b/paynt/synthesizer/synthesizer_ar.py @@ -115,7 +115,6 @@ def update_optimum(self, family): if isinstance(self.quotient, paynt.quotient.pomdp.PomdpQuotient): self.stat.new_fsc_found(family.analysis_result.improving_value, ia, self.quotient.policy_size(ia)) - def synthesize_one(self, family): families = [family] while families: diff --git a/paynt/verification/property_result.py b/paynt/verification/property_result.py index 2f9eb706..e51af9fa 100644 --- a/paynt/verification/property_result.py +++ b/paynt/verification/property_result.py @@ -132,7 +132,7 @@ def evaluate(self, family=None, admissible_assignment=None): else: self.improving_assignment = opt.improving_assignment self.improving_value = opt.improving_value - self.can_improve = opt.can_improve + self.can_improve = opt.can_improve return # constraints undecided diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.cpp b/payntbind/src/synthesis/quotient/ColoringSmt.cpp index 8590edb3..361e19ff 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.cpp +++ b/payntbind/src/synthesis/quotient/ColoringSmt.cpp @@ -14,13 +14,16 @@ template ColoringSmt::ColoringSmt( std::vector const& row_groups, std::vector const& choice_to_action, + uint64_t num_actions, + uint64_t dont_care_action, storm::storage::sparse::StateValuations const& state_valuations, BitVector const& state_is_relevant, std::vector const& variable_name, std::vector> const& variable_domain, std::vector> const& tree_list, bool enable_harmonization -) : state_is_relevant(state_is_relevant), row_groups(row_groups), choice_to_action(choice_to_action), +) : row_groups(row_groups), choice_to_action(choice_to_action), num_actions(num_actions), dont_care_action(dont_care_action), + state_is_relevant(state_is_relevant), variable_name(variable_name), variable_domain(variable_domain), solver(ctx), harmonizing_variable(ctx), enable_harmonization(enable_harmonization) { @@ -32,6 +35,15 @@ ColoringSmt::ColoringSmt( } } + // identify available actions + for(uint64_t state = 0; state < numStates(); ++state) { + BitVector action_available(num_actions,false); + for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) { + action_available.set(this->choice_to_action[choice],true); + } + this->state_available_actions.push_back(action_available); + } + // extract variables in the order of variable_name std::vector program_variables; auto const& valuation = state_valuations.at(0); @@ -146,6 +158,15 @@ ColoringSmt::ColoringSmt( } } + std::vector> state_dont_care_actions(numStates()); + for(uint64_t state: state_is_relevant) { + state_dont_care_actions[state].push_back(dont_care_action); + for(uint64_t action: ~state_available_actions[state]) { + state_dont_care_actions[state].push_back(action); + } + } + + for(uint64_t state = 0; state < numStates(); ++state) { for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) { choice_path_expresssion.push_back(z3::expr_vector(ctx)); @@ -154,7 +175,11 @@ ColoringSmt::ColoringSmt( } uint64_t action = choice_to_action[choice]; for(uint64_t path = 0; path < numPaths(); ++path) { - choice_path_expresssion[choice].push_back(state_path_expression[state][path] or action_path_expression[action][path]); + z3::expr action_selection = action_path_expression[action][path]; + if(action == dont_care_action) { + action_selection = getRoot()->substituteActionExpression(getRoot()->paths[path], state_dont_care_actions[state]); + } + choice_path_expresssion[choice].push_back(state_path_expression[state][path] or action_selection); } } } @@ -203,7 +228,11 @@ ColoringSmt::ColoringSmt( } uint64_t action = choice_to_action[choice]; for(uint64_t path = 0; path < numPaths(); ++path) { - choice_path_expresssion_harm[choice].push_back(state_path_expression_harmonizing[state][path] or action_path_expression_harmonizing[action][path]); + z3::expr action_selection = action_path_expression_harmonizing[action][path]; + if(action == dont_care_action) { + action_selection = getRoot()->substituteActionExpressionHarmonizing(getRoot()->paths[path], state_dont_care_actions[state], harmonizing_variable); + } + choice_path_expresssion_harm[choice].push_back(state_path_expression_harmonizing[state][path] or action_selection); } } } @@ -229,6 +258,11 @@ const uint64_t ColoringSmt::numChoices() const { return row_groups.back(); } +template +const bool ColoringSmt::dontCareActionDefined() const { + return dont_care_action < num_actions; +} + template const uint64_t ColoringSmt::numVariables() const { return variable_name.size(); @@ -333,7 +367,9 @@ BitVector ColoringSmt::selectCompatibleChoices(Family const& subfamil } } bool any_choice_enabled = false; + uint64_t num_choices_enabled = 0; for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) { + uint64_t action = choice_to_action[choice]; if(not base_choices[choice]) { continue; } @@ -342,33 +378,40 @@ BitVector ColoringSmt::selectCompatibleChoices(Family const& subfamil // enable the choice only if no choice has been enabled yet choice_enabled = not any_choice_enabled; } else { + // iterate over paths for(uint64_t path: state_path_enabled[state]) { - if(subfamily.holeContains(path_action_hole[path],choice_to_action[choice])) { - choice_enabled = true; + uint64_t path_hole = path_action_hole[path]; + choice_enabled = subfamily.holeContains(path_hole,action); + // enable the choice if this action is the family + if(not choice_enabled and action == this->dont_care_action) { + // don't-care action can also be enabled if any unavailable action is in the family + for(uint64_t unavailable_action: ~state_available_actions[state]) { + if(subfamily.holeContains(path_hole,unavailable_action)) { + choice_enabled = true; + break; + } + } + } + if(choice_enabled) { break; } } } if(choice_enabled) { any_choice_enabled = true; + num_choices_enabled++; selection.set(choice,true); visitChoice(choice,state_reached,unexplored_states); } } - if(not any_choice_enabled) { - if(subfamily.isAssignment()) { - STORM_LOG_WARN("Hole assignment does not induce a DTMC, enabling last action..."); - // uint64_t choice = row_groups[state]; // pick the first choice - uint64_t choice = row_groups[state+1]-1; // pick the last choice executing the random choice - selection.set(choice,true); - visitChoice(choice,state_reached,unexplored_states); - } else { - selection.clear(); - timers["selectCompatibleChoices::2 state exploration"].stop(); - timers[__FUNCTION__].stop(); - return selection; + STORM_LOG_THROW(any_choice_enabled, storm::exceptions::UnexpectedException, "no choice is available in the sub-MDP"); + /*if(num_choices_enabled == 1) { + if(state_path_enabled[state].getNumberOfSetBits() == 1) { + uint64_t path = *state_path_enabled[state].begin(); + uint64_t path_hole = path_action_hole[path]; + std::cout << subfamily.holeOptions(path_hole).size() << " "; } - } + }*/ } timers["selectCompatibleChoices::2 state exploration"].stop(); @@ -420,7 +463,7 @@ void ColoringSmt::loadUnsatCore(z3::expr_vector const& unsat_core_exp timers[__FUNCTION__].stop(); return; - for(uint64_t index = 0; index < this->unsat_core.size()-1; ++index) { + /*for(uint64_t index = 0; index < this->unsat_core.size()-1; ++index) { auto [choice,path] = this->unsat_core[index]; solver.push(); solver.add(choice_path_expresssion[choice][path]); @@ -436,7 +479,39 @@ void ColoringSmt::loadUnsatCore(z3::expr_vector const& unsat_core_exp this->unsat_core.pop_back(); solver.pop(); } + timers[__FUNCTION__].stop();*/ +} + +template +void ColoringSmt::loadUnsatCore(z3::expr_vector const& unsat_core_expr, Family const& subfamily, BitVector const& choices) { + timers[__FUNCTION__].start(); + this->unsat_core.clear(); + std::set critical_states; + for(z3::expr expr: unsat_core_expr) { + std::istringstream iss(expr.decl().name().str()); + char prefix; iss >> prefix; + if(prefix == 'h' or prefix == 'z') { + // uint64_t hole; iss >> prefix; iss >> hole; + continue; + } + // prefix == 'p' + uint64_t choice,path; iss >> choice; iss >> prefix; iss >> path; + uint64_t state = this->choice_to_state[choice]; + critical_states.insert(state); + } + for(uint64_t state: critical_states) { + for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) { + if(not choices[choice]) { + continue; + } + for(uint64_t path: state_path_enabled[state]) { + this->unsat_core.emplace_back(choice,path); + } + } + } + timers[__FUNCTION__].stop(); + return; } template @@ -444,6 +519,18 @@ std::pair>> ColoringSmt::areCh timers[__FUNCTION__].start(); std::vector> hole_options_vector(family.numHoles()); + /*for(uint64_t choice: choices) { + uint64_t action = choice_to_action[choice]; + uint64_t state = choice_to_state[choice]; + if(not state_is_relevant[state]) { + continue; + } + if(state_path_enabled[state].getNumberOfSetBits() == 1) { + std::cout << (action == this->dont_care_action); + } + }*/ + + timers["areChoicesConsistent::1 is scheduler consistent?"].start(); solver.push(); getRoot()->addFamilyEncoding(subfamily,solver); @@ -504,7 +591,8 @@ std::pair>> ColoringSmt::areCh } z3::expr_vector unsat_core_expr = solver.unsat_core(); solver.pop(); - loadUnsatCore(unsat_core_expr,subfamily); + // loadUnsatCore(unsat_core_expr,subfamily); + loadUnsatCore(unsat_core_expr,subfamily,choices); timers["areChoicesConsistent::2 better unsat core"].stop(); if(PRINT_UNSAT_CORE) @@ -539,21 +627,24 @@ std::pair>> ColoringSmt::areCh solver.add(0 <= harmonizing_variable and harmonizing_variable < (int)(family.numHoles()), "harmonizing_domain"); consistent = check(); - STORM_LOG_THROW(consistent, storm::exceptions::UnexpectedException, "harmonized UNSAT core is not SAT"); - model = solver.get_model(); - + if(consistent) { + model = solver.get_model(); + uint64_t harmonizing_hole = model.eval(harmonizing_variable).get_numeral_uint64(); + getRoot()->loadHoleAssignmentFromModel(model,hole_options_vector); + getRoot()->loadHoleAssignmentFromModelHarmonizing(model,hole_options_vector,harmonizing_hole); + if(hole_options_vector[harmonizing_hole][0] > hole_options_vector[harmonizing_hole][1]) { + uint64_t tmp = hole_options_vector[harmonizing_hole][0]; + hole_options_vector[harmonizing_hole][0] = hole_options_vector[harmonizing_hole][1]; + hole_options_vector[harmonizing_hole][1] = tmp; + } + } else { + STORM_LOG_THROW(consistent, storm::exceptions::UnexpectedException, "harmonized UNSAT core is not SAT"); + } solver.pop(); - uint64_t harmonizing_hole = model.eval(harmonizing_variable).get_numeral_uint64(); - getRoot()->loadHoleAssignmentFromModel(model,hole_options_vector); - getRoot()->loadHoleAssignmentFromModelHarmonizing(model,hole_options_vector,harmonizing_hole); - if(hole_options_vector[harmonizing_hole][0] > hole_options_vector[harmonizing_hole][1]) { - uint64_t tmp = hole_options_vector[harmonizing_hole][0]; - hole_options_vector[harmonizing_hole][0] = hole_options_vector[harmonizing_hole][1]; - hole_options_vector[harmonizing_hole][1] = tmp; - } if(PRINT_UNSAT_CORE) std::cout << "-- unsat core end --" << std::endl; + timers["areChoicesConsistent::3 unsat core analysis"].stop(); timers[__FUNCTION__].stop(); diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.h b/payntbind/src/synthesis/quotient/ColoringSmt.h index e82fad1b..3f75d9d3 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.h +++ b/payntbind/src/synthesis/quotient/ColoringSmt.h @@ -26,6 +26,9 @@ class ColoringSmt { * Construct an SMT coloring. * @param row groups (nondeterministic choice indices) of an underlying MDP * @param choice_to_action for every choice, the corresponding action + * @param num_actions total number of actions in the MDP + * @param dont_care_action index of the don't-care action; if the MDP has no such action, a value equal to + * \p num_actions can be given * @param state_valuations state valuation of the underlying MDP * @param variable_name list of variable names * @param variable_domain list of possible variable values @@ -35,6 +38,8 @@ class ColoringSmt { ColoringSmt( std::vector const& row_groups, std::vector const& choice_to_action, + uint64_t num_actions, + uint64_t dont_care_action, storm::storage::sparse::StateValuations const& state_valuations, BitVector const& state_is_relevant, std::vector const& variable_name, @@ -56,7 +61,10 @@ class ColoringSmt { * Get a mask of choices compatible with the family. For irrelevant states, only the first choice will be enabled. */ BitVector selectCompatibleChoices(Family const& subfamily); - /** Get a mask of sub-choices of \p base_choices compatible with the family.*/ + /** + * Get a mask of sub-choices of \p base_choices compatible with the family. If a relevant state has no enabled + * actions, it last action will be enabled. We assume here that this last action is the one executing random action. + */ BitVector selectCompatibleChoices(Family const& subfamily, BitVector const& base_choices); /** @@ -85,11 +93,6 @@ class ColoringSmt { /** Whether a check for an admissible family member is done before choice selection. */ const bool CHECK_FAMILY_CONSISTENCE = true; - /** Valuation of each state. */ - std::vector> state_valuation; - /** Only relevant states are taken into account when checking consistency. */ - const BitVector state_is_relevant; - /** Row groups of the quotient. */ const std::vector row_groups; /** For each choice, the state it comes from. */ @@ -99,10 +102,21 @@ class ColoringSmt { /** Number of choices in the quotient. */ const uint64_t numChoices() const; - /** Number of MDP actions. */ - uint64_t num_actions; /** For each choice, its unique action. */ const std::vector choice_to_action; + /** Number of MDP actions. */ + uint64_t num_actions; + /** Index of the don't care action; equal to \num_actions if no such action exists. */ + uint64_t dont_care_action; + /** For every state, a list of available actions. */ + std::vector state_available_actions; + /** Whether the don't-care action is present in the MDP. */ + const bool dontCareActionDefined() const; + + /** Valuation of each state. */ + std::vector> state_valuation; + /** Only relevant states are taken into account when checking consistency. */ + const BitVector state_is_relevant; /** For each variable, its name. */ const std::vector variable_name; @@ -144,7 +158,7 @@ class ColoringSmt { /** For each choice, its color expressed as a conjunction of all path implications. */ std::vector choice_path_expresssion; - /** TODO. */ + /** Whether harmonization is required. */ const bool enable_harmonization; /** SMT variable refering to harmonizing hole. */ z3::expr harmonizing_variable; @@ -159,6 +173,7 @@ class ColoringSmt { bool PRINT_UNSAT_CORE = false; void loadUnsatCore(z3::expr_vector const& unsat_core_expr, Family const& subfamily); + void loadUnsatCore(z3::expr_vector const& unsat_core_expr, Family const& subfamily, BitVector const& choices); }; diff --git a/payntbind/src/synthesis/quotient/TreeNode.cpp b/payntbind/src/synthesis/quotient/TreeNode.cpp index c9908f1c..eed81d7d 100644 --- a/payntbind/src/synthesis/quotient/TreeNode.cpp +++ b/payntbind/src/synthesis/quotient/TreeNode.cpp @@ -176,6 +176,14 @@ z3::expr TerminalNode::substituteActionExpression(std::vector const& path, return action_hole.solver_variable == (int)action; } +z3::expr TerminalNode::substituteActionExpression(std::vector const& path, std::vector const& actions) const { + z3::expr_vector clauses(ctx); + for(uint64_t action: actions) { + clauses.push_back(substituteActionExpression(path,action)); + } + return z3::mk_or(clauses); +} + void TerminalNode::createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) { // } @@ -185,9 +193,17 @@ void TerminalNode::substitutePrefixExpressionHarmonizing(std::vector const } z3::expr TerminalNode::substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const { - return action_hole.solver_variable == (int)action or (harmonizing_variable == (int)action_hole.hole and action_hole.solver_variable_harm == (int)action); + return substituteActionExpression(path,action) or (harmonizing_variable == (int)action_hole.hole and action_hole.solver_variable_harm == (int)action); } +z3::expr TerminalNode::substituteActionExpressionHarmonizing(std::vector const& path, std::vector const& actions, z3::expr const& harmonizing_variable) const { + z3::expr_vector clauses(ctx); + for(uint64_t action: actions) { + clauses.push_back(action_hole.solver_variable_harm == (int)action); + } + z3::expr harmonizing_options = z3::mk_or(clauses); + return substituteActionExpression(path,actions) or (harmonizing_variable == (int)action_hole.hole and harmonizing_options); +} void TerminalNode::addFamilyEncoding(Family const& subfamily, z3::solver& solver) const { action_hole.addDomainEncoding(subfamily,solver); @@ -412,6 +428,10 @@ z3::expr InnerNode::substituteActionExpression(std::vector const& path, ui return getChild(path[depth])->substituteActionExpression(path,action); } +z3::expr InnerNode::substituteActionExpression(std::vector const& path, std::vector const& actions) const { + return getChild(path[depth])->substituteActionExpression(path,actions); +} + /*void InnerNode::substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector const& state_valuation, z3::expr_vector & substituted) const { bool step_to_true_child = path[depth]; z3::expr step = step_to_true_child ? step_true_harm : step_false_harm; @@ -433,11 +453,14 @@ void InnerNode::substitutePrefixExpressionHarmonizing(std::vector const& p } - z3::expr InnerNode::substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const { return getChild(path[depth])->substituteActionExpressionHarmonizing(path,action,harmonizing_variable); } +z3::expr InnerNode::substituteActionExpressionHarmonizing(std::vector const& path, std::vector const& actions, z3::expr const& harmonizing_variable) const { + return getChild(path[depth])->substituteActionExpressionHarmonizing(path,actions,harmonizing_variable); +} + void InnerNode::addFamilyEncoding(Family const& subfamily, z3::solver& solver) const { decision_hole.addDomainEncoding(subfamily,solver); diff --git a/payntbind/src/synthesis/quotient/TreeNode.h b/payntbind/src/synthesis/quotient/TreeNode.h index 71860682..3b970aee 100644 --- a/payntbind/src/synthesis/quotient/TreeNode.h +++ b/payntbind/src/synthesis/quotient/TreeNode.h @@ -122,12 +122,14 @@ class TreeNode { virtual void substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const {}; /** Add an action expression evaluated for a given state valuation. */ virtual z3::expr substituteActionExpression(std::vector const& path, uint64_t action) const {return z3::expr(ctx);}; + virtual z3::expr substituteActionExpression(std::vector const& path, std::vector const& actions) const {return z3::expr(ctx);}; /** Add a step expression evaluated for a given state valuation (harmonizing). */ virtual void createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) {}; virtual void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const {}; /** Add an action expression evaluated for a given state valuation (harmonizing). */ virtual z3::expr substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const {return z3::expr(ctx);}; + virtual z3::expr substituteActionExpressionHarmonizing(std::vector const& path, std::vector const& actions, z3::expr const& harmonizing_variable) const {return z3::expr(ctx);}; /** Add encoding of hole option in the given family. */ virtual void addFamilyEncoding(Family const& subfamily, z3::solver & solver) const {} @@ -181,10 +183,12 @@ class TerminalNode: public TreeNode { void createPrefixSubstitutions(std::vector const& state_valuation) override; void substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpression(std::vector const& path, uint64_t action) const override; + z3::expr substituteActionExpression(std::vector const& path, std::vector const& actions) const override; void createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) override; void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const override; + z3::expr substituteActionExpressionHarmonizing(std::vector const& path, std::vector const& actions, z3::expr const& harmonizing_variable) const override; void addFamilyEncoding(Family const& subfamily, z3::solver & solver) const override; bool isPathEnabledInState( @@ -241,10 +245,12 @@ class InnerNode: public TreeNode { void createPrefixSubstitutions(std::vector const& state_valuation) override; void substitutePrefixExpression(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpression(std::vector const& path, uint64_t action) const override; + z3::expr substituteActionExpression(std::vector const& path, std::vector const& actions) const override; void createPrefixSubstitutionsHarmonizing(z3::expr_vector const& state_valuation) override; void substitutePrefixExpressionHarmonizing(std::vector const& path, z3::expr_vector & substituted) const override; z3::expr substituteActionExpressionHarmonizing(std::vector const& path, uint64_t action, z3::expr const& harmonizing_variable) const override; + z3::expr substituteActionExpressionHarmonizing(std::vector const& path, std::vector const& actions, z3::expr const& harmonizing_variable) const override; void addFamilyEncoding(Family const& subfamily, z3::solver & solver) const override; bool isPathEnabledInState( diff --git a/payntbind/src/synthesis/quotient/bindings.cpp b/payntbind/src/synthesis/quotient/bindings.cpp index bfa5954c..27e8f343 100644 --- a/payntbind/src/synthesis/quotient/bindings.cpp +++ b/payntbind/src/synthesis/quotient/bindings.cpp @@ -325,6 +325,7 @@ void bindings_coloring(py::module& m) { .def(py::init< std::vector const&, std::vector const&, + uint64_t, uint64_t, storm::storage::sparse::StateValuations const&, storm::storage::BitVector const&, std::vector const&, diff --git a/payntbind/src/synthesis/translation/choiceTransformation.cpp b/payntbind/src/synthesis/translation/choiceTransformation.cpp index ac844c52..a719da84 100644 --- a/payntbind/src/synthesis/translation/choiceTransformation.cpp +++ b/payntbind/src/synthesis/translation/choiceTransformation.cpp @@ -351,7 +351,8 @@ std::shared_ptr> restoreActionsInAbsorbi template std::shared_ptr> addDontCareAction( - storm::models::sparse::Model const& model + storm::models::sparse::Model const& model, + storm::storage::BitVector const& state_mask ) { auto [action_labels,choice_to_action] = synthesis::extractActionLabels(model); auto it = std::find(action_labels.begin(),action_labels.end(),DONT_CARE_ACTION_LABEL); @@ -371,14 +372,17 @@ std::shared_ptr> addDontCareAction( // translate choices std::vector translated_to_original_choice; std::vector row_groups_new; + std::vector const& row_groups = model.getTransitionMatrix().getRowGroupIndices(); for(uint64_t state = 0; state < num_states; ++state) { row_groups_new.push_back(translated_to_original_choice.size()); // copy existing choices - for(uint64_t choice: model.getTransitionMatrix().getRowGroupIndices(state)) { + for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) { translated_to_original_choice.push_back(choice); } - // add don't care action - translated_to_original_choice.push_back(num_choices); + if(state_mask[state]) { + // add don't care action + translated_to_original_choice.push_back(num_choices); + } } row_groups_new.push_back(translated_to_original_choice.size()); uint64_t num_translated_choices = translated_to_original_choice.size(); @@ -396,22 +400,24 @@ std::shared_ptr> addDontCareAction( storm::storage::SparseMatrixBuilder builder(num_translated_choices, num_states, 0, true, true, num_states); for(uint64_t state = 0; state < num_states; ++state) { builder.newRowGroup(row_groups_new[state]); + uint64_t state_num_choices = row_groups[state+1]-row_groups[state]; // the original number of choices // copy existing choices std::map dont_care_transitions; - uint64_t new_translated_choice = row_groups_new[state+1]-1; - uint64_t state_num_choices = new_translated_choice-row_groups_new[state]; - for(uint64_t translated_choice = row_groups_new[state]; translated_choice < new_translated_choice; ++translated_choice) { - uint64_t choice = translated_to_original_choice[translated_choice]; + uint64_t translated_choice = row_groups_new[state]; + for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) { for(auto entry: model.getTransitionMatrix().getRow(choice)) { uint64_t dst = entry.getColumn(); ValueType prob = entry.getValue(); builder.addNextValue(translated_choice, dst, prob); dont_care_transitions[dst] += prob/state_num_choices; } + ++translated_choice; } - // add don't care action - for(auto [dst,prob]: dont_care_transitions) { - builder.addNextValue(new_translated_choice,dst,prob); + if(state_mask[state]) { + // add don't care action + for(auto [dst,prob]: dont_care_transitions) { + builder.addNextValue(translated_choice,dst,prob); + } } } components.transitionMatrix = builder.build(); @@ -419,13 +425,16 @@ std::shared_ptr> addDontCareAction( for(auto & [name,reward_model]: rewardModels) { std::vector & choice_reward = reward_model.getStateActionRewardVector(); for(uint64_t state = 0; state < num_states; ++state) { + if(not state_mask[state]) { + continue; + } ValueType reward_sum = 0; - uint64_t new_translated_choice = row_groups_new[state+1]-1; - uint64_t state_num_choices = new_translated_choice-row_groups_new[state]; - for(uint64_t translated_choice = row_groups_new[state]; translated_choice < new_translated_choice; ++translated_choice) { + uint64_t dont_care_translated_choice = row_groups_new[state+1]-1; + uint64_t state_num_choices = dont_care_translated_choice-row_groups_new[state]; + for(uint64_t translated_choice = row_groups_new[state]; translated_choice < dont_care_translated_choice; ++translated_choice) { reward_sum += choice_reward[translated_choice]; } - choice_reward[new_translated_choice] = reward_sum / state_num_choices; + choice_reward[dont_care_translated_choice] = reward_sum / state_num_choices; } } components.rewardModels = rewardModels; @@ -587,7 +596,8 @@ template std::shared_ptr> removeAction> restoreActionsInAbsorbingStates( storm::models::sparse::Model const& model); template std::shared_ptr> addDontCareAction( - storm::models::sparse::Model const& model); + storm::models::sparse::Model const& model, + storm::storage::BitVector const& state_mask); template std::shared_ptr> createModelUnion( std::vector>> const& ); diff --git a/payntbind/src/synthesis/translation/choiceTransformation.h b/payntbind/src/synthesis/translation/choiceTransformation.h index 19845a19..6e8d4360 100644 --- a/payntbind/src/synthesis/translation/choiceTransformation.h +++ b/payntbind/src/synthesis/translation/choiceTransformation.h @@ -98,17 +98,19 @@ std::shared_ptr> removeAction( * these unlabeled actions. * @return an updated model or NULL if no change took place */ + template std::shared_ptr> restoreActionsInAbsorbingStates( storm::models::sparse::Model const& model ); /** - * To every state of an MDP add an explicit action that executes a random choice between available actions. + * To every state in \p state_mask, add an explicit action that executes a random choice between available actions. */ template std::shared_ptr> addDontCareAction( - storm::models::sparse::Model const& model + storm::models::sparse::Model const& model, + storm::storage::BitVector const& state_mask ); /**