Skip to content

Commit b4ebf34

Browse files
test: add regression test
1 parent 51ddb47 commit b4ebf34

1 file changed

Lines changed: 121 additions & 0 deletions

File tree

Lib/test/test_free_threading/test_list.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
import unittest
23

34
from threading import Thread, Barrier
@@ -171,5 +172,125 @@ def size_function():
171172
pass
172173

173174

175+
def test_richcompare_stale_element_list1(self) -> None:
176+
# gh-148442: list_richcompare_impl must keep references to the
177+
# captured items for the final ordering comparison, not re-read the
178+
# list slots after the critical section may have been suspended.
179+
#
180+
# list1 = [x, 1], list2 = [0, 0] with x > 0.
181+
# During list1 > list2, list1[0] is mutated from x to 0.
182+
# The result must be True (x > 0), not False (0 > 0).
183+
184+
class Value:
185+
def __init__(self, v: int) -> None:
186+
self.v = v
187+
188+
def __eq__(self, other: object) -> bool:
189+
if isinstance(other, Value):
190+
return self.v == other.v
191+
return NotImplemented
192+
193+
def __gt__(self, other: object) -> bool:
194+
if isinstance(other, Value):
195+
return self.v > other.v
196+
return NotImplemented
197+
198+
def __hash__(self) -> int:
199+
return hash(self.v)
200+
201+
eq_started = threading.Event()
202+
swap_done = threading.Event()
203+
204+
class SignalingValue(Value):
205+
# Value whose __eq__ signals eq_started then waits for swap_done,
206+
# giving another thread the window to mutate the list slot.
207+
def __eq__(self, other: object) -> bool:
208+
eq_started.set()
209+
swap_done.wait()
210+
return super().__eq__(other)
211+
212+
x = SignalingValue(5) # list1[0]; x > 0
213+
list1: list[Value] = [x, Value(1)]
214+
list2: list[Value] = [Value(0), Value(0)]
215+
216+
result: list[bool] = []
217+
218+
def compare() -> None:
219+
result.append(list1 > list2)
220+
221+
def swap() -> None:
222+
eq_started.wait()
223+
list1[0] = Value(0) # replace x(5) with 0 -- must not change result
224+
swap_done.set()
225+
226+
t_cmp = Thread(target=compare)
227+
t_swp = Thread(target=swap)
228+
t_cmp.start()
229+
t_swp.start()
230+
t_cmp.join()
231+
t_swp.join()
232+
233+
# x(5) > Value(0) is True; old code would compare Value(0) > Value(0) -> False
234+
self.assertTrue(result[0])
235+
236+
def test_richcompare_stale_element_list2(self) -> None:
237+
# Same as test_richcompare_stale_element_list1 but list2[0] is mutated
238+
# to a value larger than x, which would flip the result under the old code.
239+
#
240+
# list1 = [x, 1], list2 = [0, 0] with x > 0.
241+
# During list1 > list2, list2[0] is mutated from 0 to 100 (> x).
242+
# The result must be True (x > 0), not False (x > 100).
243+
244+
class Value:
245+
def __init__(self, v: int) -> None:
246+
self.v = v
247+
248+
def __eq__(self, other: object) -> bool:
249+
if isinstance(other, Value):
250+
return self.v == other.v
251+
return NotImplemented
252+
253+
def __gt__(self, other: object) -> bool:
254+
if isinstance(other, Value):
255+
return self.v > other.v
256+
return NotImplemented
257+
258+
def __hash__(self) -> int:
259+
return hash(self.v)
260+
261+
eq_started = threading.Event()
262+
swap_done = threading.Event()
263+
264+
class SignalingValue(Value):
265+
def __eq__(self, other: object) -> bool:
266+
eq_started.set()
267+
swap_done.wait()
268+
return super().__eq__(other)
269+
270+
x = SignalingValue(5) # list1[0]; x > 0
271+
list1: list[Value] = [x, Value(1)]
272+
list2: list[Value] = [Value(0), Value(0)]
273+
274+
result: list[bool] = []
275+
276+
def compare() -> None:
277+
result.append(list1 > list2)
278+
279+
def swap() -> None:
280+
eq_started.wait()
281+
list2[0] = Value(100) # replace 0 with 100 (> x) -- must not change result
282+
swap_done.set()
283+
284+
t_cmp = Thread(target=compare)
285+
t_swp = Thread(target=swap)
286+
t_cmp.start()
287+
t_swp.start()
288+
t_cmp.join()
289+
t_swp.join()
290+
291+
# x(5) > Value(0) is True; old code would compare Value(5) > Value(100) -> False
292+
self.assertTrue(result[0])
293+
294+
174295
if __name__ == "__main__":
175296
unittest.main()

0 commit comments

Comments
 (0)