Skip to content

Commit 64a6687

Browse files
committed
pass db session for testcase
with depend db session, it will be override from conftest
1 parent 1cd7819 commit 64a6687

File tree

5 files changed

+51
-43
lines changed

5 files changed

+51
-43
lines changed

app/api/v1/user.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3-
from fastapi import APIRouter, Request, status
3+
from fastapi import APIRouter, Request, status, Depends
44
from fastapi.encoders import jsonable_encoder
55
from fastapi.responses import JSONResponse
66
from fastapi_pundra.rest.helpers import the_query
77
from fastapi_pundra.rest.validation import dto
8-
8+
from sqlalchemy.orm import Session
9+
from app.database.database import get_db_session
910
from app.schemas.user_schema import UserCreateSchema, UserUpdateSchema
1011
from app.services.user_service import UserService
1112

@@ -19,47 +20,53 @@
1920
# Registration route
2021
@router.post("/users/registration")
2122
@dto(UserCreateSchema)
22-
async def registration(request: Request) -> JSONResponse:
23+
async def registration(request: Request, db: Session = Depends(get_db_session)) -> JSONResponse:
2324
"""Register a new user."""
2425
# Retrieve data from the request
2526
request_data = await the_query(request)
2627
data = UserCreateSchema(**request_data)
2728

28-
output = await user_service.s_registration(request, data)
29+
output = await user_service.s_registration(request, db, data)
2930
return JSONResponse(content=output, status_code=status.HTTP_201_CREATED)
3031

3132

3233
@router.post("/users/login")
33-
async def login(request: Request) -> JSONResponse:
34+
async def login(request: Request, db: Session = Depends(get_db_session)) -> JSONResponse:
3435
"""Login a user."""
35-
data = await user_service.s_login(request)
36+
data = await user_service.s_login(request, db)
3637
return JSONResponse(content=jsonable_encoder(data), status_code=status.HTTP_200_OK)
3738

3839

3940
@router.get("/users")
40-
async def get_users(request: Request) -> JSONResponse:
41+
async def get_users(request: Request, db: Session = Depends(get_db_session)) -> JSONResponse:
4142
"""Get all users."""
42-
data = await user_service.s_get_users(request)
43+
data = await user_service.s_get_users(request, db)
4344
return JSONResponse(content=jsonable_encoder(data), status_code=status.HTTP_200_OK)
4445

4546

4647
@router.get("/users/{user_id}")
47-
async def get_user(request: Request, user_id: int | str) -> JSONResponse:
48+
async def get_user(
49+
request: Request, user_id: int | str, db: Session = Depends(get_db_session)
50+
) -> JSONResponse:
4851
"""Get a user by id."""
49-
data = await user_service.s_get_user_by_id(request, user_id=user_id)
52+
data = await user_service.s_get_user_by_id(request, db, user_id=user_id)
5053
return JSONResponse(content=jsonable_encoder(data), status_code=status.HTTP_200_OK)
5154

5255

5356
@router.put("/users/{user_id}/update")
5457
@dto(UserUpdateSchema)
55-
async def update_user(request: Request, user_id: int | str) -> JSONResponse:
58+
async def update_user(
59+
request: Request, user_id: int | str, db: Session = Depends(get_db_session)
60+
) -> JSONResponse:
5661
"""Update a user by id."""
57-
data = await user_service.s_update_user(request, user_id=user_id)
62+
data = await user_service.s_update_user(request, db, user_id=user_id)
5863
return JSONResponse(content=jsonable_encoder(data), status_code=status.HTTP_200_OK)
5964

6065

6166
@router.delete("/users/{user_id}/delete")
62-
async def delete_user(request: Request, user_id: int | str) -> JSONResponse:
67+
async def delete_user(
68+
request: Request, user_id: int | str, db: Session = Depends(get_db_session)
69+
) -> JSONResponse:
6370
"""Delete a user by id."""
64-
data = await user_service.s_delete_user(request, user_id=user_id)
71+
data = await user_service.s_delete_user(request, db, user_id=user_id)
6572
return JSONResponse(content=jsonable_encoder(data), status_code=status.HTTP_200_OK)

app/database/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Database module."""
22

3-
from app.database.database import Base, engine, get_db
3+
from app.database.database import Base, engine, get_db_session

app/database/database.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
2+
from collections.abc import Generator
33
from dotenv import load_dotenv
44
from sqlalchemy import create_engine
55
from sqlalchemy.orm import Session, declarative_base, sessionmaker
@@ -29,6 +29,10 @@
2929

3030

3131
# Dependency function to get a database session
32-
def get_db() -> Session:
32+
def get_db_session() -> Generator[Session, None, None]:
3333
"""Get a database session."""
34-
return SessionLocal()
34+
db = SessionLocal()
35+
try:
36+
yield db
37+
finally:
38+
db.close()

app/services/user_service.py

+19-24
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from fastapi_pundra.rest.paginate import paginate
1111
from sqlalchemy.orm import Session
1212

13-
from app.database import get_db
1413
from app.models.users import User
1514
from app.schemas.user_schema import UserCreateSchema
1615
from app.serializers.user_serializer import UserLoginSerializer, UserSerializer
@@ -19,13 +18,9 @@
1918
class UserService:
2019
"""User service."""
2120

22-
def __init__(self, db: Session | None = None) -> None:
23-
"""Initialize the user service."""
24-
self.db = db or get_db()
25-
26-
async def s_registration(self, request: Request, data: UserCreateSchema) -> dict:
21+
async def s_registration(self, request: Request, db: Session, data: UserCreateSchema) -> dict:
2722
"""Register a new user."""
28-
db_user = self.db.query(User).filter(User.email == data.email).first()
23+
db_user = db.query(User).filter(User.email == data.email).first()
2924

3025
if db_user:
3126
raise BaseAPIException(message="Email already registered", status_code=400)
@@ -36,9 +31,9 @@ async def s_registration(self, request: Request, data: UserCreateSchema) -> dict
3631
new_user.name = data.name
3732
new_user.status = "active"
3833

39-
self.db.add(new_user)
40-
self.db.commit()
41-
self.db.refresh(new_user)
34+
db.add(new_user)
35+
db.commit()
36+
db.refresh(new_user)
4237

4338
user_data = UserSerializer(**new_user.as_dict())
4439

@@ -59,9 +54,9 @@ async def s_registration(self, request: Request, data: UserCreateSchema) -> dict
5954

6055
return {"message": "Registration successful", "user": user_data.model_dump()}
6156

62-
async def s_get_users(self, request: Request) -> dict:
57+
async def s_get_users(self, request: Request, db: Session) -> dict:
6358
"""Get users."""
64-
query = self.db.query(User)
59+
query = db.query(User)
6560

6661
# TODO: add logic here if you want to filter users
6762

@@ -80,22 +75,22 @@ def additional_data(data: list) -> dict:
8075
additional_data=additional_data,
8176
)
8277

83-
async def s_get_user_by_id(self, request: Request, user_id: str) -> User:
78+
async def s_get_user_by_id(self, request: Request, db: Session, user_id: str) -> User:
8479
"""Get user by id."""
85-
user = self.db.query(User).filter(User.id == user_id).first()
80+
user = db.query(User).filter(User.id == user_id).first()
8681
if user is None:
8782
raise ItemNotFoundException(message="User not found")
8883
return user
8984

90-
async def s_login(self, request: Request) -> dict:
85+
async def s_login(self, request: Request, db: Session) -> dict:
9186
"""Login a user."""
9287
# Get data from request
9388
the_data = await the_query(request)
9489
email = the_data.get("email")
9590
password = the_data.get("password")
9691

9792
# Find user by email
98-
user = self.db.query(User).filter(User.email == email).first()
93+
user = db.query(User).filter(User.email == email).first()
9994
if not user:
10095
raise UnauthorizedException(message="Invalid credentials")
10196

@@ -125,10 +120,10 @@ async def s_login(self, request: Request) -> dict:
125120
"refresh_token": refresh_token,
126121
}
127122

128-
async def s_update_user(self, request: Request, user_id: str) -> dict:
123+
async def s_update_user(self, request: Request, db: Session, user_id: str) -> dict:
129124
"""Update a user."""
130125
the_data = await the_query(request)
131-
user = self.db.query(User).filter(User.id == user_id).first()
126+
user = db.query(User).filter(User.id == user_id).first()
132127

133128
if not user:
134129
raise ItemNotFoundException(message="User not found")
@@ -139,17 +134,17 @@ async def s_update_user(self, request: Request, user_id: str) -> dict:
139134
user.email = the_data.get("email")
140135
if the_data.get("password"):
141136
user.password = generate_password_hash(the_data.get("password"))
142-
self.db.commit()
143-
self.db.refresh(user)
137+
db.commit()
138+
db.refresh(user)
144139

145140
user_data = UserSerializer(**user.as_dict())
146141
return {"message": "User updated successfully", "user": user_data.model_dump()}
147142

148-
async def s_delete_user(self, request: Request, user_id: str) -> dict:
143+
async def s_delete_user(self, request: Request, db: Session, user_id: str) -> dict:
149144
"""Delete a user."""
150-
user = self.db.query(User).filter(User.id == user_id).first()
145+
user = db.query(User).filter(User.id == user_id).first()
151146
if not user:
152147
raise ItemNotFoundException(message="User not found")
153-
self.db.delete(user)
154-
self.db.commit()
148+
db.delete(user)
149+
db.commit()
155150
return {"message": "User deleted successfully"}

ruff.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ ignore = [
5252
"FIX002", # Line contains TODO, consider resolving the issue
5353
"COM812", # Missing trailing comma in Python 3.6+
5454
"N805", # Instance method first argument name should be 'self'
55-
"ERA001",
55+
"ERA001", # Found commented-out code
56+
"FAST002", # FastAPI dependency injection error annotation
57+
"B008", # Do not perform function call in argument defaults
5658
]
5759

5860
# Allow unused variables when underscore-prefixed

0 commit comments

Comments
 (0)