Skip to content

Commit

Permalink
Add unit tests for TrackDAO
Browse files Browse the repository at this point in the history
  • Loading branch information
m-danya committed Apr 16, 2024
1 parent 31a46ca commit e6ae572
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 25 deletions.
16 changes: 10 additions & 6 deletions accompanist/collection/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from sqlalchemy.orm import selectinload

from accompanist.collection.models import Album, Artist, Track
from accompanist.collection.schema import TrackUpdateRequest
from accompanist.dao import BaseDAO
from accompanist.database import async_session_maker
from accompanist.exceptions import NoAlbumException, NoTrackException
from accompanist.exceptions import AlbumNotFoundException, TrackNotFoundException


class ArtistDAO(BaseDAO):
Expand All @@ -15,7 +16,7 @@ class ArtistDAO(BaseDAO):

class AlbumDAO(BaseDAO):
model = Album
not_found_exception = NoAlbumException
not_found_exception = AlbumNotFoundException

@classmethod
async def get_all(cls) -> list[dict]:
Expand All @@ -40,13 +41,13 @@ async def get_by_id_with_tracks_info(cls, id_: int) -> list[dict]:
result = await session.execute(query)
album = result.scalars().one_or_none()
if not album:
raise NoAlbumException(id=id_)
raise AlbumNotFoundException(id=id_)
return album.to_json_with_tracks()


class TrackDAO(BaseDAO):
model = Track
not_found_exception = NoTrackException
not_found_exception = TrackNotFoundException

@classmethod
async def get_with_artist(cls, id_: int) -> Optional[Track]:
Expand All @@ -56,11 +57,14 @@ async def get_with_artist(cls, id_: int) -> Optional[Track]:
return result.scalars().one_or_none()

@classmethod
async def update(cls, track_id: int, update_data: dict):
async def update(cls, track_id: int, update_request: TrackUpdateRequest):
update_data = update_request.model_dump(exclude_unset=True)
async with async_session_maker() as session:
query = select(Track).filter_by(id=track_id).with_for_update()
result = await session.execute(query)
track = result.scalars().one()
track = result.scalars().one_or_none()
if not track:
raise cls.not_found_exception(id=track_id)

for field, value in update_data.items():
setattr(track, field, value)
Expand Down
7 changes: 3 additions & 4 deletions accompanist/collection/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AlbumInfoFromUser,
TrackUpdateRequest,
)
from accompanist.exceptions import NoTrackException
from accompanist.exceptions import TrackNotFoundException

router = APIRouter(
tags=["User's music collection"],
Expand Down Expand Up @@ -40,14 +40,13 @@ async def get_all_alumbs():
async def get_track(track_id: int):
track = await TrackDAO.find_one_or_none(id=track_id)
if not track:
raise NoTrackException(id=track_id)
raise TrackNotFoundException(id=track_id)
return track


@router.patch("/track/{track_id}")
async def update_track(track_id: int, update_request: TrackUpdateRequest):
update_data = update_request.model_dump(exclude_unset=True)
updated_track = await TrackDAO.update(track_id, update_data)
updated_track = await TrackDAO.update(track_id, update_request)
return updated_track


Expand Down
4 changes: 1 addition & 3 deletions accompanist/collection/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ async def update_track_lyrics_by_id(track_id: int):
lyrics, genius_url = await loop.run_in_executor(
None, get_lyrics_from_genius, track.artist.name, track.name
)
update_request = TrackUpdateRequest(
lyrics=lyrics, genius_url=genius_url
).model_dump(exclude_unset=True)
update_request = TrackUpdateRequest(lyrics=lyrics, genius_url=genius_url)
track = await TrackDAO.update(track.id, update_request)
return track
4 changes: 2 additions & 2 deletions accompanist/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from sqlalchemy.exc import SQLAlchemyError

from accompanist.database import async_session_maker
from accompanist.exceptions import NoEntityException
from accompanist.exceptions import EntityNotFoundException


class BaseDAO:
model = None
not_found_exception = NoEntityException
not_found_exception = EntityNotFoundException

@classmethod
async def find_one_or_none(cls, **filter_by):
Expand Down
6 changes: 3 additions & 3 deletions accompanist/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, **kwargs):
)


class NoEntityException(AccompanistException):
class EntityNotFoundException(AccompanistException):
status_code = status.HTTP_404_NOT_FOUND
entity_name = "entity"

Expand All @@ -26,9 +26,9 @@ def _detail(self):
return f"There is no {self.entity_name} with {self.filter_criteria}"


class NoAlbumException(NoEntityException):
class AlbumNotFoundException(EntityNotFoundException):
entity_name = "album"


class NoTrackException(NoEntityException):
class TrackNotFoundException(EntityNotFoundException):
entity_name = "track"
7 changes: 7 additions & 0 deletions accompanist/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def test_data():
for column in row:
if column.endswith("_at"):
row[column] = datetime.strptime(row[column], TEST_DATETIME_FORMAT)
test_data["albums"].sort(key=lambda x: x["id"])
test_data["tracks"].sort(key=lambda x: x["id"])
# TODO: make this dict immutable to prevent accidental modification
return test_data

Expand All @@ -32,6 +34,11 @@ def test_album_data(test_data):
return test_data["albums"]


@pytest.fixture(scope="session")
def test_track_data(test_data):
return test_data["tracks"]


@pytest.fixture(scope="function", autouse=True)
async def prepare_database(test_data):
assert settings.MODE == "TEST"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import pytest

from accompanist.collection.dao import AlbumDAO
from accompanist.exceptions import NoAlbumException
from accompanist.exceptions import AlbumNotFoundException


async def test_get_all_albums():
async def test_get_all_albums(test_album_data):
albums = await AlbumDAO.get_all()
assert len(albums) == 1
assert len(albums) == len(test_album_data)

for album in albums:
test_album = next(
test_album
for test_album in test_album_data
if test_album["id"] == album["id"]
)

assert album["id"] == test_album["id"]
assert album["name"] == test_album["name"]
assert album["cover_path"] == test_album["cover_path"]
assert album["artist"]["id"] == test_album["artist_id"]


@pytest.mark.parametrize(
Expand All @@ -19,7 +31,7 @@ async def test_get_all_albums():
async def test_get_by_id_with_tracks_info(id_, must_exist, test_album_data):
try:
album = await AlbumDAO.get_by_id_with_tracks_info(id_)
except NoAlbumException:
except AlbumNotFoundException:
if not must_exist:
return
if not must_exist:
Expand All @@ -29,6 +41,3 @@ async def test_get_by_id_with_tracks_info(id_, must_exist, test_album_data):
assert album[attr] == album_data[attr]
assert album["artist"]["id"] == album_data["artist_id"]
assert len(album["tracks"]) == 2


# TODO: add more tests for different DAOs
53 changes: 53 additions & 0 deletions accompanist/tests/unit_tests/test_track_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from copy import deepcopy

import pytest

from accompanist.collection.dao import TrackDAO
from accompanist.collection.schema import TrackUpdateRequest
from accompanist.exceptions import TrackNotFoundException


@pytest.mark.parametrize(
"id_,must_exist",
[
(698, True),
(42, False),
],
)
async def test_get_with_artist(id_, must_exist, test_track_data):
track = await TrackDAO.get_with_artist(id_)
if not must_exist:
assert not track
return
test_track = next(track for track in test_track_data if track["id"] == id_)
assert track.id == test_track["id"]


@pytest.mark.parametrize(
"id_,must_exist,update_request",
[
(698, True, TrackUpdateRequest(is_favorite=False)),
(698, True, TrackUpdateRequest(is_favorite=True)),
(698, True, TrackUpdateRequest(lyrics="Na na na", is_favorite=True)),
(698, True, TrackUpdateRequest(lyrics="Na na na", is_favorite=False)),
(698, True, TrackUpdateRequest(lyrics="Na na na", genius_url="genius.com/123")),
(42, False, TrackUpdateRequest(lyrics="Na na na")),
],
)
async def test_update_track(id_, must_exist, update_request, test_track_data):
if must_exist:
track_original_json = (await TrackDAO.get_with_artist(id_)).to_json()
try:
returned_track_json = await TrackDAO.update(id_, update_request)
updated_track_json = (await TrackDAO.get_with_artist(id_)).to_json()
except TrackNotFoundException:
if not must_exist:
return
if not must_exist:
raise AssertionError("This track must not exist")
assert updated_track_json == returned_track_json

updated_track_ground_truth = deepcopy(track_original_json)
updated_track_ground_truth.update(update_request.model_dump(exclude_unset=True))

assert updated_track_json == updated_track_ground_truth

0 comments on commit e6ae572

Please sign in to comment.