Skip to content

Commit f49a88f

Browse files
authored
[mypyc] Simplify IR generated for "for" loops over strings (#19434)
Add unsafe list get item primitive. The new primitive just calls the primary get item primitive, but we could later provide an optimized primitive if this turns out to be a performance bottleneck.
1 parent db67888 commit f49a88f

File tree

5 files changed

+30
-26
lines changed

5 files changed

+30
-26
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from mypyc.primitives.misc_ops import stop_async_iteration_op
7777
from mypyc.primitives.registry import CFunctionDescription
7878
from mypyc.primitives.set_ops import set_add_op
79+
from mypyc.primitives.str_ops import str_get_item_unsafe_op
7980
from mypyc.primitives.tuple_ops import tuple_get_item_unsafe_op
8081

8182
GenFunc = Callable[[], None]
@@ -772,6 +773,8 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) ->
772773
return builder.primitive_op(list_get_item_unsafe_op, [target, index], line)
773774
elif is_tuple_rprimitive(target.type):
774775
return builder.call_c(tuple_get_item_unsafe_op, [target, index], line)
776+
elif is_str_rprimitive(target.type):
777+
return builder.call_c(str_get_item_unsafe_op, [target, index], line)
775778
else:
776779
return builder.gen_method_call(target, "__getitem__", [index], None, line)
777780

mypyc/lib-rt/CPy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) {
727727
char CPyStr_Equal(PyObject *str1, PyObject *str2);
728728
PyObject *CPyStr_Build(Py_ssize_t len, ...);
729729
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
730+
PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index);
730731
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
731732
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
732733
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);

mypyc/lib-rt/str_ops.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
117117
}
118118
}
119119

120+
PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index) {
121+
// This is unsafe since we don't check for overflow when doing <<.
122+
return CPyStr_GetItem(str, index << 1);
123+
}
124+
120125
// A simplification of _PyUnicode_JoinArray() from CPython 3.9.6
121126
PyObject *CPyStr_Build(Py_ssize_t len, ...) {
122127
Py_ssize_t i;

mypyc/primitives/str_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@
9595
error_kind=ERR_MAGIC,
9696
)
9797

98+
# This is unsafe since it assumes that the index is within reasonable bounds.
99+
# In the future this might do no bounds checking at all.
100+
str_get_item_unsafe_op = custom_op(
101+
arg_types=[str_rprimitive, c_pyssize_t_rprimitive],
102+
return_type=str_rprimitive,
103+
c_function_name="CPyStr_GetItemUnsafe",
104+
error_kind=ERR_MAGIC,
105+
)
106+
98107
# str[begin:end]
99108
str_slice_op = custom_op(
100109
arg_types=[str_rprimitive, int_rprimitive, int_rprimitive],

mypyc/test-data/irbuild-tuple.test

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ L4:
272272
a = r6
273273
return 1
274274

275-
[case testTupleBuiltFromStr_64bit]
275+
[case testTupleBuiltFromStr]
276276
def f2(val: str) -> str:
277277
return val + "f2"
278278

@@ -292,10 +292,9 @@ def test():
292292
r2 :: bit
293293
r3 :: tuple
294294
r4, r5 :: native_int
295-
r6, r7, r8, r9 :: bit
296-
r10, r11, r12 :: int
297-
r13, x, r14 :: str
298-
r15 :: native_int
295+
r6, r7 :: bit
296+
r8, x, r9 :: str
297+
r10 :: native_int
299298
a :: tuple
300299
L0:
301300
r0 = 'abc'
@@ -308,30 +307,17 @@ L1:
308307
r5 = CPyStr_Size_size_t(source)
309308
r6 = r5 >= 0 :: signed
310309
r7 = r4 < r5 :: signed
311-
if r7 goto L2 else goto L8 :: bool
310+
if r7 goto L2 else goto L4 :: bool
312311
L2:
313-
r8 = r4 <= 4611686018427387903 :: signed
314-
if r8 goto L3 else goto L4 :: bool
312+
r8 = CPyStr_GetItemUnsafe(source, r4)
313+
x = r8
314+
r9 = f2(x)
315+
CPySequenceTuple_SetItemUnsafe(r3, r4, r9)
315316
L3:
316-
r9 = r4 >= -4611686018427387904 :: signed
317-
if r9 goto L5 else goto L4 :: bool
318-
L4:
319-
r10 = CPyTagged_FromInt64(r4)
320-
r11 = r10
321-
goto L6
322-
L5:
323-
r12 = r4 << 1
324-
r11 = r12
325-
L6:
326-
r13 = CPyStr_GetItem(source, r11)
327-
x = r13
328-
r14 = f2(x)
329-
CPySequenceTuple_SetItemUnsafe(r3, r4, r14)
330-
L7:
331-
r15 = r4 + 1
332-
r4 = r15
317+
r10 = r4 + 1
318+
r4 = r10
333319
goto L1
334-
L8:
320+
L4:
335321
a = r3
336322
return 1
337323

0 commit comments

Comments
 (0)