Skip to content

Commit 65ae3d5

Browse files
authored
GH-127809: Fix the JIT's understanding of ** (GH-127844)
1 parent e08b282 commit 65ae3d5

File tree

8 files changed

+199
-26
lines changed

8 files changed

+199
-26
lines changed

Lib/test/test_capi/test_opt.py

+44
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import itertools
23
import sys
34
import textwrap
45
import unittest
@@ -1511,6 +1512,49 @@ def test_jit_error_pops(self):
15111512
with self.assertRaises(TypeError):
15121513
{item for item in items}
15131514

1515+
def test_power_type_depends_on_input_values(self):
1516+
template = textwrap.dedent("""
1517+
import _testinternalcapi
1518+
1519+
L, R, X, Y = {l}, {r}, {x}, {y}
1520+
1521+
def check(actual: complex, expected: complex) -> None:
1522+
assert actual == expected, (actual, expected)
1523+
assert type(actual) is type(expected), (actual, expected)
1524+
1525+
def f(l: complex, r: complex) -> None:
1526+
expected_local_local = pow(l, r) + pow(l, r)
1527+
expected_const_local = pow(L, r) + pow(L, r)
1528+
expected_local_const = pow(l, R) + pow(l, R)
1529+
expected_const_const = pow(L, R) + pow(L, R)
1530+
for _ in range(_testinternalcapi.TIER2_THRESHOLD):
1531+
# Narrow types:
1532+
l + l, r + r
1533+
# The powers produce results, and the addition is unguarded:
1534+
check(l ** r + l ** r, expected_local_local)
1535+
check(L ** r + L ** r, expected_const_local)
1536+
check(l ** R + l ** R, expected_local_const)
1537+
check(L ** R + L ** R, expected_const_const)
1538+
1539+
# JIT for one pair of values...
1540+
f(L, R)
1541+
# ...then run with another:
1542+
f(X, Y)
1543+
""")
1544+
interesting = [
1545+
(1, 1), # int ** int -> int
1546+
(1, -1), # int ** int -> float
1547+
(1.0, 1), # float ** int -> float
1548+
(1, 1.0), # int ** float -> float
1549+
(-1, 0.5), # int ** float -> complex
1550+
(1.0, 1.0), # float ** float -> float
1551+
(-1.0, 0.5), # float ** float -> complex
1552+
]
1553+
for (l, r), (x, y) in itertools.product(interesting, repeat=2):
1554+
s = template.format(l=l, r=r, x=x, y=y)
1555+
with self.subTest(l=l, r=r, x=x, y=y):
1556+
script_helper.assert_python_ok("-c", s)
1557+
15141558

15151559
def global_identity(x):
15161560
return x
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix an issue where the experimental JIT may infer an incorrect result type
2+
for exponentiation (``**`` and ``**=``), leading to bugs or crashes.

Python/bytecodes.c

+16
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,8 @@ dummy_func(
530530
pure op(_BINARY_OP_MULTIPLY_INT, (left, right -- res)) {
531531
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
532532
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
533+
assert(PyLong_CheckExact(left_o));
534+
assert(PyLong_CheckExact(right_o));
533535

534536
STAT_INC(BINARY_OP, hit);
535537
PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
@@ -543,6 +545,8 @@ dummy_func(
543545
pure op(_BINARY_OP_ADD_INT, (left, right -- res)) {
544546
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
545547
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
548+
assert(PyLong_CheckExact(left_o));
549+
assert(PyLong_CheckExact(right_o));
546550

547551
STAT_INC(BINARY_OP, hit);
548552
PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
@@ -556,6 +560,8 @@ dummy_func(
556560
pure op(_BINARY_OP_SUBTRACT_INT, (left, right -- res)) {
557561
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
558562
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
563+
assert(PyLong_CheckExact(left_o));
564+
assert(PyLong_CheckExact(right_o));
559565

560566
STAT_INC(BINARY_OP, hit);
561567
PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
@@ -593,6 +599,8 @@ dummy_func(
593599
pure op(_BINARY_OP_MULTIPLY_FLOAT, (left, right -- res)) {
594600
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
595601
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
602+
assert(PyFloat_CheckExact(left_o));
603+
assert(PyFloat_CheckExact(right_o));
596604

597605
STAT_INC(BINARY_OP, hit);
598606
double dres =
@@ -607,6 +615,8 @@ dummy_func(
607615
pure op(_BINARY_OP_ADD_FLOAT, (left, right -- res)) {
608616
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
609617
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
618+
assert(PyFloat_CheckExact(left_o));
619+
assert(PyFloat_CheckExact(right_o));
610620

611621
STAT_INC(BINARY_OP, hit);
612622
double dres =
@@ -621,6 +631,8 @@ dummy_func(
621631
pure op(_BINARY_OP_SUBTRACT_FLOAT, (left, right -- res)) {
622632
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
623633
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
634+
assert(PyFloat_CheckExact(left_o));
635+
assert(PyFloat_CheckExact(right_o));
624636

625637
STAT_INC(BINARY_OP, hit);
626638
double dres =
@@ -650,6 +662,8 @@ dummy_func(
650662
pure op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) {
651663
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
652664
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
665+
assert(PyUnicode_CheckExact(left_o));
666+
assert(PyUnicode_CheckExact(right_o));
653667

654668
STAT_INC(BINARY_OP, hit);
655669
PyObject *res_o = PyUnicode_Concat(left_o, right_o);
@@ -672,6 +686,8 @@ dummy_func(
672686
op(_BINARY_OP_INPLACE_ADD_UNICODE, (left, right --)) {
673687
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
674688
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
689+
assert(PyUnicode_CheckExact(left_o));
690+
assert(PyUnicode_CheckExact(right_o));
675691

676692
int next_oparg;
677693
#if TIER_ONE

Python/executor_cases.c.h

+16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/generated_cases.c.h

+16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/optimizer_bytecodes.c

+45-12
Original file line numberDiff line numberDiff line change
@@ -167,23 +167,56 @@ dummy_func(void) {
167167
}
168168

169169
op(_BINARY_OP, (left, right -- res)) {
170-
PyTypeObject *ltype = sym_get_type(left);
171-
PyTypeObject *rtype = sym_get_type(right);
172-
if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
173-
rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
174-
{
175-
if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
176-
ltype == &PyLong_Type && rtype == &PyLong_Type) {
177-
/* If both inputs are ints and the op is not division the result is an int */
178-
res = sym_new_type(ctx, &PyLong_Type);
170+
bool lhs_int = sym_matches_type(left, &PyLong_Type);
171+
bool rhs_int = sym_matches_type(right, &PyLong_Type);
172+
bool lhs_float = sym_matches_type(left, &PyFloat_Type);
173+
bool rhs_float = sym_matches_type(right, &PyFloat_Type);
174+
if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) {
175+
// There's something other than an int or float involved:
176+
res = sym_new_unknown(ctx);
177+
}
178+
else if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
179+
// This one's fun... the *type* of the result depends on the
180+
// *values* being exponentiated. However, exponents with one
181+
// constant part are reasonably common, so it's probably worth
182+
// trying to infer some simple cases:
183+
// - A: 1 ** 1 -> 1 (int ** int -> int)
184+
// - B: 1 ** -1 -> 1.0 (int ** int -> float)
185+
// - C: 1.0 ** 1 -> 1.0 (float ** int -> float)
186+
// - D: 1 ** 1.0 -> 1.0 (int ** float -> float)
187+
// - E: -1 ** 0.5 ~> 1j (int ** float -> complex)
188+
// - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float)
189+
// - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex)
190+
if (rhs_float) {
191+
// Case D, E, F, or G... can't know without the sign of the LHS
192+
// or whether the RHS is whole, which isn't worth the effort:
193+
res = sym_new_unknown(ctx);
179194
}
180-
else {
181-
/* For any other op combining ints/floats the result is a float */
195+
else if (lhs_float) {
196+
// Case C:
182197
res = sym_new_type(ctx, &PyFloat_Type);
183198
}
199+
else if (!sym_is_const(right)) {
200+
// Case A or B... can't know without the sign of the RHS:
201+
res = sym_new_unknown(ctx);
202+
}
203+
else if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
204+
// Case B:
205+
res = sym_new_type(ctx, &PyFloat_Type);
206+
}
207+
else {
208+
// Case A:
209+
res = sym_new_type(ctx, &PyLong_Type);
210+
}
211+
}
212+
else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
213+
res = sym_new_type(ctx, &PyFloat_Type);
214+
}
215+
else if (lhs_int && rhs_int) {
216+
res = sym_new_type(ctx, &PyLong_Type);
184217
}
185218
else {
186-
res = sym_new_unknown(ctx);
219+
res = sym_new_type(ctx, &PyFloat_Type);
187220
}
188221
}
189222

0 commit comments

Comments
 (0)