|
| 1 | +import threading |
1 | 2 | import unittest |
2 | 3 |
|
3 | 4 | from threading import Thread, Barrier |
@@ -171,5 +172,125 @@ def size_function(): |
171 | 172 | pass |
172 | 173 |
|
173 | 174 |
|
| 175 | + def test_richcompare_stale_element_list1(self) -> None: |
| 176 | + # gh-148442: list_richcompare_impl must keep references to the |
| 177 | + # captured items for the final ordering comparison, not re-read the |
| 178 | + # list slots after the critical section may have been suspended. |
| 179 | + # |
| 180 | + # list1 = [x, 1], list2 = [0, 0] with x > 0. |
| 181 | + # During list1 > list2, list1[0] is mutated from x to 0. |
| 182 | + # The result must be True (x > 0), not False (0 > 0). |
| 183 | + |
| 184 | + class Value: |
| 185 | + def __init__(self, v: int) -> None: |
| 186 | + self.v = v |
| 187 | + |
| 188 | + def __eq__(self, other: object) -> bool: |
| 189 | + if isinstance(other, Value): |
| 190 | + return self.v == other.v |
| 191 | + return NotImplemented |
| 192 | + |
| 193 | + def __gt__(self, other: object) -> bool: |
| 194 | + if isinstance(other, Value): |
| 195 | + return self.v > other.v |
| 196 | + return NotImplemented |
| 197 | + |
| 198 | + def __hash__(self) -> int: |
| 199 | + return hash(self.v) |
| 200 | + |
| 201 | + eq_started = threading.Event() |
| 202 | + swap_done = threading.Event() |
| 203 | + |
| 204 | + class SignalingValue(Value): |
| 205 | + # Value whose __eq__ signals eq_started then waits for swap_done, |
| 206 | + # giving another thread the window to mutate the list slot. |
| 207 | + def __eq__(self, other: object) -> bool: |
| 208 | + eq_started.set() |
| 209 | + swap_done.wait() |
| 210 | + return super().__eq__(other) |
| 211 | + |
| 212 | + x = SignalingValue(5) # list1[0]; x > 0 |
| 213 | + list1: list[Value] = [x, Value(1)] |
| 214 | + list2: list[Value] = [Value(0), Value(0)] |
| 215 | + |
| 216 | + result: list[bool] = [] |
| 217 | + |
| 218 | + def compare() -> None: |
| 219 | + result.append(list1 > list2) |
| 220 | + |
| 221 | + def swap() -> None: |
| 222 | + eq_started.wait() |
| 223 | + list1[0] = Value(0) # replace x(5) with 0 -- must not change result |
| 224 | + swap_done.set() |
| 225 | + |
| 226 | + t_cmp = Thread(target=compare) |
| 227 | + t_swp = Thread(target=swap) |
| 228 | + t_cmp.start() |
| 229 | + t_swp.start() |
| 230 | + t_cmp.join() |
| 231 | + t_swp.join() |
| 232 | + |
| 233 | + # x(5) > Value(0) is True; old code would compare Value(0) > Value(0) -> False |
| 234 | + self.assertTrue(result[0]) |
| 235 | + |
| 236 | + def test_richcompare_stale_element_list2(self) -> None: |
| 237 | + # Same as test_richcompare_stale_element_list1 but list2[0] is mutated |
| 238 | + # to a value larger than x, which would flip the result under the old code. |
| 239 | + # |
| 240 | + # list1 = [x, 1], list2 = [0, 0] with x > 0. |
| 241 | + # During list1 > list2, list2[0] is mutated from 0 to 100 (> x). |
| 242 | + # The result must be True (x > 0), not False (x > 100). |
| 243 | + |
| 244 | + class Value: |
| 245 | + def __init__(self, v: int) -> None: |
| 246 | + self.v = v |
| 247 | + |
| 248 | + def __eq__(self, other: object) -> bool: |
| 249 | + if isinstance(other, Value): |
| 250 | + return self.v == other.v |
| 251 | + return NotImplemented |
| 252 | + |
| 253 | + def __gt__(self, other: object) -> bool: |
| 254 | + if isinstance(other, Value): |
| 255 | + return self.v > other.v |
| 256 | + return NotImplemented |
| 257 | + |
| 258 | + def __hash__(self) -> int: |
| 259 | + return hash(self.v) |
| 260 | + |
| 261 | + eq_started = threading.Event() |
| 262 | + swap_done = threading.Event() |
| 263 | + |
| 264 | + class SignalingValue(Value): |
| 265 | + def __eq__(self, other: object) -> bool: |
| 266 | + eq_started.set() |
| 267 | + swap_done.wait() |
| 268 | + return super().__eq__(other) |
| 269 | + |
| 270 | + x = SignalingValue(5) # list1[0]; x > 0 |
| 271 | + list1: list[Value] = [x, Value(1)] |
| 272 | + list2: list[Value] = [Value(0), Value(0)] |
| 273 | + |
| 274 | + result: list[bool] = [] |
| 275 | + |
| 276 | + def compare() -> None: |
| 277 | + result.append(list1 > list2) |
| 278 | + |
| 279 | + def swap() -> None: |
| 280 | + eq_started.wait() |
| 281 | + list2[0] = Value(100) # replace 0 with 100 (> x) -- must not change result |
| 282 | + swap_done.set() |
| 283 | + |
| 284 | + t_cmp = Thread(target=compare) |
| 285 | + t_swp = Thread(target=swap) |
| 286 | + t_cmp.start() |
| 287 | + t_swp.start() |
| 288 | + t_cmp.join() |
| 289 | + t_swp.join() |
| 290 | + |
| 291 | + # x(5) > Value(0) is True; old code would compare Value(5) > Value(100) -> False |
| 292 | + self.assertTrue(result[0]) |
| 293 | + |
| 294 | + |
174 | 295 | if __name__ == "__main__": |
175 | 296 | unittest.main() |
0 commit comments