Skip to content

Commit f72da19

Browse files
authored
Merge pull request #57 from randriu/init-optimized
optimizing formulae costruction in ColoringSmt
2 parents cc6f9cd + cfbcec5 commit f72da19

File tree

8 files changed

+329
-322
lines changed

8 files changed

+329
-322
lines changed

paynt/quotient/mdp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def reset_tree(self, depth, enable_harmonization=True):
459459
self.quotient_mdp.state_valuations, self.state_is_relevant_bv,
460460
variable_name, variable_domain, tree_list, enable_harmonization
461461
)
462+
# return
462463
self.coloring.enableStateExploration(self.quotient_mdp)
463464

464465
# reconstruct the family

paynt/synthesizer/decision_tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def map_scheduler(self, scheduler_choices):
205205
break
206206

207207
def run(self, optimum_threshold=None):
208+
# self.quotient.reset_tree(SynthesizerDecisionTree.tree_depth,enable_harmonization=True)
208209
scheduler_choices = None
209210
if SynthesizerDecisionTree.scheduler_path is None:
210211
paynt_mdp = paynt.models.models.Mdp(self.quotient.quotient_mdp)
@@ -278,6 +279,7 @@ def run(self, optimum_threshold=None):
278279

279280
if self.export_synthesis_filename_base is not None:
280281
self.export_decision_tree(self.best_tree, self.export_synthesis_filename_base)
282+
281283
time_total = round(paynt.utils.timer.GlobalTimer.read(),2)
282284
logger.info(f"synthesis finished after {time_total} seconds")
283285

payntbind/src/synthesis/quotient/ColoringSmt.cpp

Lines changed: 89 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ ColoringSmt<ValueType>::ColoringSmt(
2828
solver(ctx), harmonizing_variable(ctx), enable_harmonization(enable_harmonization) {
2929

3030
timers[__FUNCTION__].start();
31+
timers["ColoringSmt::0"].start();
3132

3233
for(uint64_t state = 0; state < numStates(); ++state) {
3334
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
@@ -60,17 +61,6 @@ ColoringSmt<ValueType>::ColoringSmt(
6061
STORM_LOG_THROW(variable_found, storm::exceptions::UnexpectedException, "Unexpected variable name.");
6162
}
6263

63-
// create substitution variables
64-
z3::expr_vector state_substitution_variables(ctx);
65-
z3::expr_vector choice_substitution_variables(ctx);
66-
for(auto const& name: variable_name) {
67-
z3::expr variable = ctx.int_const(name.c_str());
68-
state_substitution_variables.push_back(variable);
69-
choice_substitution_variables.push_back(variable);
70-
}
71-
z3::expr action_substitution_variable = ctx.int_const("act");
72-
choice_substitution_variables.push_back(action_substitution_variable);
73-
7464
// create the tree
7565
uint64_t num_nodes = tree_list.size();
7666
this->num_actions = *std::max_element(choice_to_action.begin(),choice_to_action.end())+1;
@@ -81,13 +71,12 @@ ColoringSmt<ValueType>::ColoringSmt(
8171
"Inner node has only one child."
8272
);
8373
if(child_true != num_nodes) {
84-
tree.push_back(std::make_shared<InnerNode>(node,ctx,this->variable_name,this->variable_domain,state_substitution_variables));
74+
tree.push_back(std::make_shared<InnerNode>(node,ctx,this->variable_name,this->variable_domain));
8575
} else {
86-
tree.push_back(std::make_shared<TerminalNode>(node,ctx,this->variable_name,this->variable_domain,this->num_actions,action_substitution_variable));
76+
tree.push_back(std::make_shared<TerminalNode>(node,ctx,this->variable_name,this->variable_domain,this->num_actions));
8777
}
8878
}
8979
getRoot()->createTree(tree_list,tree);
90-
9180
getRoot()->createHoles(family);
9281
harmonizing_variable = ctx.int_const("__harm__");
9382
getRoot()->createPaths(harmonizing_variable);
@@ -118,14 +107,15 @@ ColoringSmt<ValueType>::ColoringSmt(
118107
STORM_LOG_THROW(domain_option_found, storm::exceptions::UnexpectedException, "Hole option not found.");
119108
}
120109
}
110+
timers["ColoringSmt::0"].stop();
121111

122112
// create choice colors
123-
timers["ColoringSmt:: create choice colors"].start();
113+
timers["ColoringSmt::1 create choice colors"].start();
114+
// std::cout << "ColoringSmt::1 create choice colors" << std::endl << std::flush;
124115

125116
for(std::vector<bool> const& path: getRoot()->paths) {
126117
path_action_hole.push_back(getRoot()->getPathActionHole(path));
127118
}
128-
129119
choice_path_label.resize(numChoices());
130120
for(uint64_t state: state_is_relevant) {
131121
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
@@ -136,111 +126,121 @@ ColoringSmt<ValueType>::ColoringSmt(
136126
}
137127
}
138128

139-
std::vector<z3::expr_vector> state_path_expression;
129+
std::vector<const TerminalNode*> terminals;
130+
for(uint64_t path = 0; path < numPaths(); ++path) {
131+
terminals.push_back(getRoot()->getTerminal(getRoot()->paths[path]));
132+
}
133+
134+
// allocate array for path expressions
135+
uint64_t longest_path = 0;
136+
for(uint64_t path = 0; path < numPaths(); ++path) {
137+
longest_path = std::max(longest_path,getRoot()->paths[path].size());
138+
}
139+
z3::expr_vector state_valuation_int(ctx);
140+
z3::array<Z3_ast> clause_array(longest_path-1+num_actions);
141+
142+
getRoot()->substituteActionExpressions();
143+
choice_path_expresssion.resize(numChoices());
140144
for(uint64_t state = 0; state < numStates(); ++state) {
141-
state_path_expression.push_back(z3::expr_vector(ctx));
142145
if(not state_is_relevant[state]) {
143146
continue;
144147
}
145-
getRoot()->createPrefixSubstitutions(state_valuation[state]);
146-
for(uint64_t path = 0; path < numPaths(); ++path) {
147-
z3::expr_vector evaluated(ctx);
148-
getRoot()->substitutePrefixExpression(getRoot()->paths[path], evaluated);
149-
state_path_expression[state].push_back(z3::mk_or(evaluated));
150-
}
151-
}
152-
std::vector<z3::expr_vector> action_path_expression;
153-
for(uint64_t action = 0; action < this->num_actions; ++action) {
154-
action_path_expression.push_back(z3::expr_vector(ctx));
155-
for(uint64_t path = 0; path < numPaths(); ++path) {
156-
z3::expr evaluated = getRoot()->substituteActionExpression(getRoot()->paths[path], action);
157-
action_path_expression[action].push_back(evaluated);
158-
}
159-
}
160148

161-
std::vector<std::vector<uint64_t>> state_dont_care_actions(numStates());
162-
for(uint64_t state: state_is_relevant) {
163-
state_dont_care_actions[state].push_back(dont_care_action);
164-
for(uint64_t action: ~state_available_actions[state]) {
165-
state_dont_care_actions[state].push_back(action);
149+
for(uint64_t value: state_valuation[state]) {
150+
state_valuation_int.push_back(ctx.int_val(value));
166151
}
167-
}
168-
152+
timers["ColoringSmt::1-2 createPrefixSubstitutions"].start();
153+
getRoot()->createPrefixSubstitutions(state_valuation[state], state_valuation_int);
154+
timers["ColoringSmt::1-2 createPrefixSubstitutions"].stop();
155+
state_valuation_int.resize(0);
169156

170-
for(uint64_t state = 0; state < numStates(); ++state) {
171-
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
172-
choice_path_expresssion.push_back(z3::expr_vector(ctx));
173-
if(not state_is_relevant[state]) {
174-
continue;
175-
}
176-
uint64_t action = choice_to_action[choice];
177-
for(uint64_t path = 0; path < numPaths(); ++path) {
178-
z3::expr action_selection = action_path_expression[action][path];
157+
timers["ColoringSmt::1-3"].start();
158+
for(uint64_t path = 0; path < numPaths(); ++path) {
159+
timers["ColoringSmt::1-3-1"].start();
160+
getRoot()->substitutePrefixExpression(getRoot()->paths[path], clause_array);
161+
timers["ColoringSmt::1-3-1"].stop();
162+
163+
timers["ColoringSmt::1-3-2"].start();
164+
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
165+
timers["ColoringSmt::1-3-2-1"].start();
166+
uint64_t num_clauses = getRoot()->paths[path].size()-1;
167+
uint64_t action = choice_to_action[choice];
168+
clause_array[num_clauses++] = terminals[path]->action_expression[action];
169+
timers["ColoringSmt::1-3-2-1"].stop();
170+
timers["ColoringSmt::1-3-2-2"].start();
179171
if(action == dont_care_action) {
180-
action_selection = getRoot()->substituteActionExpression(getRoot()->paths[path], state_dont_care_actions[state]);
172+
for(uint64_t unavailable_action: ~state_available_actions[state]) {
173+
clause_array[num_clauses++] = terminals[path]->action_expression[unavailable_action];
174+
}
181175
}
182-
choice_path_expresssion[choice].push_back(state_path_expression[state][path] or action_selection);
176+
timers["ColoringSmt::1-3-2-2"].stop();
177+
choice_path_expresssion[choice].push_back(z3::expr(ctx, Z3_mk_or(ctx, num_clauses, clause_array.ptr())));
178+
// choice_path_expresssion[choice].push_back(Z3_mk_or(ctx, num_clauses, clause_array.ptr()));
183179
}
180+
timers["ColoringSmt::1-3-2"].stop();
184181
}
182+
timers["ColoringSmt::1-3"].stop();
185183
}
186-
timers["ColoringSmt:: create choice colors"].stop();
184+
timers["ColoringSmt::1 create choice colors"].stop();
187185

188186
if(not this->enable_harmonization) {
189187
timers[__FUNCTION__].stop();
190188
return;
191189
}
192190

193-
timers["ColoringSmt:: create harmonizing variants"].start();
194-
191+
timers["ColoringSmt::2 create harmonizing variants"].start();
192+
// std::cout << "ColoringSmt::2 create harmonizing variants" << std::endl << std::flush;
195193

196-
// create harmonizing expressions
197-
std::vector<z3::expr_vector> state_path_expression_harmonizing;
194+
getRoot()->substituteActionExpressionsHarmonizing(harmonizing_variable);
195+
choice_path_expresssion_harm.resize(numChoices());
198196
for(uint64_t state = 0; state < numStates(); ++state) {
199-
state_path_expression_harmonizing.push_back(z3::expr_vector(ctx));
200197
if(not state_is_relevant[state]) {
201198
continue;
202199
}
203-
// create state substitution
204-
z3::expr_vector substitution_expr(ctx);
200+
205201
for(uint64_t value: state_valuation[state]) {
206-
substitution_expr.push_back(ctx.int_val(value));
207-
}
208-
getRoot()->createPrefixSubstitutionsHarmonizing(substitution_expr);
209-
for(uint64_t path = 0; path < numPaths(); ++path) {
210-
z3::expr_vector evaluated(ctx);
211-
getRoot()->substitutePrefixExpressionHarmonizing(getRoot()->paths[path], evaluated);
212-
state_path_expression_harmonizing[state].push_back(z3::mk_or(evaluated));
202+
state_valuation_int.push_back(ctx.int_val(value));
213203
}
214-
}
215-
std::vector<z3::expr_vector> action_path_expression_harmonizing;
216-
for(uint64_t action = 0; action < num_actions; ++action) {
217-
action_path_expression_harmonizing.push_back(z3::expr_vector(ctx));
204+
timers["ColoringSmt::2-2 createPrefixSubstitutionsHarmonizing"].start();
205+
getRoot()->createPrefixSubstitutionsHarmonizing(state_valuation[state], state_valuation_int, harmonizing_variable);
206+
timers["ColoringSmt::2-2 createPrefixSubstitutionsHarmonizing"].stop();
207+
state_valuation_int.resize(0);
208+
209+
timers["ColoringSmt::2-3"].start();
218210
for(uint64_t path = 0; path < numPaths(); ++path) {
219-
z3::expr evaluated = getRoot()->substituteActionExpressionHarmonizing(getRoot()->paths[path], action, harmonizing_variable);
220-
action_path_expression_harmonizing[action].push_back(evaluated);
221-
}
222-
}
223-
for(uint64_t state = 0; state < numStates(); ++state) {
224-
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
225-
choice_path_expresssion_harm.push_back(z3::expr_vector(ctx));
226-
if(not state_is_relevant[state]) {
227-
continue;
228-
}
229-
uint64_t action = choice_to_action[choice];
230-
for(uint64_t path = 0; path < numPaths(); ++path) {
231-
z3::expr action_selection = action_path_expression_harmonizing[action][path];
211+
timers["ColoringSmt::2-3-1"].start();
212+
getRoot()->substitutePrefixExpressionHarmonizing(getRoot()->paths[path], clause_array);
213+
timers["ColoringSmt::2-3-1"].stop();
214+
215+
timers["ColoringSmt::2-3-2"].start();
216+
for(uint64_t choice = row_groups[state]; choice < row_groups[state+1]; ++choice) {
217+
uint64_t action = choice_to_action[choice];
218+
uint64_t num_clauses = getRoot()->paths[path].size()-1;
219+
clause_array[num_clauses++] = terminals[path]->action_expression_harmonizing[action];
232220
if(action == dont_care_action) {
233-
action_selection = getRoot()->substituteActionExpressionHarmonizing(getRoot()->paths[path], state_dont_care_actions[state], harmonizing_variable);
221+
for(uint64_t unavailable_action: ~state_available_actions[state]) {
222+
clause_array[num_clauses++] = terminals[path]->action_expression_harmonizing[unavailable_action];
223+
}
234224
}
235-
choice_path_expresssion_harm[choice].push_back(state_path_expression_harmonizing[state][path] or action_selection);
225+
choice_path_expresssion_harm[choice].push_back(z3::expr(ctx, Z3_mk_or(ctx, num_clauses, clause_array.ptr())));
226+
// choice_path_expresssion_harm[choice].push_back(Z3_mk_or(ctx, num_clauses, clause_array.ptr()));
236227
}
228+
timers["ColoringSmt::2-3-2"].stop();
237229
}
230+
timers["ColoringSmt::2-3"].stop();
238231
}
239-
timers["ColoringSmt:: create harmonizing variants"].stop();
232+
timers["ColoringSmt::2 create harmonizing variants"].stop();
233+
234+
getRoot()->clearCache();
240235

241236
timers[__FUNCTION__].stop();
242237
}
243238

239+
template<typename ValueType>
240+
ColoringSmt<ValueType>::~ColoringSmt() {
241+
tree.clear();
242+
}
243+
244244
template<typename ValueType>
245245
void ColoringSmt<ValueType>::enableStateExploration(storm::models::sparse::NondeterministicModel<ValueType> const& model) {
246246
this->state_exploration_enabled = true;
@@ -530,7 +530,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
530530
}
531531
}*/
532532

533-
534533
timers["areChoicesConsistent::1 is scheduler consistent?"].start();
535534
solver.push();
536535
getRoot()->addFamilyEncoding(subfamily,solver);
@@ -543,6 +542,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
543542
for(uint64_t path: state_path_enabled[state]) {
544543
const char *label = choice_path_label[choice][path].c_str();
545544
solver.add(choice_path_expresssion[choice][path], label);
545+
// Z3_solver_assert_and_track(ctx, solver.operator Z3_solver(), choice_path_expresssion[choice][path], ctx.bool_const(label));
546546
}
547547
}
548548
bool consistent = check();
@@ -584,6 +584,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
584584
for(uint64_t path: state_path_enabled[state]) {
585585
const char *label = choice_path_label[choice][path].c_str();
586586
solver.add(choice_path_expresssion[choice][path], label);
587+
// Z3_solver_assert_and_track(ctx, solver.operator Z3_solver(), choice_path_expresssion[choice][path], ctx.bool_const(label));
587588
}
588589
consistent = check();
589590
}
@@ -603,6 +604,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
603604
for(auto [choice,path]: this->unsat_core) {
604605
const char *label = choice_path_label[choice][path].c_str();
605606
solver.add(choice_path_expresssion_harm[choice][path], label);
607+
// Z3_solver_assert_and_track(ctx, solver.operator Z3_solver(), choice_path_expresssion_harm[choice][path], ctx.bool_const(label));
606608
}
607609

608610
z3::model model(ctx);

payntbind/src/synthesis/quotient/ColoringSmt.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class ColoringSmt {
4848
bool enable_harmonization
4949
);
5050

51+
~ColoringSmt();
52+
5153
/**
5254
* Enable efficient state exploration of reachable states.
5355
* @note this is required for harmonization
@@ -156,14 +158,17 @@ class ColoringSmt {
156158
/** For each choice and path, a label passed to SMT solver. */
157159
std::vector<std::vector<std::string>> choice_path_label;
158160
/** For each choice, its color expressed as a conjunction of all path implications. */
159-
std::vector<z3::expr_vector> choice_path_expresssion;
161+
std::vector<std::vector<z3::expr>> choice_path_expresssion;
162+
// std::vector<std::vector<Z3_ast>> choice_path_expresssion;
163+
160164

161165
/** Whether harmonization is required. */
162166
const bool enable_harmonization;
163167
/** SMT variable refering to harmonizing hole. */
164168
z3::expr harmonizing_variable;
165169
/** For each choice, its color expressed as a conjunction of all path implications. */
166-
std::vector<z3::expr_vector> choice_path_expresssion_harm;
170+
std::vector<std::vector<z3::expr>> choice_path_expresssion_harm;
171+
// std::vector<std::vector<Z3_ast>> choice_path_expresssion_harm;
167172

168173
/** For each state, whether (in the last subfamily) the path was enabled. */
169174
std::vector<BitVector> state_path_enabled;

0 commit comments

Comments
 (0)