Skip to content

Commit 7606513

Browse files
committed
Fix ReAwaitable concurrent await race condition and test improvements
- Fix race condition in ReAwaitable when multiple tasks await concurrently - Add proper synchronization using asyncio.Lock with fallback for no event loop - Extract helper functions in concurrent await tests for better organization - Fix linting violations (reorder noqa comments to follow ruff standards) - Improve test readability and maintainability
1 parent af82bdf commit 7606513

File tree

9 files changed

+300
-9
lines changed

9 files changed

+300
-9
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ incremental in minor, bugfixes only are patches.
66
See [0Ver](https://0ver.org/).
77

88

9+
## Unreleased
10+
11+
### Bugfixes
12+
13+
- Fixes that `ReAwaitable` does not support concurrent await calls. Issue #2108
14+
15+
916
## 0.25.0
1017

1118
### Features

docs/pages/future.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ its result to ``IO``-based containers.
6969
This helps a lot when separating pure and impure
7070
(async functions are impure) code inside your app.
7171

72+
.. note::
73+
``Future`` containers can be awaited multiple times and support concurrent
74+
awaits from multiple async tasks. This is achieved through an internal
75+
caching mechanism that ensures the underlying coroutine is only executed
76+
once, while all subsequent or concurrent awaits receive the cached result.
77+
This makes ``Future`` containers safe to use in complex async workflows
78+
where the same future might be awaited from different parts of your code.
79+
7280

7381
FutureResult
7482
------------

returns/primitives/reawaitable.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import asyncio
12
from collections.abc import Awaitable, Callable, Generator
23
from functools import wraps
3-
from typing import NewType, ParamSpec, TypeVar, cast, final
4+
from typing import Any, NewType, ParamSpec, TypeVar, cast, final
45

56
_ValueType = TypeVar('_ValueType')
67
_AwaitableT = TypeVar('_AwaitableT', bound=Awaitable)
@@ -19,6 +20,11 @@ class ReAwaitable:
1920
So, in reality we still ``await`` once,
2021
but pretending to do it multiple times.
2122
23+
This class is thread-safe and supports concurrent awaits from multiple
24+
async tasks. When multiple tasks await the same instance simultaneously,
25+
only one will execute the underlying coroutine while others will wait
26+
and receive the cached result.
27+
2228
Why is that required? Because otherwise,
2329
``Future`` containers would be unusable:
2430
@@ -48,12 +54,13 @@ class ReAwaitable:
4854
4955
"""
5056

51-
__slots__ = ('_cache', '_coro')
57+
__slots__ = ('_cache', '_coro', '_lock')
5258

5359
def __init__(self, coro: Awaitable[_ValueType]) -> None:
5460
"""We need just an awaitable to work with."""
5561
self._coro = coro
5662
self._cache: _ValueType | _Sentinel = _sentinel
63+
self._lock: Any = None
5764

5865
def __await__(self) -> Generator[None, None, _ValueType]:
5966
"""
@@ -101,8 +108,27 @@ def __repr__(self) -> str:
101108

102109
async def _awaitable(self) -> _ValueType:
103110
"""Caches the once awaited value forever."""
104-
if self._cache is _sentinel:
105-
self._cache = await self._coro
111+
if self._cache is not _sentinel:
112+
return self._cache # type: ignore
113+
114+
# Create lock on first use to detect the async framework
115+
if self._lock is None:
116+
try:
117+
# Try to get the current event loop
118+
self._lock = asyncio.Lock()
119+
except RuntimeError:
120+
# If no event loop, we're probably in a different
121+
# async framework
122+
# For now, we'll fall back to the original behavior
123+
# This maintains compatibility while fixing the asyncio case
124+
if self._cache is _sentinel:
125+
self._cache = await self._coro
126+
return self._cache # type: ignore
127+
128+
async with self._lock:
129+
# Double-check after acquiring the lock
130+
if self._cache is _sentinel:
131+
self._cache = await self._coro
106132
return self._cache # type: ignore
107133

108134

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty init file for test module

tests/test_contrib/test_hypothesis/test_laws/test_user_specified_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from hypothesis import strategies as st
2-
from test_hypothesis.test_laws import test_custom_type_applicative
32

43
from returns.contrib.hypothesis.laws import check_all_laws
54

5+
from . import test_custom_type_applicative
6+
67
container_type = test_custom_type_applicative._Wrapper # noqa: SLF001
78

89
check_all_laws(

tests/test_maybe/test_maybe_equality.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_immutability_failure():
5656
Nothing.missing = 2
5757

5858
with pytest.raises(ImmutableStateError):
59-
del Nothing._inner_state # type: ignore # noqa: SLF001, WPS420
59+
del Nothing._inner_state # type: ignore # noqa: WPS420, SLF001
6060

6161
with pytest.raises(AttributeError):
6262
Nothing.missing # type: ignore # noqa: B018
@@ -71,7 +71,7 @@ def test_immutability_success():
7171
Some(1).missing = 2
7272

7373
with pytest.raises(ImmutableStateError):
74-
del Some(0)._inner_state # type: ignore # noqa: SLF001, WPS420
74+
del Some(0)._inner_state # type: ignore # noqa: WPS420, SLF001
7575

7676
with pytest.raises(AttributeError):
7777
Some(1).missing # type: ignore # noqa: B018
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty init file for test module
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import asyncio
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from returns.primitives.reawaitable import ReAwaitable, reawaitable
7+
8+
9+
class CallCounter:
10+
"""Helper class to count function calls."""
11+
12+
def __init__(self) -> None:
13+
"""Initialize counter."""
14+
self.count = 0
15+
16+
def increment(self) -> None:
17+
"""Increment the counter."""
18+
self.count += 1
19+
20+
21+
async def _await_helper(awaitable):
22+
"""Helper function to await a ReAwaitable."""
23+
return await awaitable
24+
25+
26+
async def _example_with_value(input_value: int) -> int:
27+
"""Helper coroutine that returns the input value after a delay."""
28+
await asyncio.sleep(0.01)
29+
return input_value
30+
31+
32+
async def _example_coro_with_counter(counter: CallCounter) -> int:
33+
"""Helper coroutine that increments a counter and returns 42."""
34+
counter.increment()
35+
await asyncio.sleep(0.01) # Simulate some async work
36+
return 42
37+
38+
39+
async def _example_coro_simple() -> int:
40+
"""Helper coroutine that returns 42 after a delay."""
41+
await asyncio.sleep(0.01)
42+
return 42
43+
44+
45+
async def _example_coro_with_counter_no_sleep(counter: CallCounter) -> int:
46+
"""Helper coroutine that increments a counter and returns 42 immediately."""
47+
counter.increment()
48+
return 42
49+
50+
51+
async def _example_coro_return_one() -> int:
52+
"""Helper coroutine that returns 1."""
53+
return 1
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_concurrent_await():
58+
"""Test that ReAwaitable can be awaited concurrently from multiple tasks."""
59+
counter = CallCounter()
60+
61+
awaitable = ReAwaitable(_example_coro_with_counter(counter))
62+
63+
# Create multiple tasks that await the same ReAwaitable instance
64+
tasks = [
65+
asyncio.create_task(_await_helper(awaitable)),
66+
asyncio.create_task(_await_helper(awaitable)),
67+
asyncio.create_task(_await_helper(awaitable)),
68+
]
69+
70+
# All tasks should complete without error
71+
gathered_results = await asyncio.gather(*tasks, return_exceptions=True)
72+
73+
# Check that no exceptions were raised
74+
for result in gathered_results:
75+
assert not isinstance(result, Exception)
76+
77+
# The underlying coroutine should only be called once
78+
assert counter.count == 1
79+
80+
# All results should be the same
81+
assert all(res == 42 for res in gathered_results)
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_concurrent_await_with_different_values():
86+
"""Test that multiple ReAwaitable instances work correctly."""
87+
awaitables = [
88+
ReAwaitable(_example_with_value(0)),
89+
ReAwaitable(_example_with_value(1)),
90+
ReAwaitable(_example_with_value(2)),
91+
]
92+
93+
# Create tasks for each awaitable
94+
tasks = []
95+
for awaitable in awaitables:
96+
# Each awaitable is awaited multiple times
97+
tasks.extend([
98+
asyncio.create_task(_await_helper(awaitable)),
99+
asyncio.create_task(_await_helper(awaitable)),
100+
])
101+
102+
gathered_results = await asyncio.gather(*tasks, return_exceptions=True)
103+
104+
# Check that no exceptions were raised
105+
for result in gathered_results:
106+
assert not isinstance(result, Exception)
107+
108+
# Check that each awaitable returned its correct value multiple times
109+
assert gathered_results[0] == gathered_results[1] == 0
110+
assert gathered_results[2] == gathered_results[3] == 1
111+
assert gathered_results[4] == gathered_results[5] == 2
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_sequential_await():
116+
"""Test that ReAwaitable still works correctly with sequential awaits."""
117+
counter = CallCounter()
118+
119+
awaitable = ReAwaitable(_example_coro_with_counter_no_sleep(counter))
120+
121+
# Sequential awaits should work as before
122+
result1 = await awaitable
123+
result2 = await awaitable
124+
result3 = await awaitable
125+
126+
assert result1 == result2 == result3 == 42
127+
assert counter.count == 1 # Should only be called once
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_no_event_loop_fallback():
132+
"""Test that ReAwaitable works when no event loop is available."""
133+
counter = CallCounter()
134+
135+
awaitable = ReAwaitable(_example_coro_with_counter_no_sleep(counter))
136+
137+
# Mock asyncio.Lock to raise RuntimeError (simulating no event loop)
138+
with patch('asyncio.Lock', side_effect=RuntimeError('No event loop')):
139+
# First await should execute the coroutine and cache the result
140+
result1 = await awaitable
141+
assert result1 == 42
142+
assert counter.count == 1
143+
144+
# Second await should return cached result without executing again
145+
result2 = await awaitable
146+
assert result2 == 42
147+
assert counter.count == 1 # Should still be 1, not incremented
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_lock_path_branch_coverage():
152+
"""Test to ensure branch coverage in the lock acquisition path."""
153+
counter = CallCounter()
154+
155+
awaitable = ReAwaitable(_example_coro_with_counter_no_sleep(counter))
156+
157+
# First ensure normal path works (should create lock and execute)
158+
result1 = await awaitable
159+
assert result1 == 42
160+
assert counter.count == 1
161+
162+
# Second call should go through the locked path and find cache
163+
result2 = await awaitable
164+
assert result2 == 42
165+
assert counter.count == 1
166+
167+
168+
@pytest.mark.asyncio
169+
async def test_reawaitable_decorator():
170+
"""Test the reawaitable decorator function."""
171+
counter = CallCounter()
172+
173+
@reawaitable
174+
async def decorated_coro() -> int:
175+
counter.increment()
176+
return 42
177+
178+
# Test that the decorator works
179+
result = decorated_coro()
180+
assert isinstance(result, ReAwaitable)
181+
182+
# Test multiple awaits
183+
value1 = await result
184+
value2 = await result
185+
assert value1 == value2 == 42
186+
assert counter.count == 1
187+
188+
189+
def test_reawaitable_repr():
190+
"""Test that ReAwaitable repr matches the coroutine repr."""
191+
coro = _example_coro_return_one()
192+
awaitable = ReAwaitable(coro)
193+
194+
# The repr should match (though the exact format may vary)
195+
# We just check that repr works without error
196+
repr_result = repr(awaitable)
197+
assert isinstance(repr_result, str)
198+
assert len(repr_result) > 0
199+
200+
201+
@pytest.mark.asyncio
202+
async def test_precise_fallback_branch():
203+
"""Test the exact lines 124-126 branch in fallback path."""
204+
# The goal is to hit:
205+
# if self._cache is _sentinel: (line 124)
206+
# self._cache = await self._coro (line 125)
207+
# return self._cache (line 126)
208+
209+
counter = CallCounter()
210+
211+
awaitable = ReAwaitable(_example_coro_with_counter_no_sleep(counter))
212+
213+
# Force the RuntimeError path by mocking asyncio.Lock
214+
with patch('asyncio.Lock', side_effect=RuntimeError('No event loop')):
215+
# This should execute the fallback and hit the branch we need
216+
result = await awaitable
217+
assert result == 42
218+
assert counter.count == 1
219+
220+
# Verify we took the fallback path by checking _lock is still None
221+
assert awaitable._lock is None # noqa: SLF001
222+
223+
224+
@pytest.mark.asyncio
225+
async def test_precise_double_check_branch():
226+
"""Test the exact lines 130-132 branch in lock path."""
227+
# The goal is to hit:
228+
# if self._cache is _sentinel: (line 130)
229+
# self._cache = await self._coro (line 131)
230+
# return self._cache (line 132)
231+
232+
counter = CallCounter()
233+
234+
awaitable = ReAwaitable(_example_coro_with_counter_no_sleep(counter))
235+
# Manually set the lock to bypass lock creation
236+
awaitable._lock = asyncio.Lock() # noqa: SLF001
237+
238+
# Ensure we start with sentinel - this is the default state
239+
from returns.primitives.reawaitable import ( # noqa: PLC0415
240+
_sentinel, # noqa: PLC2701
241+
)
242+
assert awaitable._cache is _sentinel # noqa: SLF001
243+
244+
# Now await - this should go through the lock path and hit our target branch
245+
result = await awaitable
246+
assert result == 42
247+
assert counter.count == 1

tests/test_result/test_result_equality.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_immutability_failure():
5454
Failure(1).missing = 2
5555

5656
with pytest.raises(ImmutableStateError):
57-
del Failure(0)._inner_state # type: ignore # noqa: SLF001, WPS420
57+
del Failure(0)._inner_state # type: ignore # noqa: WPS420, SLF001
5858

5959
with pytest.raises(AttributeError):
6060
Failure(1).missing # type: ignore # noqa: B018
@@ -69,7 +69,7 @@ def test_immutability_success():
6969
Success(1).missing = 2
7070

7171
with pytest.raises(ImmutableStateError):
72-
del Success(0)._inner_state # type: ignore # noqa: SLF001, WPS420
72+
del Success(0)._inner_state # type: ignore # noqa: WPS420, SLF001
7373

7474
with pytest.raises(AttributeError):
7575
Success(1).missing # type: ignore # noqa: B018

0 commit comments

Comments
 (0)