Skip to content

Commit db8d4a8

Browse files
committed
fix: ensure logout button renders in exception handlers by passing user context
Add get_user_from_request helper function to extract user from cookies in exception handlers where Depends() cannot be used. Update all exception handlers (password validation, request validation, HTTP exceptions, and general exceptions) to retrieve and pass user to templates, enabling logout button visibility on error pages. Fixes #141
1 parent 13a6f30 commit db8d4a8

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

main.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from contextlib import asynccontextmanager
44
from dotenv import load_dotenv
55
from fastapi import FastAPI, Request, Depends, status
6-
from fastapi.responses import RedirectResponse
6+
from fastapi.responses import RedirectResponse, Response
77
from fastapi.staticfiles import StaticFiles
88
from fastapi.templating import Jinja2Templates
99
from fastapi.exceptions import RequestValidationError
1010
from starlette.exceptions import HTTPException as StarletteHTTPException
1111
from routers.core import account, dashboard, organization, role, user, static_pages, invitation
1212
from utils.core.dependencies import (
13-
get_optional_user
13+
get_optional_user,
14+
get_user_from_request
1415
)
1516
from exceptions.http_exceptions import (
1617
AuthenticationError,
@@ -91,21 +92,29 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):
9192

9293
# Handle PasswordValidationError by rendering the validation_error page
9394
@app.exception_handler(PasswordValidationError)
94-
async def password_validation_exception_handler(request: Request, exc: PasswordValidationError):
95+
async def password_validation_exception_handler(
96+
request: Request,
97+
exc: PasswordValidationError
98+
) -> Response:
99+
user = await get_user_from_request(request)
95100
return templates.TemplateResponse(
96101
request,
97102
"errors/validation_error.html",
98103
{
99104
"status_code": 422,
100-
"errors": {"error": exc.detail}
105+
"errors": {"error": exc.detail},
106+
"user": user
101107
},
102108
status_code=422,
103109
)
104110

105111

106112
# Handle RequestValidationError by rendering the validation_error page
107113
@app.exception_handler(RequestValidationError)
108-
async def validation_exception_handler(request: Request, exc: RequestValidationError):
114+
async def validation_exception_handler(
115+
request: Request,
116+
exc: RequestValidationError
117+
):
109118
errors = {}
110119

111120
# Map error types to user-friendly message templates
@@ -129,26 +138,28 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
129138
# For JSON body, it might be (body, field_name)
130139
# For array items, it might be (field_name, array_index)
131140
field_name = location[-2] if isinstance(location[-1], int) else location[-1]
132-
141+
133142
# Format the field name to be more user-friendly
134143
display_name = field_name.replace("_", " ").title()
135-
144+
136145
# Use mapped message if available, otherwise use FastAPI's message
137146
error_type = error.get("type", "")
138147
message_template = error_templates.get(error_type, error["msg"])
139-
148+
140149
# For array items, append the index to the message
141150
if isinstance(location[-1], int):
142151
message_template = f"Item {location[-1] + 1}: {message_template}"
143-
152+
144153
errors[display_name] = message_template
145154

155+
user = await get_user_from_request(request)
146156
return templates.TemplateResponse(
147157
request,
148158
"errors/validation_error.html",
149159
{
150160
"status_code": 422,
151-
"errors": errors
161+
"errors": errors,
162+
"user": user
152163
},
153164
status_code=422,
154165
)
@@ -157,10 +168,11 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
157168
# Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page
158169
@app.exception_handler(StarletteHTTPException)
159170
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
171+
user = await get_user_from_request(request)
160172
return templates.TemplateResponse(
161173
request,
162174
"errors/error.html",
163-
{"status_code": exc.status_code, "detail": exc.detail},
175+
{"status_code": exc.status_code, "detail": exc.detail, "user": user},
164176
status_code=exc.status_code,
165177
)
166178

@@ -170,13 +182,15 @@ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
170182
async def general_exception_handler(request: Request, exc: Exception):
171183
# Log the error for debugging
172184
logger.error(f"Unhandled exception: {exc}", exc_info=True)
185+
user = await get_user_from_request(request)
173186

174187
return templates.TemplateResponse(
175188
request,
176189
"errors/error.html",
177190
{
178191
"status_code": 500,
179-
"detail": "Internal Server Error"
192+
"detail": "Internal Server Error",
193+
"user": user
180194
},
181195
status_code=500,
182196
)

utils/core/dependencies.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import Depends, Form
1+
from fastapi import Depends, Form, Request
22
from pydantic import EmailStr
33
from sqlmodel import Session, select
44
from sqlalchemy.orm import selectinload
@@ -321,4 +321,29 @@ def get_user_with_relations(
321321
)
322322
).one()
323323

324-
return eager_user
324+
return eager_user
325+
326+
327+
async def get_user_from_request(request: Request) -> Optional[User]:
328+
"""
329+
Helper function to get user from request cookies in exception handlers.
330+
Exception handlers can't use Depends(), so we manually extract tokens and get the user.
331+
"""
332+
access_token = request.cookies.get("access_token")
333+
refresh_token = request.cookies.get("refresh_token")
334+
tokens = (access_token, refresh_token)
335+
336+
# Get a database session
337+
engine = create_engine(get_connection_url())
338+
with Session(engine) as session:
339+
user, new_access_token, new_refresh_token = get_user_from_tokens(tokens, session)
340+
341+
# If we got new tokens, we'd normally raise NeedsNewTokens, but in an exception
342+
# handler we can't do that easily. For now, just return the user.
343+
# The tokens will be refreshed on the next request.
344+
if user and new_access_token and new_refresh_token:
345+
# Note: We can't easily set cookies here since we're in an exception handler.
346+
# The user will need to make another request to get new tokens.
347+
pass
348+
349+
return user

0 commit comments

Comments
 (0)