Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
pip install dist/*.whl
- name: Test with pytest
run: |
python -m pip install typing-extensions
python -m pip install typing-extensions pytest-asyncio
pytest

build:
Expand Down Expand Up @@ -74,6 +74,6 @@ jobs:
pip install dist/*.whl
- name: Test with pytest
run: |
python -m pip install typing-extensions
python -m pip install typing-extensions pytest-asyncio
pytest

4 changes: 2 additions & 2 deletions .github/workflows/python-test-macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
pip install dist/*.whl
- name: Test with pytest
run: |
python -m pip install typing-extensions
python -m pip install typing-extensions pytest-asyncio
pytest

build:
Expand Down Expand Up @@ -74,6 +74,6 @@ jobs:
pip install dist/*.whl
- name: Test with pytest
run: |
python -m pip install typing-extensions
python -m pip install typing-extensions pytest-asyncio
pytest

2 changes: 1 addition & 1 deletion .github/workflows/python-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ jobs:
}
- name: Test with pytest
run: |
python -m pip install typing-extensions
python -m pip install typing-extensions pytest-asyncio
pytest

49 changes: 40 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ list_ext.extend()
Currently, we provide the following extensions:


| file | extended types |
|:---------------:|:----------------------------------:|
| dict_ext.py | dict_keys, dict_values, dict_items |
| float_ext.py | float |
| function_ext.py | FunctionType, LambdaType |
| int_ext.py | int |
| list_ext.py | list |
| seq_ext.py | map, filter, range, zip |
| str_ext.py | str |
| file | extended types |
|:----------------:|:----------------------------------:|
| coroutine_ext.py | coroutine (async functions) |
| dict_ext.py | dict_keys, dict_values, dict_items |
| float_ext.py | float |
| function_ext.py | FunctionType, LambdaType |
| int_ext.py | int |
| list_ext.py | list |
| seq_ext.py | map, filter, range, zip |
| str_ext.py | str |



Expand Down Expand Up @@ -214,6 +215,36 @@ list.last(self: List[T]) -> T, raise IndexError
```
Returns the last element in the list, or raises `IndexError` if the list is empty.

```py
coroutine.then(self: Awaitable[T], fn: Callable[[T], Awaitable[U] | U]) -> Awaitable[U]
```
Maps the result of the awaitable via an optionally async function. If the function is async, it is awaited in the context of the wrapped awaitable.

Example:
```py
async def get_value():
return 10

result = await get_value().then(lambda x: x * 2) # result is 20
```

```py
coroutine.catch(self: Awaitable[T], fn: Callable[[E], Awaitable[U] | U], *, exception: type[E] = Exception) -> Awaitable[T | U]
```
Catches an exception of the given type and calls the passed function with the caught exception.

If no exception was raised inside the wrapped awaitable, the function will not be called.
The passed function can optionally return a value to be returned in case of an error.
The passed function can be either sync or async. If it's async, it is awaited in the context of the wrapped awaitable.

Example:
```py
async def might_fail():
raise ValueError("error")

result = await might_fail().catch(lambda e: "default", exception=ValueError) # result is "default"
```

```py
float.round(self: float) -> int
```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ dev = [
"meson-python>=0.17.1",
"ninja>=1.11.1.4",
]
test = ["pytest>=7.4.4", "typing-extensions>=4.7.1"]
test = ["pytest>=7.4.4", "pytest-asyncio>=0.21.0", "typing-extensions>=4.7.1"]
90 changes: 90 additions & 0 deletions src/extype/builtin_extensions/coroutine_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from inspect import iscoroutine
from typing import Awaitable, Callable, Type, TypeVar, Union

from ..extension_utils import extend_type_with, extension


__all__ = [
"extend",
"CoroutineExtension"
]


_T = TypeVar("_T")
_U = TypeVar("_U")
_E = TypeVar("_E", bound=BaseException)


class CoroutineExtension:
"""
A class that contains methods to extend coroutine objects (async functions).
"""

@extension
def then(self: Awaitable[_T], fn: Callable[[_T], Union[Awaitable[_U], _U]]) -> Awaitable[_U]:
"""
Maps the result of the awaitable via an optionally async function.

If the function is async, it is awaited in the context of the wrapped awaitable.

Args:
fn: A function that takes the result of the awaitable and returns a value or awaitable.

Returns:
An awaitable that resolves to the result of the function.
"""
async def _then():
result = fn(await self)
if iscoroutine(result):
return await result
return result

return _then()

@extension
def catch(
self: Awaitable[_T],
fn: Callable[[_E], Union[Awaitable[_U], _U]],
*,
exception: Type[_E] = Exception
) -> Awaitable[Union[_T, _U]]:
"""
Catches an exception of the given type and calls the passed function with the caught exception.

If no exception was raised inside the wrapped awaitable, the function will not be called.
The passed function can optionally return a value to be returned in case of an error.
The passed function can be either sync or async. If it's async, it is awaited.

Args:
fn: A function that takes the exception and returns a value or awaitable.
exception: The type of exception to catch (default: Exception).

Returns:
An awaitable that resolves to the original result or the result of the error handler.
"""
async def _catch():
try:
return await self
except exception as e:
result = fn(e)
if iscoroutine(result):
return await result
return result

return _catch()


def extend():
"""
Applies the coroutine extensions to coroutine objects.
"""
# Get the coroutine type by creating a coroutine and getting its type
async def _dummy():
pass

coro = _dummy()
coroutine_type = type(coro)
extend_type_with(coroutine_type, CoroutineExtension)

# Close the coroutine to avoid warnings
coro.close()
2 changes: 2 additions & 0 deletions src/extype/builtin_extensions/extend_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
function_ext,
dict_ext,
str_ext,
coroutine_ext,
)

for ext in [
Expand All @@ -16,5 +17,6 @@
function_ext,
dict_ext,
str_ext,
coroutine_ext,
]:
ext.extend()
3 changes: 2 additions & 1 deletion src/extype/builtin_extensions/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ python_sources = [
'int_ext.py',
'list_ext.py',
'seq_ext.py',
'str_ext.py'
'str_ext.py',
'coroutine_ext.py'
]


Expand Down
106 changes: 105 additions & 1 deletion tests/test_builtin_extensions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from extype.builtin_extensions import extend_all
from extype.builtin_extensions import extend_all # noqa: F401


# dict keys extension tests
Expand Down Expand Up @@ -322,3 +322,107 @@ def test_str_to_float():

###################################################


# coroutine extensions tests


@pytest.mark.asyncio
async def test_coroutine_then_sync():
async def foo():
return 10

result = await foo().then(lambda x: x + 5)
assert result == 15


@pytest.mark.asyncio
async def test_coroutine_then_async():
async def foo():
return 10

async def add_five(x):
return x + 5

result = await foo().then(add_five)
assert result == 15


@pytest.mark.asyncio
async def test_coroutine_then_chaining():
async def foo():
return 10

async def add_five(x):
return x + 5

result = await foo().then(lambda x: x * 2).then(add_five).then(lambda x: x - 3)
assert result == 22 # (10 * 2) + 5 - 3 = 22


@pytest.mark.asyncio
async def test_coroutine_catch_no_exception():
async def foo():
return 42

result = await foo().catch(lambda e: 0)
assert result == 42


@pytest.mark.asyncio
async def test_coroutine_catch_with_exception():
async def foo():
raise ValueError("test error")

result = await foo().catch(lambda e: 100, exception=ValueError)
assert result == 100


@pytest.mark.asyncio
async def test_coroutine_catch_async_handler():
async def foo():
raise ValueError("test error")

async def handle_error(e):
return 200

result = await foo().catch(handle_error, exception=ValueError)
assert result == 200


@pytest.mark.asyncio
async def test_coroutine_catch_wrong_exception_type():
async def foo():
raise ValueError("test error")

with pytest.raises(ValueError):
await foo().catch(lambda e: 0, exception=TypeError)


@pytest.mark.asyncio
async def test_coroutine_catch_default_exception():
async def foo():
raise RuntimeError("test error")

result = await foo().catch(lambda e: 300)
assert result == 300


@pytest.mark.asyncio
async def test_coroutine_then_and_catch_combined():
async def foo():
return 10

result = await foo().then(lambda x: x * 2).catch(lambda e: 0)
assert result == 20


@pytest.mark.asyncio
async def test_coroutine_catch_and_then_combined():
async def foo():
raise ValueError("error")

result = await foo().catch(lambda e: 50, exception=ValueError).then(lambda x: x + 10)
assert result == 60


###################################################