Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sabuhibrahim committed Oct 22, 2023
1 parent 5326cb4 commit 7c0766e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 40 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ FROM python:3.11

WORKDIR /code

COPY . /code
COPY ./requirements.txt ./requirements.txt

RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt

COPY . .

RUN chmod 755 /code/start.sh

CMD ["sh", "start.sh"]
32 changes: 16 additions & 16 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ services:
networks:
- default

backend:
container_name: backend
build:
context: .
dockerfile: ./Dockerfile
restart: always
env_file:
- .env
ports:
- "3000:80"
links:
- postgres
depends_on:
- postgres
networks:
- default
# backend:
# container_name: backend
# build:
# context: .
# dockerfile: ./Dockerfile
# restart: always
# env_file:
# - .env
# ports:
# - "3000:80"
# links:
# - postgres
# depends_on:
# - postgres
# networks:
# - default

networks:
default:
Expand Down
40 changes: 27 additions & 13 deletions src/core/jwt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from datetime import timedelta, datetime
from datetime import timedelta, datetime, timezone
from jose import jwt, JWTError
from fastapi import Response
from sqlalchemy.ext.asyncio import AsyncSession

from . import config
Expand All @@ -9,13 +10,19 @@
from src.models import BlackListToken


REFRESH_COOKIE_NAME = "refresh"
SUB = "sub"
EXP = "exp"
IAT = "iat"
JTI = "jti"


def _create_access_token(payload: dict, minutes: int | None = None) -> JwtTokenSchema:
expire = datetime.utcnow() + timedelta(
minutes=minutes or config.ACCESS_TOKEN_EXPIRES_MINUTES
)

payload["exp"] = expire
payload["frs"] = False
payload[EXP] = expire

token = JwtTokenSchema(
token=jwt.encode(payload, config.SECRET_KEY, algorithm=config.ALGORITHM),
Expand All @@ -29,8 +36,7 @@ def _create_access_token(payload: dict, minutes: int | None = None) -> JwtTokenS
def _create_refresh_token(payload: dict) -> JwtTokenSchema:
expire = datetime.utcnow() + timedelta(minutes=config.REFRESH_TOKEN_EXPIRES_MINUTES)

payload["exp"] = expire
payload["frs"] = True
payload[EXP] = expire

token = JwtTokenSchema(
token=jwt.encode(payload, config.SECRET_KEY, algorithm=config.ALGORITHM),
Expand All @@ -42,7 +48,7 @@ def _create_refresh_token(payload: dict) -> JwtTokenSchema:


def create_token_pair(user: User) -> TokenPair:
payload = {"sub": str(user.id), "name": user.full_name, "jti": str(uuid.uuid4())}
payload = {SUB: str(user.id), JTI: str(uuid.uuid4()), IAT: datetime.utcnow()}

return TokenPair(
access=_create_access_token(payload={**payload}),
Expand All @@ -53,9 +59,7 @@ def create_token_pair(user: User) -> TokenPair:
async def decode_access_token(token: str, db: AsyncSession):
try:
payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
if payload.get("frs"):
raise JWTError("Access token need")
black_list_token = await BlackListToken.find_by_id(db=db, id=payload["jti"])
black_list_token = await BlackListToken.find_by_id(db=db, id=payload[JTI])
if black_list_token:
raise JWTError("Token is blacklisted")
except JWTError:
Expand All @@ -67,16 +71,26 @@ async def decode_access_token(token: str, db: AsyncSession):
def refresh_token_state(token: str):
try:
payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
if not payload.get("frs"):
raise JWTError("Refresh token need")
except JWTError as ex:
print(str(ex))
raise AuthFailedException()

return {"access": _create_access_token(payload=payload).token}
return {"token": _create_access_token(payload=payload).token}


def mail_token(user: User):
"""Return 2 hour lifetime access_token"""
payload = {"sub": str(user.id), "name": user.full_name, "jti": str(uuid.uuid4())}
payload = {SUB: str(user.id), JTI: str(uuid.uuid4()), IAT: datetime.utcnow()}
return _create_access_token(payload=payload, minutes=2 * 60).token


def add_refresh_token_cookie(response: Response, token: str):
exp = datetime.utcnow() + timedelta(minutes=config.REFRESH_TOKEN_EXPIRES_MINUTES)
exp.replace(tzinfo=timezone.utc)

response.set_cookie(
key="refresh",
value=token,
expires=int(exp.timestamp()),
httponly=True,
)
30 changes: 20 additions & 10 deletions src/routers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from datetime import datetime

from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks, Response, Cookie
from fastapi.exceptions import RequestValidationError
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -16,6 +16,10 @@
refresh_token_state,
decode_access_token,
mail_token,
add_refresh_token_cookie,
SUB,
JTI,
EXP,
)
from src.exceptions import BadRequestException, NotFoundException, ForbiddenException
from src.tasks import (
Expand Down Expand Up @@ -61,6 +65,7 @@ async def register(
@router.post("/login")
async def login(
data: schemas.UserLogin,
response: Response,
db: AsyncSession = Depends(get_db),
):
user = await models.User.authenticate(
Expand All @@ -77,18 +82,23 @@ async def login(

token_pair = create_token_pair(user=user)

return {"access": token_pair.access.token, "refresh": token_pair.refresh.token}
add_refresh_token_cookie(response=response, token=token_pair.refresh.token)

return {"token": token_pair.access.token}


@router.post("/refresh")
async def refresh(data: schemas.RefreshToken):
return refresh_token_state(data.refresh)
async def refresh(refresh: Annotated[str | None, Cookie()] = None):
print(refresh)
if not refresh:
raise BadRequestException(detail="refresh token required")
return refresh_token_state(token=refresh)


@router.get("/verify", response_model=schemas.SuccessResponseScheme)
async def verify(token: str, db: AsyncSession = Depends(get_db)):
payload = await decode_access_token(token=token, db=db)
user = await models.User.find_by_id(db=db, id=payload["sub"])
user = await models.User.find_by_id(db=db, id=payload[SUB])
if not user:
raise NotFoundException(detail="User not found")

Expand All @@ -104,7 +114,7 @@ async def logout(
):
payload = await decode_access_token(token=token, db=db)
black_listed = models.BlackListToken(
id=payload["jti"], expire=datetime.utcfromtimestamp(payload["exp"])
id=payload[JTI], expire=datetime.utcfromtimestamp(payload[EXP])
)
await black_listed.save(db=db)

Expand Down Expand Up @@ -138,7 +148,7 @@ async def password_reset_token(
db: AsyncSession = Depends(get_db),
):
payload = await decode_access_token(token=token, db=db)
user = await models.User.find_by_id(db=db, id=payload["sub"])
user = await models.User.find_by_id(db=db, id=payload[SUB])
if not user:
raise NotFoundException(detail="User not found")

Expand All @@ -155,7 +165,7 @@ async def password_update(
db: AsyncSession = Depends(get_db),
):
payload = await decode_access_token(token=token, db=db)
user = await models.User.find_by_id(db=db, id=payload["sub"])
user = await models.User.find_by_id(db=db, id=payload[SUB])
if not user:
raise NotFoundException(detail="User not found")

Expand All @@ -177,7 +187,7 @@ async def articles(
db: AsyncSession = Depends(get_db),
):
payload = await decode_access_token(token=token, db=db)
user = await models.User.find_by_id(db=db, id=payload["sub"])
user = await models.User.find_by_id(db=db, id=payload[SUB])
if not user:
raise NotFoundException(detail="User not found")

Expand All @@ -193,7 +203,7 @@ async def articles(
db: AsyncSession = Depends(get_db),
):
payload = await decode_access_token(token=token, db=db)
user = await models.User.find_by_id(db=db, id=payload["sub"])
user = await models.User.find_by_id(db=db, id=payload[SUB])
if not user:
raise NotFoundException(detail="User not found")

Expand Down

0 comments on commit 7c0766e

Please sign in to comment.