From 31d1db3af69d2c3237513ce5c4722c87e6b01a52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20G=C3=B3recki?= Date: Fri, 3 Jan 2025 09:26:47 +0100 Subject: [PATCH] migrate to async --- .pre-commit-config.yaml | 8 +--- poetry.lock | 16 +++---- pyproject.toml | 2 +- src/api/dependencies.py | 8 ++-- src/api/main.py | 30 ++++++++++-- src/api/routers/bidding.py | 9 ++-- src/api/routers/catalog.py | 13 +++-- src/api/tests/test_bidding.py | 19 +++++--- src/api/tests/test_catalog.py | 29 +++++++----- src/api/tests/test_common.py | 6 +++ src/config/api_config.py | 2 +- src/config/container.py | 47 ++++++++++++------- ..._create_listing_when_draft_is_published.py | 15 +++--- .../command/create_listing_draft.py | 6 +-- .../command/publish_listing_draft.py | 4 +- .../application/query/get_all_listings.py | 3 +- .../application/test_create_listing_draft.py | 13 +++-- .../tests/application/test_publish_listing.py | 10 ++-- src/seedwork/domain/value_objects.py | 18 +++---- .../tests/domain/test_value_objects.py | 7 ++- 20 files changed, 153 insertions(+), 112 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09d896a..22e1b87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,7 @@ default_language_version: - python: python3.10 + python: python3.11 repos: - # native hints instead of `from typing` | List -> list - - repo: https://github.com/sondrelg/pep585-upgrade - rev: 'v1.0' # Version to check - hooks: - - id: upgrade-type-hints - # Only for removing unused imports > Other staff done by Black - repo: https://github.com/myint/autoflake rev: "v1.4" # Version to check diff --git a/poetry.lock b/poetry.lock index 467bcad..5db8e0e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -685,13 +685,13 @@ files = [ [[package]] name = "lato" -version = "0.10.0" +version = "0.12.0" description = "Lato is a Python microframework designed for building modular monoliths and loosely coupled applications." optional = false -python-versions = ">=3.9,<4.0" +python-versions = "<4.0,>=3.9" files = [ - {file = "lato-0.10.0-py3-none-any.whl", hash = "sha256:9a2efa02d8cf28503c53a8229aa2f787ae46a7123759b0474cfda9a9c44ac446"}, - {file = "lato-0.10.0.tar.gz", hash = "sha256:d53b212dddb902067fa66969f181457e3fb72c34f8fabe942e4ca9fe632bd606"}, + {file = "lato-0.12.0-py3-none-any.whl", hash = "sha256:0ef3512ff0d3f7912623595ecc21e8f23dfce8e0ff385f17d847b8bbfed27253"}, + {file = "lato-0.12.0.tar.gz", hash = "sha256:c3d36cdae0d5ca2db8868423f780c42bff22ce30bd6f3783ee5576ef223cec7a"}, ] [package.dependencies] @@ -1200,13 +1200,13 @@ testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygm [[package]] name = "pytest-asyncio" -version = "0.23.5.post1" +version = "0.23.8" description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-asyncio-0.23.5.post1.tar.gz", hash = "sha256:b9a8806bea78c21276bc34321bbf234ba1b2ea5b30d9f0ce0f2dea45e4685813"}, - {file = "pytest_asyncio-0.23.5.post1-py3-none-any.whl", hash = "sha256:30f54d27774e79ac409778889880242b0403d09cabd65b727ce90fe92dd5d80e"}, + {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, + {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, ] [package.dependencies] @@ -1743,4 +1743,4 @@ tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} [metadata] lock-version = "2.0" python-versions = "^3.10.0" -content-hash = "5ac4859b98df7fa83e2b336a4a189531aa1bda42073eff3f6c9229d234d16590" +content-hash = "44490e84888709bc4e640d8a067c1f262a4e51cc2c768934804860103c11692b" diff --git a/pyproject.toml b/pyproject.toml index 1498837..0b65b39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,8 @@ requests = "^2.28.1" bcrypt = "^4.0.1" mypy = "^1.4.1" fastapi = "^0.110.0" -lato = "^0.10.0" pydantic-settings = "^2.2.1" +lato = "^0.12.0" [tool.poetry.dev-dependencies] poethepoet = "^0.10.0" diff --git a/src/api/dependencies.py b/src/api/dependencies.py index e49cbab..f94a82d 100644 --- a/src/api/dependencies.py +++ b/src/api/dependencies.py @@ -2,25 +2,23 @@ from fastapi import Depends, Request from fastapi.security import OAuth2PasswordBearer +from lato import Application, TransactionContext from modules.iam.application.services import IamService from modules.iam.domain.entities import User -from lato import Application, TransactionContext oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") async def get_application(request: Request) -> Application: - application = request.app.container.application() - return application + return request.state.lato_application async def get_transaction_context( app: Annotated[Application, Depends(get_application)], ) -> TransactionContext: """Creates a new transaction context for each request""" - - with app.transaction_context() as ctx: + async with app.transaction_context() as ctx: yield ctx diff --git a/src/api/main.py b/src/api/main.py index 23076bc..55492c6 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -2,11 +2,12 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from pydantic import ValidationError from api.dependencies import oauth2_scheme # noqa from api.routers import bidding, catalog, diagnostics, iam from config.api_config import ApiConfig -from config.container import create_application, ApplicationContainer +from config.container import ApplicationContainer from seedwork.domain.exceptions import DomainException, EntityNotFoundException from seedwork.infrastructure.database import Base from seedwork.infrastructure.logging import LoggerFactory, logger @@ -28,16 +29,27 @@ app.include_router(bidding.router) app.include_router(iam.router) app.include_router(diagnostics.router) -app.container = container +app.container = container # type: ignore +@app.exception_handler(ValidationError) +async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): + return JSONResponse( + status_code=422, + content={ + "detail": exc.errors(), + }, + ) + + +# startup try: import uuid from modules.iam.application.services import IamService - with app.container.application().transaction_context() as ctx: + with container.application().transaction_context() as ctx: iam_service = ctx[IamService] iam_service.create_user( user_id=uuid.UUID(int=1), @@ -50,7 +62,7 @@ @app.exception_handler(DomainException) -async def unicorn_exception_handler(request: Request, exc: DomainException): +async def domain_exception_handler(request: Request, exc: DomainException): if container.config.DEBUG: raise exc @@ -61,7 +73,9 @@ async def unicorn_exception_handler(request: Request, exc: DomainException): @app.exception_handler(EntityNotFoundException) -async def unicorn_exception_handler(request: Request, exc: EntityNotFoundException): +async def entity_not_found_exception_handler( + request: Request, exc: EntityNotFoundException +): return JSONResponse( status_code=404, content={ @@ -70,6 +84,12 @@ async def unicorn_exception_handler(request: Request, exc: EntityNotFoundExcepti ) +@app.middleware("http") +async def add_lato_application(request: Request, call_next): + request.state.lato_application = container.application() + return await call_next(request) + + @app.middleware("http") async def add_process_time(request: Request, call_next): start_time = time.time() diff --git a/src/api/routers/bidding.py b/src/api/routers/bidding.py index 4603f3d..68dfbdf 100644 --- a/src/api/routers/bidding.py +++ b/src/api/routers/bidding.py @@ -1,13 +1,13 @@ from typing import Annotated from fastapi import APIRouter, Depends +from lato import Application from api.dependencies import get_application from api.models.bidding import BiddingResponse, PlaceBidRequest from config.container import inject from modules.bidding.application.command import PlaceBidCommand, RetractBidCommand from modules.bidding.application.query.get_bidding_details import GetBiddingDetails -from lato import Application router = APIRouter() @@ -25,7 +25,7 @@ async def get_bidding_details_of_listing( Shows listing details """ query = GetBiddingDetails(listing_id=listing_id) - result = app.execute(query) + result = await app.execute_async(query) return BiddingResponse( listing_id=result.id, auction_end_date=result.ends_at, @@ -52,10 +52,11 @@ async def place_bid( bidder_id=request_body.bidder_id, amount=request_body.amount, ) - app.execute(command) + await app.execute_async(command) + # execute_async, or execute? query = GetBiddingDetails(listing_id=listing_id) - result = app.execute(query) + result = await app.execute_async(query) return BiddingResponse( listing_id=result.id, auction_end_date=result.ends_at, diff --git a/src/api/routers/catalog.py b/src/api/routers/catalog.py index cbe92d9..eae8ff2 100644 --- a/src/api/routers/catalog.py +++ b/src/api/routers/catalog.py @@ -22,13 +22,12 @@ @router.get("/catalog", tags=["catalog"], response_model=ListingIndexModel) -@inject -def get_all_listings(app: Annotated[Application, Depends(get_application)]): +async def get_all_listings(app: Annotated[Application, Depends(get_application)]): """ Shows all published listings in the catalog """ query = GetAllListings() - result = app.execute(query) + result = await app.execute_async(query) return dict(data=result) @@ -41,7 +40,7 @@ async def get_listing_details( Shows listing details """ query = GetListingDetails(listing_id=listing_id) - query_result = app.execute_query(query) + query_result = await app.execute_async(query) return dict(data=query_result.payload) @@ -87,7 +86,7 @@ async def delete_listing( listing_id=listing_id, seller_id=current_user.id, ) - app.execute(command) + await app.execute_async(command) @router.post( @@ -109,8 +108,8 @@ async def publish_listing( listing_id=listing_id, seller_id=current_user.id, ) - app.execute(command) + await app.execute_async(command) query = GetListingDetails(listing_id=listing_id) - response = app.execute(query) + response = await app.execute_async(query) return response diff --git a/src/api/tests/test_bidding.py b/src/api/tests/test_bidding.py index 00bef31..a4c949a 100644 --- a/src/api/tests/test_bidding.py +++ b/src/api/tests/test_bidding.py @@ -8,7 +8,7 @@ from seedwork.infrastructure.logging import logger -def setup_app_for_bidding_tests(app, listing_id, seller_id, bidder_id): +async def setup_app_for_bidding_tests(app, listing_id, seller_id, bidder_id): logger.info("Adding users") with app.transaction_context() as ctx: iam_service = ctx["iam_service"] @@ -19,15 +19,17 @@ def setup_app_for_bidding_tests(app, listing_id, seller_id, bidder_id): password="password", access_token="token1", ) + ctx["logger"].debug(f"Added seller: {seller_id}") + iam_service.create_user( user_id=bidder_id, email="bidder@example.com", password="password", access_token="token2", ) + ctx["logger"].debug(f"Added bidder: {bidder_id}") - logger.info("Adding listing") - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=listing_id, title="Foo", @@ -36,18 +38,23 @@ def setup_app_for_bidding_tests(app, listing_id, seller_id, bidder_id): seller_id=seller_id, ) ) - app.execute( + logger.info(f"Created listing draft: {listing_id}") + + await app.execute_async( PublishListingDraftCommand(listing_id=listing_id, seller_id=seller_id) ) + logger.info(f"Published listing draft {listing_id} by seller {seller_id}") + logger.info("test setup complete") @pytest.mark.integration -def test_place_bid(app, api_client): +@pytest.mark.asyncio +async def test_place_bid(app, api_client): listing_id = GenericUUID(int=1) seller_id = GenericUUID(int=2) bidder_id = GenericUUID(int=3) - setup_app_for_bidding_tests(app, listing_id, seller_id, bidder_id) + await setup_app_for_bidding_tests(app, listing_id, seller_id, bidder_id) url = f"/bidding/{listing_id}/place_bid" diff --git a/src/api/tests/test_catalog.py b/src/api/tests/test_catalog.py index 4959088..8ccaa52 100644 --- a/src/api/tests/test_catalog.py +++ b/src/api/tests/test_catalog.py @@ -15,9 +15,10 @@ def test_empty_catalog_list(api_client): @pytest.mark.integration -def test_catalog_list_with_one_item(app, api_client): +@pytest.mark.asyncio +async def test_catalog_list_with_one_item(app, api_client): # arrange - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=GenericUUID(int=1), title="Foo", @@ -48,9 +49,10 @@ def test_catalog_list_with_one_item(app, api_client): @pytest.mark.integration -def test_catalog_list_with_two_items(app, api_client): +@pytest.mark.asyncio +async def test_catalog_list_with_two_items(app, api_client): # arrange - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=GenericUUID(int=1), title="Foo #1", @@ -59,7 +61,7 @@ def test_catalog_list_with_two_items(app, api_client): seller_id=GenericUUID(int=2), ) ) - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=GenericUUID(int=2), title="Foo #2", @@ -86,9 +88,10 @@ def test_catalog_create_draft_fails_due_to_incomplete_data( @pytest.mark.integration -def test_catalog_delete_draft(app, authenticated_api_client): +@pytest.mark.asyncio +async def test_catalog_delete_draft(app, authenticated_api_client): current_user = authenticated_api_client.current_user - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=GenericUUID(int=1), title="Listing to be deleted", @@ -111,11 +114,12 @@ def test_catalog_delete_non_existing_draft_returns_404(authenticated_api_client) @pytest.mark.integration -def test_catalog_publish_listing_draft(app, authenticated_api_client): +@pytest.mark.asyncio +async def test_catalog_publish_listing_draft(app, authenticated_api_client): # arrange current_user = authenticated_api_client.current_user listing_id = GenericUUID(int=1) - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=listing_id, title="Listing to be published", @@ -132,11 +136,12 @@ def test_catalog_publish_listing_draft(app, authenticated_api_client): assert response.status_code == 200 -def test_published_listing_appears_in_biddings(app, authenticated_api_client): +@pytest.mark.asyncio +async def test_published_listing_appears_in_biddings(app, authenticated_api_client): # arrange listing_id = GenericUUID(int=1) current_user = authenticated_api_client.current_user - app.execute( + await app.execute_async( CreateListingDraftCommand( listing_id=listing_id, title="Listing to be published", @@ -145,7 +150,7 @@ def test_published_listing_appears_in_biddings(app, authenticated_api_client): seller_id=current_user.id, ) ) - app.execute( + await app.execute_async( PublishListingDraftCommand( listing_id=listing_id, seller_id=current_user.id, diff --git a/src/api/tests/test_common.py b/src/api/tests/test_common.py index c950f24..5bec52d 100644 --- a/src/api/tests/test_common.py +++ b/src/api/tests/test_common.py @@ -11,3 +11,9 @@ def test_homepage_returns_200(api_client): def test_docs_page_returns_200(api_client): response = api_client.get("/docs") assert response.status_code == 200 + + +@pytest.mark.integration +def test_openapi_schema_returns_200(api_client): + response = api_client.get("/openapi.json") + assert response.status_code == 200 diff --git a/src/config/api_config.py b/src/config/api_config.py index 4ceb6db..eb22405 100644 --- a/src/config/api_config.py +++ b/src/config/api_config.py @@ -11,7 +11,7 @@ class ApiConfig(BaseSettings): APP_NAME: str = "Online Auctions API" DEBUG: bool = Field(default=True) - DATABASE_ECHO: bool = Field(default=True) + DATABASE_ECHO: bool = Field(default=False) DATABASE_URL: str = Field( default="postgresql://postgres:password@localhost:5432/postgres", ) diff --git a/src/config/container.py b/src/config/container.py index 7966641..59dfbd0 100644 --- a/src/config/container.py +++ b/src/config/container.py @@ -1,18 +1,19 @@ +import asyncio +import copy import inspect import json import uuid from typing import Optional from uuid import UUID -from pydantic_settings import BaseSettings from dependency_injector import containers, providers from dependency_injector.containers import Container from dependency_injector.providers import Dependency, Factory, Provider, Singleton from dependency_injector.wiring import Provide, inject # noqa +from lato import Application, DependencyProvider, TransactionContext +from pydantic_settings import BaseSettings from sqlalchemy import create_engine from sqlalchemy.orm import Session -from lato import Application, TransactionContext, DependencyProvider -import copy from modules.bidding.application import bidding_module from modules.bidding.infrastructure.listing_repository import ( @@ -24,9 +25,8 @@ ) from modules.iam.application.services import IamService from modules.iam.infrastructure.repository import PostgresJsonUserRepository - from seedwork.application.inbox_outbox import InMemoryOutbox -from seedwork.infrastructure.logging import logger, Logger +from seedwork.infrastructure.logging import Logger, logger def _default(val): @@ -79,16 +79,22 @@ def on_create_transaction_context(**kwargs): return TransactionContext(dependency_provider) @application.on_enter_transaction_context - def on_enter_transaction_context(ctx: TransactionContext): + async def on_enter_transaction_context(ctx: TransactionContext): ctx.set_dependencies(publish=ctx.publish) logger.debug("Entering transaction") @application.on_exit_transaction_context - def on_exit_transaction_context(ctx: TransactionContext, exception: Optional[Exception] = None): + async def on_exit_transaction_context( + ctx: TransactionContext, exception: Optional[Exception] = None + ): session = ctx["db_session"] if exception: session.rollback() logger.warning(f"rollback due to {exception}") + + # from pydantic import ValidationError + # if type(exception) not in [ValidationError]: + # raise exception else: session.commit() logger.debug(f"committed") @@ -97,7 +103,7 @@ def on_exit_transaction_context(ctx: TransactionContext, exception: Optional[Exc logger.correlation_id.set(uuid.UUID(int=0)) # type: ignore @application.transaction_middleware - def logging_middleware(ctx: TransactionContext, call_next): + async def logging_middleware(ctx: TransactionContext, call_next): description = ( f"{ctx.current_action[1]} -> {repr(ctx.current_action[0])}" if ctx.current_action @@ -105,27 +111,31 @@ def logging_middleware(ctx: TransactionContext, call_next): ) logger.debug(f"Executing {description}...") result = call_next() + if asyncio.iscoroutine(result): + result = await result logger.debug(f"Finished executing {description}") return result - + @application.transaction_middleware - def event_collector_middleware(ctx: TransactionContext, call_next): + async def event_collector_middleware(ctx: TransactionContext, call_next): handler_kwargs = call_next.keywords result = call_next() - + if asyncio.iscoroutine(result): + result = await result + logger.debug(f"Collecting event from {ctx['message'].__class__}") - + domain_events = [] repositories = filter( - lambda x: hasattr(x, 'collect_events'), handler_kwargs.values() + lambda x: hasattr(x, "collect_events"), handler_kwargs.values() ) for repo in repositories: domain_events.extend(repo.collect_events()) for event in domain_events: logger.debug(f"Publishing {event}") - ctx.publish(event) - + await ctx.publish_async(event) + return result return application @@ -133,13 +143,14 @@ def event_collector_middleware(ctx: TransactionContext, call_next): class ApplicationContainer(containers.DeclarativeContainer): """Dependency Injection container for the application (application-level dependencies) - see https://github.com/ets-labs/python-dependency-injector for more details + see https://github.com/ets-labs/python-dependency-injector for more details """ + __self__ = providers.Self() config = providers.Dependency(instance_of=BaseSettings) db_engine = providers.Singleton(create_db_engine, config) application = providers.Singleton(create_application, db_engine) - + class TransactionContainer(containers.DeclarativeContainer): """Dependency Injection container for the transaction context (transaction-level dependencies) @@ -193,6 +204,8 @@ def inspect_provider(provider: Provider) -> bool: class ContainerProvider(DependencyProvider): + """A dependency provider that uses a dependency injector container under the hood""" + def __init__(self, container: Container): self.container = container self.counter = 0 diff --git a/src/modules/bidding/tests/application/test_create_listing_when_draft_is_published.py b/src/modules/bidding/tests/application/test_create_listing_when_draft_is_published.py index aa900e2..4c77891 100644 --- a/src/modules/bidding/tests/application/test_create_listing_when_draft_is_published.py +++ b/src/modules/bidding/tests/application/test_create_listing_when_draft_is_published.py @@ -8,15 +8,16 @@ @pytest.mark.integration -def test_create_listing_on_draft_published_event(app, engine): +@pytest.mark.asyncio +async def test_create_listing_on_draft_published_event(app, engine): listing_id = GenericUUID(int=1) - app.publish( - ListingPublishedEvent( - listing_id=listing_id, - seller_id=GenericUUID.next_id(), - ask_price=Money(10), - ) + await app.publish_async( + ListingPublishedEvent( + listing_id=listing_id, + seller_id=GenericUUID.next_id(), + ask_price=Money(10), ) + ) with app.transaction_context() as ctx: listing_repository = ctx[BiddingListingRepository] diff --git a/src/modules/catalog/application/command/create_listing_draft.py b/src/modules/catalog/application/command/create_listing_draft.py index adc9598..2c9ef11 100644 --- a/src/modules/catalog/application/command/create_listing_draft.py +++ b/src/modules/catalog/application/command/create_listing_draft.py @@ -1,11 +1,9 @@ -from dataclasses import dataclass +from lato import Command from modules.catalog.application import catalog_module from modules.catalog.domain.entities import Listing from modules.catalog.domain.events import ListingDraftCreatedEvent from modules.catalog.domain.repositories import ListingRepository -from seedwork.application.command_handlers import CommandResult -from lato import Command, TransactionContext from seedwork.domain.value_objects import GenericUUID, Money @@ -20,7 +18,7 @@ class CreateListingDraftCommand(Command): @catalog_module.handler(CreateListingDraftCommand) -def create_listing_draft( +async def create_listing_draft( command: CreateListingDraftCommand, repository: ListingRepository, publish ): listing = Listing( diff --git a/src/modules/catalog/application/command/publish_listing_draft.py b/src/modules/catalog/application/command/publish_listing_draft.py index 18ad314..ae51b4c 100644 --- a/src/modules/catalog/application/command/publish_listing_draft.py +++ b/src/modules/catalog/application/command/publish_listing_draft.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - from modules.catalog.application import catalog_module from modules.catalog.domain.entities import Listing from modules.catalog.domain.repositories import ListingRepository @@ -17,7 +15,7 @@ class PublishListingDraftCommand(Command): @catalog_module.handler(PublishListingDraftCommand) -def publish_listing_draft( +async def publish_listing_draft( command: PublishListingDraftCommand, listing_repository: ListingRepository, ): diff --git a/src/modules/catalog/application/query/get_all_listings.py b/src/modules/catalog/application/query/get_all_listings.py index ab552ae..64880d1 100644 --- a/src/modules/catalog/application/query/get_all_listings.py +++ b/src/modules/catalog/application/query/get_all_listings.py @@ -4,7 +4,6 @@ from modules.catalog.application.query.model_mappers import map_listing_model_to_dao from modules.catalog.infrastructure.listing_repository import ListingModel from seedwork.application.queries import Query -from seedwork.application.query_handlers import QueryResult class GetAllListings(Query): @@ -12,7 +11,7 @@ class GetAllListings(Query): @catalog_module.handler(GetAllListings) -def get_all_listings( +async def get_all_listings( query: GetAllListings, session: Session, ) -> list[ListingModel]: diff --git a/src/modules/catalog/tests/application/test_create_listing_draft.py b/src/modules/catalog/tests/application/test_create_listing_draft.py index 97455ac..a451430 100644 --- a/src/modules/catalog/tests/application/test_create_listing_draft.py +++ b/src/modules/catalog/tests/application/test_create_listing_draft.py @@ -1,6 +1,3 @@ -from uuid import UUID -from seedwork.domain.value_objects import GenericUUID - import pytest from modules.catalog.application.command.create_listing_draft import ( @@ -8,12 +5,14 @@ create_listing_draft, ) from modules.catalog.domain.entities import Seller -from seedwork.domain.value_objects import Money +from seedwork.domain.value_objects import GenericUUID, Money from seedwork.infrastructure.repository import InMemoryRepository from seedwork.tests.application.test_utils import FakeEventPublisher + @pytest.mark.unit -def test_create_listing_draft(): +@pytest.mark.asyncio +async def test_create_listing_draft(): # arrange listing_id = GenericUUID(int=1) command = CreateListingDraftCommand( @@ -27,8 +26,8 @@ def test_create_listing_draft(): repository = InMemoryRepository() # act - create_listing_draft(command, repository, publish) + await create_listing_draft(command, repository, publish) # assert assert repository.get_by_id(listing_id).title == "foo" - assert publish.contains('ListingDraftCreatedEvent') + assert publish.contains("ListingDraftCreatedEvent") diff --git a/src/modules/catalog/tests/application/test_publish_listing.py b/src/modules/catalog/tests/application/test_publish_listing.py index d049c4d..c85314d 100644 --- a/src/modules/catalog/tests/application/test_publish_listing.py +++ b/src/modules/catalog/tests/application/test_publish_listing.py @@ -12,7 +12,8 @@ @pytest.mark.unit -def test_publish_listing(): +@pytest.mark.asyncio +async def test_publish_listing(): # arrange seller_repository = InMemoryRepository() seller = Seller(id=Seller.next_id()) @@ -34,7 +35,7 @@ def test_publish_listing(): ) # act - publish_listing_draft( + await publish_listing_draft( command, listing_repository=listing_repository, ) @@ -44,7 +45,8 @@ def test_publish_listing(): @pytest.mark.unit -def test_publish_listing_and_break_business_rule(): +@pytest.mark.asyncio +async def test_publish_listing_and_break_business_rule(): # arrange seller_repository = InMemoryRepository() seller = Seller(id=Seller.next_id()) @@ -69,7 +71,7 @@ def test_publish_listing_and_break_business_rule(): # assert with pytest.raises(BusinessRuleValidationException): - publish_listing_draft( + await publish_listing_draft( command, listing_repository=listing_repository, ) diff --git a/src/seedwork/domain/value_objects.py b/src/seedwork/domain/value_objects.py index 1cb5e3d..4383105 100644 --- a/src/seedwork/domain/value_objects.py +++ b/src/seedwork/domain/value_objects.py @@ -1,24 +1,20 @@ import uuid -from pydantic import BaseModel, ConfigDict from dataclasses import dataclass +from typing import Any + +from pydantic import GetCoreSchemaHandler class GenericUUID(uuid.UUID): @classmethod def next_id(cls): return cls(int=uuid.uuid4().int) - - @classmethod - def __get_validators__(cls): - yield cls.validate @classmethod - def validate(cls, value, validation_info): - if isinstance(value, str): - return cls(value) - if not isinstance(value, uuid.UUID): - raise ValueError('Invalid UUID') - return cls(value.hex) + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ): + return handler.generate_schema(uuid.UUID) class ValueObject: diff --git a/src/seedwork/tests/domain/test_value_objects.py b/src/seedwork/tests/domain/test_value_objects.py index 7ffb5ec..68ce045 100644 --- a/src/seedwork/tests/domain/test_value_objects.py +++ b/src/seedwork/tests/domain/test_value_objects.py @@ -1,6 +1,11 @@ import pytest +from pydantic import BaseModel -from seedwork.domain.value_objects import Money +from seedwork.domain.value_objects import GenericUUID, Money + + +class CustomPydanticModel(BaseModel): + uuid: GenericUUID @pytest.mark.unit