Skip to content

Commit bae427e

Browse files
ezyangpytorchmergebot
authored andcommitted
Refactor maybe_evaluate_static into a worker function off of ShapeEnv (pytorch#135107)
By refactoring this way, I can put a non-expiring LRU cache here. Splitting also will make it easier for me to tell who is using up all the time. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#135107 Approved by: https://github.com/aorenste
1 parent e9bfbf7 commit bae427e

File tree

1 file changed

+111
-99
lines changed

1 file changed

+111
-99
lines changed

torch/fx/experimental/symbolic_shapes.py

Lines changed: 111 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,104 @@ def safe_expand(r):
15121512
else:
15131513
return r
15141514

1515+
1516+
@lru_cache(None)
1517+
def _maybe_evaluate_static_worker(
1518+
expr: sympy.Expr,
1519+
symbol_info: Tuple[Tuple[sympy.Symbol, ValueRanges, sympy.Integer, bool], ...],
1520+
unbacked_only: bool,
1521+
size_oblivious: bool
1522+
):
1523+
"""
1524+
This variant of ShapeEnv._maybe_evaluate_static has no dependence on
1525+
ShapeEnv and thus can be cached indefinitely. It does the "heavy" lifting
1526+
for static evaluation, including nontrivial reliance on Sympy simplification
1527+
that occurs when we reallocate the symbols
1528+
"""
1529+
1530+
# Simplify making use of value range lower bound
1531+
new_shape_env = {}
1532+
new_range_env = {}
1533+
for idx, sinfo in enumerate(symbol_info):
1534+
k, vr, val, is_size_like = sinfo
1535+
if isinstance(val, SingletonInt):
1536+
# Skip var_ranges logic for SingletonInt which is only used
1537+
# for jagged layout NestedTensors today
1538+
continue
1539+
if size_oblivious and is_size_like:
1540+
lower = max(2, vr.lower)
1541+
# Clamping size-oblivious to some quantity below sys.maxsize
1542+
# helps us determine that f(u0) != sys.maxsize, which is a
1543+
# test that is looking for sys.maxsize as a sentinel, but you
1544+
# don't really want to worry about it for unbacked SymInts.
1545+
# This is similar to the flavor where size oblivious omits
1546+
# 0/1, it changes semantics but in a benign way.
1547+
upper = min(2 ** 48, vr.upper)
1548+
# This is a bit dodgy: what this means is that there was a
1549+
# size-like unbacked symbol whose upper bound < 2. This
1550+
# causes... problems.
1551+
if lower <= upper:
1552+
vr = ValueRanges(lower, upper)
1553+
else:
1554+
lower = vr.lower
1555+
# Don't do anything if we don't have a nontrivial lower bound
1556+
# Also don't do anything if we asked only to simplify unbacked
1557+
# SymInt
1558+
if (
1559+
lower is -int_oo or
1560+
(unbacked_only and val is not None) or
1561+
not vr.is_int
1562+
):
1563+
new_range_env[k] = vr
1564+
continue
1565+
# The goal is to take our symbols which have various lower bounds
1566+
# and reallocate them into new symbols which are exactly positive;
1567+
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
1568+
# [1, inf], where s0 = ess0 + 1. This gives the most information
1569+
# to sympy for subsequent simplifications.
1570+
#
1571+
# Positive means >= 1
1572+
# Positive - 1 means >= 0
1573+
# Positive + lower - 1 means >= lower
1574+
# The new symbol 's' is "too low", so when we substitute it in
1575+
# we have to increase it by offset (and conversely, the new
1576+
# variables have to have their value range bounds adjusted as
1577+
# well)
1578+
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
1579+
1580+
# Note:
1581+
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
1582+
# Sympy might give unexepected results when comparing an integer with a non-integer
1583+
# Therefore, we cast offset to int here.
1584+
# For example:
1585+
# shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
1586+
# expr = sympy.Eq(shape_0 - 1/3, 4)
1587+
# expr.xreplace({}) # False
1588+
offset = int(lower - 1)
1589+
new_shape_env[k] = s + offset
1590+
new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
1591+
1592+
try:
1593+
new_expr = expr.xreplace(new_shape_env)
1594+
except RecursionError:
1595+
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
1596+
return None
1597+
1598+
# We need to canonicalize, as after expand we may have something like `a + b = a` and
1599+
# sympy will not simplify the a. The two appeareances of the a will then make value ranges
1600+
# analysis give lose bounds
1601+
new_expr = canonicalize_bool_expr(safe_expand(new_expr))
1602+
if new_expr.is_number:
1603+
return new_expr
1604+
1605+
# Check if the range can solve it statically
1606+
out = bound_sympy(new_expr, new_range_env)
1607+
if out.is_singleton():
1608+
return out.lower
1609+
1610+
return new_expr if unbacked_only else None
1611+
1612+
15151613
def error():
15161614
raise AssertionError("shouldn't be hit")
15171615

@@ -4552,11 +4650,6 @@ def _maybe_evaluate_static(
45524650
# axioms with compute hint NYE
45534651
assert not compute_hint or not axioms
45544652

4555-
if var_to_range is None:
4556-
var_ranges = self.var_to_range
4557-
else:
4558-
var_ranges = dict(var_to_range)
4559-
45604653
expr = self.simplify(expr)
45614654

45624655
if compute_hint:
@@ -4575,104 +4668,23 @@ def _maybe_evaluate_static(
45754668

45764669
expr = expr.xreplace(subst)
45774670

4578-
symbols = tuple(expr.free_symbols)
4579-
4580-
# Simplify making use of value range lower bound
4581-
new_shape_env = {}
4582-
new_range_env = {}
4583-
for idx, k in enumerate(symbols):
4584-
if isinstance(self.var_to_val.get(k, None), SingletonInt):
4585-
# Skip var_ranges logic for SingletonInt which is only used
4586-
# for jagged layout NestedTensors today
4587-
continue
4588-
vr = var_ranges[k]
4589-
if size_oblivious and k in self.size_like:
4590-
lower = max(2, vr.lower)
4591-
# Clamping size-oblivious to some quantity below sys.maxsize
4592-
# helps us determine that f(u0) != sys.maxsize, which is a
4593-
# test that is looking for sys.maxsize as a sentinel, but you
4594-
# don't really want to worry about it for unbacked SymInts.
4595-
# This is similar to the flavor where size oblivious omits
4596-
# 0/1, it changes semantics but in a benign way.
4597-
upper = min(2 ** 48, vr.upper)
4598-
# This is a bit dodgy: what this means is that there was a
4599-
# size-like unbacked symbol whose upper bound < 2. This
4600-
# causes... problems.
4601-
if lower <= upper:
4602-
vr = ValueRanges(lower, upper)
4603-
else:
4604-
lower = vr.lower
4605-
# Don't do anything if we don't have a nontrivial lower bound
4606-
# Also don't do anything if we asked only to simplify unbacked
4607-
# SymInt
4608-
if (
4609-
lower is -int_oo or
4610-
(unbacked_only and k in self.var_to_val) or
4611-
not vr.is_int
4612-
):
4613-
new_range_env[k] = vr
4614-
continue
4615-
# The goal is to take our symbols which have various lower bounds
4616-
# and reallocate them into new symbols which are exactly positive;
4617-
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
4618-
# [1, inf], where s0 = ess0 + 1. This gives the most information
4619-
# to sympy for subsequent simplifications.
4620-
#
4621-
# Positive means >= 1
4622-
# Positive - 1 means >= 0
4623-
# Positive + lower - 1 means >= lower
4624-
# The new symbol 's' is "too low", so when we substitute it in
4625-
# we have to increase it by offset (and conversely, the new
4626-
# variables have to have their value range bounds adjusted as
4627-
# well)
4628-
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
4629-
4630-
# Note:
4631-
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
4632-
# Sympy might give unexepected results when comparing an integer with a non-integer
4633-
# Therefore, we cast offset to int here.
4634-
# For example:
4635-
# shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
4636-
# expr = sympy.Eq(shape_0 - 1/3, 4)
4637-
# expr.xreplace({}) # False
4638-
offset = int(lower - 1)
4639-
new_shape_env[k] = s + offset
4640-
new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
4671+
fs = expr.free_symbols
46414672

4642-
try:
4643-
new_expr = expr.xreplace(new_shape_env)
4644-
except RecursionError:
4645-
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
4646-
self.counter["sympy_recursion_error"] += 1
4647-
return None
4648-
4649-
# We need to canonicalize, as after expand we may have something like `a + b = a` and
4650-
# sympy will not simplify the a. The two appeareances of the a will then make value ranges
4651-
# analysis give lose bounds
4652-
new_expr = canonicalize_bool_expr(safe_expand(new_expr))
4653-
if new_expr.is_number:
4654-
return new_expr
4673+
if not fs and (expr.is_number or expr.is_Boolean):
4674+
return expr
46554675

4656-
# This is bad to do, the replacement with division leaves us with
4657-
# rationals when atom.args[0] is addition, e.g., sympy will happily
4658-
# turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication!
4659-
"""
4660-
floor_div_replace = {}
4661-
for atom in new_expr.atoms(FloorDiv):
4662-
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
4663-
new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
4664-
# TODO: when unbacked_only, can sometimes early return even when there
4665-
# are still free symbols
4666-
if new_expr.is_number:
4667-
return new_expr
4668-
"""
4676+
if var_to_range is None:
4677+
var_ranges = self.var_to_range
4678+
else:
4679+
var_ranges = dict(var_to_range)
46694680

4670-
# Check if the range can solve it statically
4671-
out = bound_sympy(new_expr, new_range_env)
4672-
if out.is_singleton():
4673-
return out.lower
4681+
symbol_info = tuple(
4682+
(s, var_ranges.get(s), self.var_to_val.get(s), s in self.size_like)
4683+
for s in sorted(fs, key=lambda s: str(s)) # TODO: speed up sort?
4684+
)
46744685

4675-
return new_expr if unbacked_only else None
4686+
r = _maybe_evaluate_static_worker(expr, symbol_info, unbacked_only, size_oblivious)
4687+
return r
46764688

46774689
@_lru_cache
46784690
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":

0 commit comments

Comments
 (0)