Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Depends, status
from fastapi.responses import RedirectResponse
from fastapi.responses import RedirectResponse, Response
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from routers.core import account, dashboard, organization, role, user, static_pages, invitation
from utils.core.dependencies import (
get_optional_user
get_optional_user,
get_user_from_request
)
from exceptions.http_exceptions import (
AuthenticationError,
Expand Down Expand Up @@ -91,21 +92,29 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):

# Handle PasswordValidationError by rendering the validation_error page
@app.exception_handler(PasswordValidationError)
async def password_validation_exception_handler(request: Request, exc: PasswordValidationError):
async def password_validation_exception_handler(
request: Request,
exc: PasswordValidationError
) -> Response:
user = await get_user_from_request(request)
return templates.TemplateResponse(
request,
"errors/validation_error.html",
{
"status_code": 422,
"errors": {"error": exc.detail}
"errors": {"error": exc.detail},
"user": user
},
status_code=422,
)


# Handle RequestValidationError by rendering the validation_error page
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
async def validation_exception_handler(
request: Request,
exc: RequestValidationError
):
errors = {}

# Map error types to user-friendly message templates
Expand All @@ -129,26 +138,28 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
# For JSON body, it might be (body, field_name)
# For array items, it might be (field_name, array_index)
field_name = location[-2] if isinstance(location[-1], int) else location[-1]

# Format the field name to be more user-friendly
display_name = field_name.replace("_", " ").title()

# Use mapped message if available, otherwise use FastAPI's message
error_type = error.get("type", "")
message_template = error_templates.get(error_type, error["msg"])

# For array items, append the index to the message
if isinstance(location[-1], int):
message_template = f"Item {location[-1] + 1}: {message_template}"

errors[display_name] = message_template

user = await get_user_from_request(request)
return templates.TemplateResponse(
request,
"errors/validation_error.html",
{
"status_code": 422,
"errors": errors
"errors": errors,
"user": user
},
status_code=422,
)
Expand All @@ -157,10 +168,11 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
# Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
user = await get_user_from_request(request)
return templates.TemplateResponse(
request,
"errors/error.html",
{"status_code": exc.status_code, "detail": exc.detail},
{"status_code": exc.status_code, "detail": exc.detail, "user": user},
status_code=exc.status_code,
)

Expand All @@ -170,13 +182,15 @@ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
async def general_exception_handler(request: Request, exc: Exception):
# Log the error for debugging
logger.error(f"Unhandled exception: {exc}", exc_info=True)
user = await get_user_from_request(request)

return templates.TemplateResponse(
request,
"errors/error.html",
{
"status_code": 500,
"detail": "Internal Server Error"
"detail": "Internal Server Error",
"user": user
},
status_code=500,
)
Expand Down
29 changes: 27 additions & 2 deletions utils/core/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import Depends, Form
from fastapi import Depends, Form, Request
from pydantic import EmailStr
from sqlmodel import Session, select
from sqlalchemy.orm import selectinload
Expand Down Expand Up @@ -321,4 +321,29 @@ def get_user_with_relations(
)
).one()

return eager_user
return eager_user


async def get_user_from_request(request: Request) -> Optional[User]:
"""
Helper function to get user from request cookies in exception handlers.
Exception handlers can't use Depends(), so we manually extract tokens and get the user.
"""
access_token = request.cookies.get("access_token")
refresh_token = request.cookies.get("refresh_token")
tokens = (access_token, refresh_token)

# Get a database session
engine = create_engine(get_connection_url())
with Session(engine) as session:
user, new_access_token, new_refresh_token = get_user_from_tokens(tokens, session)

# If we got new tokens, we'd normally raise NeedsNewTokens, but in an exception
# handler we can't do that easily. For now, just return the user.
# The tokens will be refreshed on the next request.
if user and new_access_token and new_refresh_token:
# Note: We can't easily set cookies here since we're in an exception handler.
# The user will need to make another request to get new tokens.
pass

return user