Skip to content

Commit 99a803a

Browse files
committed
Solve bug exposed by RandSizeListOrder test
Apply all list length constraints when creating a multi-variable problem at the start. Functional fix and small performance boost for some cases.
1 parent bbc13cf commit 99a803a

File tree

4 files changed

+98
-67
lines changed

4 files changed

+98
-67
lines changed

constrainedrandom/internal/multivar.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (c) 2023 Imagination Technologies Ltd. All Rights Reserved
33

4+
import random
45
from collections import defaultdict
5-
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union
6+
from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Union
67

78
from .vargroup import VarGroup
89

910
from .. import utils
1011
from ..debug import RandomizationDebugInfo
1112

1213
if TYPE_CHECKING:
13-
from ..randobj import RandObj
1414
from ..internal.randvar import RandVar
1515

1616

@@ -20,7 +20,7 @@ class MultiVarProblem:
2020
Represents one problem concerning multiple random variables,
2121
where those variables all share dependencies on one another.
2222
23-
:param parent: The :class:`RandObj` instance that owns this instance.
23+
:param random_getter: A callable returning the random instance to use within this instance.
2424
:param vars: The dictionary of names and :class:`RandVar` instances this problem consists of.
2525
:param constraints: An iterable of tuples of (constraint, (variables,...)) denoting
2626
the constraints and the variables they apply to.
@@ -35,20 +35,21 @@ class MultiVarProblem:
3535

3636
def __init__(
3737
self,
38-
parent: 'RandObj',
38+
*,
39+
random_getter: Callable[[], random.Random],
3940
vars: List['RandVar'],
4041
constraints: Iterable[utils.ConstraintAndVars],
4142
max_iterations: int,
4243
max_domain_size: int,
4344
) -> None:
44-
self.parent = parent
45+
self.random_getter = random_getter
4546
self.vars = vars
4647
self.constraints = constraints
4748
self.max_iterations = max_iterations
4849
self.max_domain_size = max_domain_size
49-
self.order = None
50+
self.order: Optional[List[List['RandVar']]] = None
5051
self.debug = False
51-
self.debug_info = None
52+
self.debug_info: Optional[RandomizationDebugInfo] = None
5253

5354
def determine_order(self, with_values: Dict[str, Any]) -> List[List['RandVar']]:
5455
'''
@@ -149,7 +150,7 @@ def solve_groups(
149150
# For each group, construct a problem and solve it.
150151
for group in groups:
151152
group_solutions = None
152-
group_problem = None
153+
group_problem: VarGroup = None
153154
attempts = 0
154155
while group_solutions is None or len(group_solutions) == 0:
155156
# Early loop exit cases
@@ -169,7 +170,7 @@ def solve_groups(
169170
if solutions_per_group >= len(solutions):
170171
solution_subset = list(solutions)
171172
else:
172-
solution_subset = self.parent._get_random().choices(
173+
solution_subset = self.random_getter().choices(
173174
solutions,
174175
k=solutions_per_group
175176
)
@@ -209,11 +210,11 @@ def solve_groups(
209210
if solutions_per_group == 1:
210211
# This means we have exactly one solution for the variables considered so far,
211212
# meaning we don't need to re-apply solved constraints for future groups.
212-
constraints = group_problem.skipped_constraints
213+
constraints = group_problem.get_remaining_constraints()
213214
solved_vars += group_problem.group_vars
214215
solutions = group_solutions
215216

216-
return self.parent._get_random().choice(solutions)
217+
return self.random_getter().choice(solutions)
217218

218219
def solve(
219220
self,

constrainedrandom/internal/vargroup.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ class VarGroup:
3131
size, we don't use the ``constraint`` package, but just use ``random`` instead.
3232
:param debug: ``True`` to run in debug mode. Slower, but collects
3333
all debug info along the way and not just the final failure.
34-
:return: A tuple of 1) a list the names of the variables in the group,
35-
2) a list of variables that must be randomized rather than solved
36-
via a constraint problem,
37-
3) a list of constraints and variables that won't be applied for this group.
3834
'''
3935

4036
def __init__(
@@ -52,6 +48,7 @@ def __init__(
5248
self.problem = constraint.Problem()
5349
self.max_domain_size = max_domain_size
5450
self.debug = debug
51+
self.remaining_constraints: List[utils.ConstraintAndVars] = []
5552

5653
# Respect already-solved values when solving the constraint problem.
5754
for var_name, values in self.solution_space.items():
@@ -66,20 +63,13 @@ def __init__(
6663

6764
# Consider whether this variable has random length.
6865
possible_lengths = None
69-
if var.rand_length is not None:
66+
if var.rand_length is not None and var.rand_length in self.solution_space:
7067
# Variable has a random length.
7168
# We guarantee that the random length variable will be solved
7269
# before this one, if it is even part of the problem.
7370
# If it's not in solution_space, we've already chosen the value
7471
# for it and set the random length based on it.
75-
if var.rand_length in self.solution_space:
76-
# Deal with potential values.
77-
possible_lengths = self.solution_space[var.rand_length]
78-
# Create a constraint that the length must be defined
79-
# by the other variable.
80-
len_constr = lambda listvar, length : len(listvar) == length
81-
self.problem.addConstraint(len_constr, (var.name, var.rand_length))
82-
self.raw_constraints.append((len_constr, (var.name, var.rand_length)))
72+
possible_lengths = self.solution_space[var.rand_length]
8373

8474
# Either add to constraint problem with full domain,
8575
# or treat it as a variable to be randomized.
@@ -107,7 +97,6 @@ def __init__(
10797
self.rand_vars.append(var)
10898

10999
# Add all pertinent constraints
110-
self.skipped_constraints = []
111100
for (con, vars) in constraints:
112101
skip = False
113102
for var in vars:
@@ -116,7 +105,7 @@ def __init__(
116105
skip = True
117106
break
118107
if skip:
119-
self.skipped_constraints.append((con, vars))
108+
self.remaining_constraints.append((con, vars))
120109
continue
121110
self.problem.addConstraint(con, vars)
122111
self.raw_constraints.append((con, vars))
@@ -143,6 +132,16 @@ def can_retry(self) -> bool:
143132
'''
144133
return len(self.rand_vars) > 0
145134

135+
def get_remaining_constraints(self) -> List[utils.ConstraintAndVars]:
136+
'''
137+
Call this to get the constraints that must
138+
still be applied to other variables in future.
139+
140+
:return: A list of tuples, each tuple containing a
141+
constraint and a tuple of its variables.
142+
'''
143+
return self.remaining_constraints
144+
146145
def concretize_rand_length(
147146
self,
148147
rand_list_var: 'RandVar',
@@ -166,7 +165,7 @@ def concretize_rand_length(
166165
if rand_list_var.rand_length in concrete_values:
167166
rand_length_val = concrete_values[rand_list_var.rand_length]
168167
# Otherwise pick a value and save it.
169-
if rand_list_var.rand_length in self.solution_space:
168+
elif rand_list_var.rand_length in self.solution_space:
170169
options = self.solution_space[rand_list_var.rand_length]
171170
rand_length_val = rand_list_var._get_random().choice(options)
172171
concrete_values[rand_list_var.rand_length] = rand_length_val
@@ -176,8 +175,8 @@ def concretize_rand_length(
176175
else:
177176
# If we haven't got a value for the random var
178177
# in this problem, it must have already been set.
179-
assert rand_list_var.rand_length_val is not None, \
180-
"Rand length must be concretized, but wasn't"
178+
if rand_list_var.rand_length_val is None:
179+
raise RuntimeError(f"Internal error: Rand length must be concretized for variable {rand_list_var.name}, but wasn't")
181180

182181
def solve(
183182
self,

constrainedrandom/randobj.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import constraint
55
import random
66
from collections import defaultdict
7-
from typing import Any, Callable, Dict, Iterable, List, Optional
7+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
88

99
from . import utils
1010
from .internal.multivar import MultiVarProblem
@@ -59,18 +59,19 @@ def __init__(
5959
max_domain_size: int=utils.CONSTRAINT_MAX_DOMAIN_SIZE,
6060
) -> None:
6161
# Prefix 'internal use' variables with '_', as randomized results are populated to the class
62-
self._random = _random
63-
self._random_vars = {}
64-
self._rand_list_lengths = defaultdict(list)
65-
self._constraints : List[utils.ConstraintAndVars] = []
66-
self._constrained_vars = set()
67-
self._max_iterations = max_iterations
68-
self._max_domain_size = max_domain_size
69-
self._naive_solve = True
70-
self._sparse_solve = True
71-
self._sparsities = [1, 10, 100, 1000]
72-
self._thorough_solve = True
73-
self._problem_changed = False
62+
self._random: Optional[random.Random] = _random
63+
self._random_vars: Dict[str, RandVar] = {}
64+
self._rand_list_lengths: Dict[str, List[str]]= defaultdict(list)
65+
self._constraints: List[utils.ConstraintAndVars] = []
66+
self._constrained_vars : Set[str] = set()
67+
self._max_iterations: int = max_iterations
68+
self._max_domain_size: int = max_domain_size
69+
self._naive_solve: bool = True
70+
self._sparse_solve: bool = True
71+
self._sparsities: List[int] = [1, 10, 100, 1000]
72+
self._thorough_solve: bool = True
73+
self._problem_changed: bool = False
74+
self._multi_var_problem: Optional[MultiVarProblem] = None
7475

7576
def _get_random(self) -> random.Random:
7677
'''
@@ -85,6 +86,29 @@ def _get_random(self) -> random.Random:
8586
return random
8687
return self._random
8788

89+
def _get_list_length_constraints(self, var_names: Set[str]) -> List[utils.ConstraintAndVars]:
90+
'''
91+
Internal function to get constraints to describe
92+
the relationship between random list lengths and the variables
93+
that define them.
94+
95+
:param var_names: List of variable names that we want to
96+
constrain. Only consider variables from within
97+
this list in the result. Both the list length
98+
variable and the list variable it constrains must
99+
be in ``var_names`` to return a constraint.
100+
:return: List of constraints with variables, describing
101+
relationship between random list variables and lengths.
102+
'''
103+
result: List[utils.ConstraintAndVars] = []
104+
for rand_list_length, list_vars in self._rand_list_lengths.items():
105+
if rand_list_length in var_names:
106+
for list_var in list_vars:
107+
if list_var in var_names:
108+
len_constr = lambda _list_var, _length : len(_list_var) == _length
109+
result.append((len_constr, (list_var, rand_list_length)))
110+
return result
111+
88112
def set_solver_mode(
89113
self,
90114
*,
@@ -100,7 +124,7 @@ def set_solver_mode(
100124
1. Naive solve - randomizing and checking constraints.
101125
For some problems, it is more expedient to skip this
102126
step and go straight to a MultiVarProblem.
103-
2. Sparse solve - graph-based exporation of state space.
127+
2. Sparse solve - graph-based exploration of state space.
104128
Start with depth-first search, move to wider subsets
105129
of each level of state space until valid solution
106130
found.
@@ -352,38 +376,38 @@ def randomize(
352376
result = {}
353377

354378
# Copy always-on constraints, ready to add any temporary ones
355-
constraints = list(self._constraints)
356-
constrained_vars = set(self._constrained_vars)
379+
constraints: Set[utils.ConstraintAndVars] = list(self._constraints)
380+
constrained_var_names: Set[str] = set(self._constrained_vars)
357381

358382
# Process temporary constraints
359-
tmp_single_var_constraints = defaultdict(list)
383+
tmp_single_var_constraints: Dict[str, List[utils.Constraint]] = defaultdict(list)
360384
# Set to True if the problem is different from the base problem
361385
problem_changed = False
362386
if with_constraints is not None:
363-
for constr, vars in with_constraints:
364-
if not isinstance(vars, Iterable):
387+
for constr, var_names in with_constraints:
388+
if not isinstance(var_names, Iterable):
365389
raise TypeError("with_constraints should specify a list of tuples of (constraint, Iterable[variables])")
366-
if not len(vars) > 0:
390+
if not len(var_names) > 0:
367391
raise ValueError("Cannot add a constraint that applies to no variables")
368-
if len(vars) == 1:
392+
if len(var_names) == 1:
369393
# Single-variable constraint
370-
tmp_single_var_constraints[vars[0]].append(constr)
394+
tmp_single_var_constraints[var_names[0]].append(constr)
371395
problem_changed = True
372396
else:
373397
# Multi-variable constraint
374-
constraints.append((constr, vars))
375-
for var in vars:
376-
constrained_vars.add(var)
398+
constraints.append((constr, var_names))
399+
for var_name in var_names:
400+
constrained_var_names.add(var_name)
377401
problem_changed = True
378402
# If a variable becomes constrained due to temporary multi-variable
379403
# constraints, we must respect single var temporary constraints too.
380-
for var, constrs in sorted(tmp_single_var_constraints.items()):
381-
if var in constrained_vars:
404+
for var_name, constrs in sorted(tmp_single_var_constraints.items()):
405+
if var_name in constrained_var_names:
382406
for constr in constrs:
383-
constraints.append((constr, (var,)))
407+
constraints.append((constr, (var_name,)))
384408

385409
# Don't allow non-determinism when iterating over a set
386-
constrained_vars = sorted(constrained_vars)
410+
constrained_var_names = sorted(constrained_var_names)
387411
# Don't allow non-determinism when iterating over a dict
388412
random_var_names = sorted(self._random_vars.keys())
389413
list_length_names = sorted(self._rand_list_lengths.keys())
@@ -423,8 +447,8 @@ def randomize(
423447
if attempts == max:
424448
break
425449
problem = constraint.Problem()
426-
for var in constrained_vars:
427-
problem.addVariable(var, (result[var],))
450+
for var_name in constrained_var_names:
451+
problem.addVariable(var_name, (result[var_name],))
428452
for _constraint, variables in constraints:
429453
problem.addConstraint(_constraint, variables)
430454
solutions = problem.getSolutions()
@@ -439,7 +463,7 @@ def randomize(
439463
for list_length_name in list_length_names:
440464
# If the length-defining variable is constrained,
441465
# re-randomize it and all its dependent vars.
442-
if list_length_name not in with_values and list_length_name in constrained_vars:
466+
if list_length_name not in with_values and list_length_name in constrained_var_names:
443467
tmp_constraints = tmp_single_var_constraints.get(list_length_name, [])
444468
length_result = self._random_vars[list_length_name].randomize(tmp_constraints, debug)
445469
result[list_length_name] = length_result
@@ -449,7 +473,7 @@ def randomize(
449473
self._random_vars[dependent_var_name].set_rand_length(length_result)
450474
tmp_constraints = tmp_single_var_constraints.get(dependent_var_name, [])
451475
result[dependent_var_name] = self._random_vars[dependent_var_name].randomize(tmp_constraints, debug)
452-
for var in constrained_vars:
476+
for var in constrained_var_names:
453477
# Don't re-randomize if we've specified a concrete value
454478
if var in with_values:
455479
continue
@@ -458,7 +482,7 @@ def randomize(
458482
continue
459483
# Don't re-randomize list vars which have been re-randomized once already.
460484
rand_length = self._random_vars[var].rand_length
461-
if rand_length is not None and rand_length in constrained_vars:
485+
if rand_length is not None and rand_length in constrained_var_names:
462486
continue
463487
tmp_constraints = tmp_single_var_constraints.get(var, [])
464488
result[var] = self._random_vars[var].randomize(tmp_constraints, debug)
@@ -473,10 +497,14 @@ def randomize(
473497
' There is no way to solve the problem.'
474498
)
475499
if problem_changed or self._problem_changed or self._multi_var_problem is None:
500+
# Add list length constraints here.
501+
# By this point, we have failed to get a naive solution,
502+
# so we need the list lengths as proper constraints.
503+
constraints += self._get_list_length_constraints(constrained_var_names)
476504
multi_var_problem = MultiVarProblem(
477-
self,
478-
[self._random_vars[var_name] for var_name in constrained_vars],
479-
constraints,
505+
random_getter=self._get_random,
506+
vars=[self._random_vars[var_name] for var_name in constrained_var_names],
507+
constraints=constraints,
480508
max_iterations=self._max_iterations,
481509
max_domain_size=self._max_domain_size,
482510
)

tests/features/user.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class RandSizeListOrder(testutils.RandObjTestBase):
5959
- more than one list depending on the same random length
6060
- a small total state space
6161
- the lists are constrained based on one another and the random length
62+
- the lists have a different order value from one another
6263
- a constraint that uses the random length to index the lists
6364
- fails with naive solver, or skips it
6465
- fails with sparsities == 1, or manually run with sparsities > 1
@@ -82,7 +83,7 @@ def var_not_in_list(length, list1, list2):
8283
return False
8384
return True
8485
r.add_constraint(var_not_in_list, ('length', 'list1', 'list2'))
85-
r.set_solver_mode(naive=False, sparsities=[10])
86+
r.set_solver_mode(naive=False, sparsities=[2])
8687
return r
8788

8889
def check(self, results):
@@ -94,3 +95,5 @@ def check(self, results):
9495
self.assertEqual(len(list2), length, "list2 length was wrong")
9596
for x1, x2 in zip(list1, list2):
9697
self.assertFalse(x1 == x2)
98+
self.assertIn(x1, range(-10, 11))
99+
self.assertIn(x2, range(-10, 11))

0 commit comments

Comments
 (0)