-
Notifications
You must be signed in to change notification settings - Fork 43
Add Litestar framework #483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,3 +85,4 @@ uvicorn==0.18.2 | |
Werkzeug==2.0.3 | ||
wrapt==1.13.3 | ||
zipp==3.7.0 | ||
litestar==2.8.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from __future__ import annotations | ||
from typing import Any, Callable, Coroutine, TYPE_CHECKING | ||
|
||
from litestar import Request | ||
|
||
from supertokens_python.framework.litestar.litestar_request import LitestarRequest | ||
from .litestar_middleware import get_middleware | ||
|
||
if TYPE_CHECKING: | ||
from ...recipe.session import SessionRecipe, SessionContainer | ||
from ...recipe.session.interfaces import SessionClaimValidator | ||
from ...types import MaybeAwaitable | ||
|
||
__all__ = ['get_middleware'] | ||
|
||
|
||
def verify_session( | ||
anti_csrf_check: bool | None = None, | ||
session_required: bool = True, | ||
override_global_claim_validators: Callable[ | ||
[list[SessionClaimValidator], SessionContainer, dict[str, Any]], | ||
MaybeAwaitable[list[SessionClaimValidator]], | ||
] | ||
| None = None, | ||
user_context: None | dict[str, Any] = None, | ||
) -> Callable[..., Coroutine[Any, Any, SessionContainer | None]]: | ||
async def func(request: Request[Any, Any, Any]) -> SessionContainer | None: | ||
custom_request = LitestarRequest(request) | ||
recipe = SessionRecipe.get_instance() | ||
session = await recipe.verify_session( | ||
custom_request, | ||
anti_csrf_check, | ||
session_required, | ||
user_context=user_context or {} | ||
) | ||
|
||
if session: | ||
custom_request.set_session(session) | ||
elif session_required: | ||
raise RuntimeError("Should never come here") | ||
else: | ||
custom_request.set_session_as_none() | ||
|
||
return custom_request.get_session() | ||
|
||
return func |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from supertokens_python.framework.types import Framework | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Request | ||
|
||
|
||
class LitestarFramework(Framework): | ||
def wrap_request(self, unwrapped: Request[Any, Any, Any]): | ||
from supertokens_python.framework.litestar.litestar_request import ( | ||
LitestarRequest, | ||
) | ||
|
||
return LitestarRequest(unwrapped) |
57 changes: 57 additions & 0 deletions
57
supertokens_python/framework/litestar/litestar_middleware.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
from functools import lru_cache | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from litestar.response.base import ASGIResponse | ||
|
||
if TYPE_CHECKING: | ||
from litestar.middleware.base import AbstractMiddleware | ||
|
||
|
||
@lru_cache | ||
def get_middleware() -> type[AbstractMiddleware]: | ||
from supertokens_python import Supertokens | ||
from supertokens_python.exceptions import SuperTokensError | ||
from supertokens_python.framework.litestar.litestar_request import LitestarRequest | ||
from supertokens_python.framework.litestar.litestar_response import LitestarResponse | ||
from supertokens_python.recipe.session import SessionContainer | ||
from supertokens_python.supertokens import manage_session_post_response | ||
from supertokens_python.utils import default_user_context | ||
|
||
from litestar import Response, Request | ||
from litestar.middleware.base import AbstractMiddleware | ||
from litestar.types import Scope, Receive, Send | ||
|
||
class SupertokensMiddleware(AbstractMiddleware): | ||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
st = Supertokens.get_instance() | ||
request = Request[Any, Any, Any](scope, receive, send) | ||
user_context = default_user_context(request) | ||
|
||
try: | ||
result = await st.middleware( | ||
LitestarRequest(request), | ||
LitestarResponse(Response[Any](content=None)), | ||
user_context | ||
) | ||
except SuperTokensError as e: | ||
result = await st.handle_supertokens_error( | ||
LitestarRequest(request), | ||
e, | ||
LitestarResponse(Response[Any](content=None)), | ||
user_context | ||
) | ||
|
||
if isinstance(result, LitestarResponse): | ||
if ( | ||
session_container := request.state.get("supertokens") | ||
) and isinstance(session_container, SessionContainer): | ||
manage_session_post_response(session_container, result, user_context) | ||
|
||
await result.response.to_asgi_response(app=None, request=request)(scope, receive, send) | ||
return | ||
|
||
await self.app(scope, receive, send) | ||
|
||
return SupertokensMiddleware |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from supertokens_python.framework.request import BaseRequest | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Request | ||
from supertokens_python.recipe.session.interfaces import SessionContainer | ||
|
||
try: | ||
from litestar.exceptions import SerializationException | ||
except ImportError: | ||
SerializationException = Exception # type: ignore | ||
|
||
|
||
class LitestarRequest(BaseRequest): | ||
def __init__(self, request: Request[Any, Any, Any]): | ||
super().__init__() | ||
self.request = request | ||
|
||
def get_original_url(self) -> str: | ||
return self.request.url | ||
|
||
def get_query_param(self, key: str, default: str | None = None) -> Any: | ||
return self.request.query_params.get(key, default) # pyright: ignore | ||
|
||
def get_query_params(self) -> dict[str, list[Any]]: | ||
return self.request.query_params.dict() # pyright: ignore | ||
|
||
async def json(self) -> Any: | ||
try: | ||
return await self.request.json() | ||
except SerializationException: | ||
return {} | ||
|
||
def method(self) -> str: | ||
return self.request.method | ||
|
||
def get_cookie(self, key: str) -> str | None: | ||
return self.request.cookies.get(key) | ||
|
||
def get_header(self, key: str) -> str | None: | ||
return self.request.headers.get(key, None) | ||
|
||
def get_session(self) -> SessionContainer | None: | ||
return self.request.state.supertokens | ||
|
||
def set_session(self, session: SessionContainer): | ||
self.request.state.supertokens = session | ||
|
||
def set_session_as_none(self): | ||
self.request.state.supertokens = None | ||
|
||
def get_path(self) -> str: | ||
return self.request.url.path | ||
|
||
async def form_data(self) -> dict[str, list[Any]]: | ||
return (await self.request.form()).dict() |
72 changes: 72 additions & 0 deletions
72
supertokens_python/framework/litestar/litestar_response.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import annotations | ||
from typing import Any, TYPE_CHECKING, cast | ||
from typing_extensions import Literal | ||
from supertokens_python.framework.response import BaseResponse | ||
|
||
if TYPE_CHECKING: | ||
from litestar import Response | ||
|
||
|
||
class LitestarResponse(BaseResponse): | ||
def __init__(self, response: Response[Any]): | ||
super().__init__({}) | ||
self.response = response | ||
self.original = response | ||
self.parser_checked = False | ||
self.response_sent = False | ||
self.status_set = False | ||
|
||
def set_html_content(self, content: str): | ||
if not self.response_sent: | ||
body = bytes(content, "utf-8") | ||
self.set_header("Content-Length", str(len(body))) | ||
self.set_header("Content-Type", "text/html") | ||
self.response.content = body | ||
self.response_sent = True | ||
|
||
def set_cookie( | ||
self, | ||
key: str, | ||
value: str, | ||
expires: int, | ||
path: str = "/", | ||
domain: str | None = None, | ||
secure: bool = False, | ||
httponly: bool = False, | ||
samesite: str = "lax", | ||
): | ||
self.response.set_cookie( | ||
key=key, | ||
value=value, | ||
expires=expires, | ||
path=path, | ||
domain=domain, | ||
secure=secure, | ||
httponly=httponly, | ||
samesite=cast(Literal["lax", "strict", "none"], samesite), | ||
) | ||
|
||
def set_header(self, key: str, value: str): | ||
self.response.set_header(key, value) | ||
|
||
def get_header(self, key: str) -> str | None: | ||
return self.response.headers.get(key, None) | ||
|
||
def remove_header(self, key: str): | ||
del self.response.headers[key] | ||
|
||
def set_status_code(self, status_code: int): | ||
if not self.status_set: | ||
self.response.status_code = status_code | ||
self.status_code = status_code | ||
self.status_set = True | ||
|
||
def set_json_content(self, content: dict[str, Any]): | ||
if not self.response_sent: | ||
from litestar.serialization import encode_json | ||
|
||
body = encode_json(content) | ||
self.set_header("Content-Type", "application/json; charset=utf-8") | ||
self.set_header("Content-Length", str(len(body))) | ||
self.response.content = body | ||
self.response_sent = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import nest_asyncio # type: ignore | ||
|
||
nest_asyncio.apply() # type: ignore |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
SessionRecipe
class is imported inside theTYPE_CHECKING
block, but it's actually used at runtime in theverify_session
function. This will cause a runtime error when the function is called. Please addfrom ...recipe.session import SessionRecipe
outside theTYPE_CHECKING
block to make it available during execution.Spotted by Diamond
Is this helpful? React 👍 or 👎 to let us know.