@@ -1512,6 +1512,104 @@ def safe_expand(r):
1512
1512
else :
1513
1513
return r
1514
1514
1515
+
1516
+ @lru_cache (None )
1517
+ def _maybe_evaluate_static_worker (
1518
+ expr : sympy .Expr ,
1519
+ symbol_info : Tuple [Tuple [sympy .Symbol , ValueRanges , sympy .Integer , bool ], ...],
1520
+ unbacked_only : bool ,
1521
+ size_oblivious : bool
1522
+ ):
1523
+ """
1524
+ This variant of ShapeEnv._maybe_evaluate_static has no dependence on
1525
+ ShapeEnv and thus can be cached indefinitely. It does the "heavy" lifting
1526
+ for static evaluation, including nontrivial reliance on Sympy simplification
1527
+ that occurs when we reallocate the symbols
1528
+ """
1529
+
1530
+ # Simplify making use of value range lower bound
1531
+ new_shape_env = {}
1532
+ new_range_env = {}
1533
+ for idx , sinfo in enumerate (symbol_info ):
1534
+ k , vr , val , is_size_like = sinfo
1535
+ if isinstance (val , SingletonInt ):
1536
+ # Skip var_ranges logic for SingletonInt which is only used
1537
+ # for jagged layout NestedTensors today
1538
+ continue
1539
+ if size_oblivious and is_size_like :
1540
+ lower = max (2 , vr .lower )
1541
+ # Clamping size-oblivious to some quantity below sys.maxsize
1542
+ # helps us determine that f(u0) != sys.maxsize, which is a
1543
+ # test that is looking for sys.maxsize as a sentinel, but you
1544
+ # don't really want to worry about it for unbacked SymInts.
1545
+ # This is similar to the flavor where size oblivious omits
1546
+ # 0/1, it changes semantics but in a benign way.
1547
+ upper = min (2 ** 48 , vr .upper )
1548
+ # This is a bit dodgy: what this means is that there was a
1549
+ # size-like unbacked symbol whose upper bound < 2. This
1550
+ # causes... problems.
1551
+ if lower <= upper :
1552
+ vr = ValueRanges (lower , upper )
1553
+ else :
1554
+ lower = vr .lower
1555
+ # Don't do anything if we don't have a nontrivial lower bound
1556
+ # Also don't do anything if we asked only to simplify unbacked
1557
+ # SymInt
1558
+ if (
1559
+ lower is - int_oo or
1560
+ (unbacked_only and val is not None ) or
1561
+ not vr .is_int
1562
+ ):
1563
+ new_range_env [k ] = vr
1564
+ continue
1565
+ # The goal is to take our symbols which have various lower bounds
1566
+ # and reallocate them into new symbols which are exactly positive;
1567
+ # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
1568
+ # [1, inf], where s0 = ess0 + 1. This gives the most information
1569
+ # to sympy for subsequent simplifications.
1570
+ #
1571
+ # Positive means >= 1
1572
+ # Positive - 1 means >= 0
1573
+ # Positive + lower - 1 means >= lower
1574
+ # The new symbol 's' is "too low", so when we substitute it in
1575
+ # we have to increase it by offset (and conversely, the new
1576
+ # variables have to have their value range bounds adjusted as
1577
+ # well)
1578
+ s = sympy .Symbol (f"evaluate_static_shape_{ idx } " , positive = True , integer = True )
1579
+
1580
+ # Note:
1581
+ # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
1582
+ # Sympy might give unexepected results when comparing an integer with a non-integer
1583
+ # Therefore, we cast offset to int here.
1584
+ # For example:
1585
+ # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
1586
+ # expr = sympy.Eq(shape_0 - 1/3, 4)
1587
+ # expr.xreplace({}) # False
1588
+ offset = int (lower - 1 )
1589
+ new_shape_env [k ] = s + offset
1590
+ new_range_env [s ] = SymPyValueRangeAnalysis .add (vr , - offset )
1591
+
1592
+ try :
1593
+ new_expr = expr .xreplace (new_shape_env )
1594
+ except RecursionError :
1595
+ log .warning ("RecursionError in sympy.xreplace(%s, %s)" , expr , new_shape_env )
1596
+ return None
1597
+
1598
+ # We need to canonicalize, as after expand we may have something like `a + b = a` and
1599
+ # sympy will not simplify the a. The two appeareances of the a will then make value ranges
1600
+ # analysis give lose bounds
1601
+ new_expr = canonicalize_bool_expr (safe_expand (new_expr ))
1602
+ if new_expr .is_number :
1603
+ return new_expr
1604
+
1605
+ # Check if the range can solve it statically
1606
+ out = bound_sympy (new_expr , new_range_env )
1607
+ if out .is_singleton ():
1608
+ return out .lower
1609
+
1610
+ return new_expr if unbacked_only else None
1611
+
1612
+
1515
1613
def error ():
1516
1614
raise AssertionError ("shouldn't be hit" )
1517
1615
@@ -4552,11 +4650,6 @@ def _maybe_evaluate_static(
4552
4650
# axioms with compute hint NYE
4553
4651
assert not compute_hint or not axioms
4554
4652
4555
- if var_to_range is None :
4556
- var_ranges = self .var_to_range
4557
- else :
4558
- var_ranges = dict (var_to_range )
4559
-
4560
4653
expr = self .simplify (expr )
4561
4654
4562
4655
if compute_hint :
@@ -4575,104 +4668,23 @@ def _maybe_evaluate_static(
4575
4668
4576
4669
expr = expr .xreplace (subst )
4577
4670
4578
- symbols = tuple (expr .free_symbols )
4579
-
4580
- # Simplify making use of value range lower bound
4581
- new_shape_env = {}
4582
- new_range_env = {}
4583
- for idx , k in enumerate (symbols ):
4584
- if isinstance (self .var_to_val .get (k , None ), SingletonInt ):
4585
- # Skip var_ranges logic for SingletonInt which is only used
4586
- # for jagged layout NestedTensors today
4587
- continue
4588
- vr = var_ranges [k ]
4589
- if size_oblivious and k in self .size_like :
4590
- lower = max (2 , vr .lower )
4591
- # Clamping size-oblivious to some quantity below sys.maxsize
4592
- # helps us determine that f(u0) != sys.maxsize, which is a
4593
- # test that is looking for sys.maxsize as a sentinel, but you
4594
- # don't really want to worry about it for unbacked SymInts.
4595
- # This is similar to the flavor where size oblivious omits
4596
- # 0/1, it changes semantics but in a benign way.
4597
- upper = min (2 ** 48 , vr .upper )
4598
- # This is a bit dodgy: what this means is that there was a
4599
- # size-like unbacked symbol whose upper bound < 2. This
4600
- # causes... problems.
4601
- if lower <= upper :
4602
- vr = ValueRanges (lower , upper )
4603
- else :
4604
- lower = vr .lower
4605
- # Don't do anything if we don't have a nontrivial lower bound
4606
- # Also don't do anything if we asked only to simplify unbacked
4607
- # SymInt
4608
- if (
4609
- lower is - int_oo or
4610
- (unbacked_only and k in self .var_to_val ) or
4611
- not vr .is_int
4612
- ):
4613
- new_range_env [k ] = vr
4614
- continue
4615
- # The goal is to take our symbols which have various lower bounds
4616
- # and reallocate them into new symbols which are exactly positive;
4617
- # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
4618
- # [1, inf], where s0 = ess0 + 1. This gives the most information
4619
- # to sympy for subsequent simplifications.
4620
- #
4621
- # Positive means >= 1
4622
- # Positive - 1 means >= 0
4623
- # Positive + lower - 1 means >= lower
4624
- # The new symbol 's' is "too low", so when we substitute it in
4625
- # we have to increase it by offset (and conversely, the new
4626
- # variables have to have their value range bounds adjusted as
4627
- # well)
4628
- s = sympy .Symbol (f"evaluate_static_shape_{ idx } " , positive = True , integer = True )
4629
-
4630
- # Note:
4631
- # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
4632
- # Sympy might give unexepected results when comparing an integer with a non-integer
4633
- # Therefore, we cast offset to int here.
4634
- # For example:
4635
- # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
4636
- # expr = sympy.Eq(shape_0 - 1/3, 4)
4637
- # expr.xreplace({}) # False
4638
- offset = int (lower - 1 )
4639
- new_shape_env [k ] = s + offset
4640
- new_range_env [s ] = SymPyValueRangeAnalysis .add (vr , - offset )
4671
+ fs = expr .free_symbols
4641
4672
4642
- try :
4643
- new_expr = expr .xreplace (new_shape_env )
4644
- except RecursionError :
4645
- log .warning ("RecursionError in sympy.xreplace(%s, %s)" , expr , new_shape_env )
4646
- self .counter ["sympy_recursion_error" ] += 1
4647
- return None
4648
-
4649
- # We need to canonicalize, as after expand we may have something like `a + b = a` and
4650
- # sympy will not simplify the a. The two appeareances of the a will then make value ranges
4651
- # analysis give lose bounds
4652
- new_expr = canonicalize_bool_expr (safe_expand (new_expr ))
4653
- if new_expr .is_number :
4654
- return new_expr
4673
+ if not fs and (expr .is_number or expr .is_Boolean ):
4674
+ return expr
4655
4675
4656
- # This is bad to do, the replacement with division leaves us with
4657
- # rationals when atom.args[0] is addition, e.g., sympy will happily
4658
- # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication!
4659
- """
4660
- floor_div_replace = {}
4661
- for atom in new_expr.atoms(FloorDiv):
4662
- floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
4663
- new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
4664
- # TODO: when unbacked_only, can sometimes early return even when there
4665
- # are still free symbols
4666
- if new_expr.is_number:
4667
- return new_expr
4668
- """
4676
+ if var_to_range is None :
4677
+ var_ranges = self .var_to_range
4678
+ else :
4679
+ var_ranges = dict (var_to_range )
4669
4680
4670
- # Check if the range can solve it statically
4671
- out = bound_sympy ( new_expr , new_range_env )
4672
- if out . is_singleton ():
4673
- return out . lower
4681
+ symbol_info = tuple (
4682
+ ( s , var_ranges . get ( s ), self . var_to_val . get ( s ), s in self . size_like )
4683
+ for s in sorted ( fs , key = lambda s : str ( s )) # TODO: speed up sort?
4684
+ )
4674
4685
4675
- return new_expr if unbacked_only else None
4686
+ r = _maybe_evaluate_static_worker (expr , symbol_info , unbacked_only , size_oblivious )
4687
+ return r
4676
4688
4677
4689
@_lru_cache
4678
4690
def replace (self , expr : "sympy.Expr" ) -> "sympy.Expr" :
0 commit comments