Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/django-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install mypy flake8 pytest pytest-xdist flaky
python -m pip install mypy djangorestframework-stubs flake8 pytest pytest-xdist flaky
if [ -f testproject/requirements.txt ]; then pip install -r testproject/requirements.txt; fi
ln -s $(pwd)/trench/ $(pwd)/testproject/trench
- name: Lint trench package with flake8
Expand Down
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/ambv/black
rev: 23.1.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: end-of-file-fixer
- id: check-merge-conflict
- id: mixed-line-ending
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
language_version: python3
args: ['--select=E9,F63,F7,F82']
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Local development

.. code-block:: shell

pip install black mypy
pip install black mypy djangorestframework-stubs
pip install -r testproject/requirements.txt

5. Set environment variables:
Expand Down
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
from typing import Dict

import sphinx_rtd_theme


Expand Down Expand Up @@ -114,7 +116,7 @@

# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
latex_elements: Dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
Expand Down
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
[mypy]
ignore_missing_imports = True
plugins =
mypy_django_plugin.main,
mypy_drf_plugin.main

[mypy.plugins.django-stubs]
django_settings_module = "testproject.settings"

[flake8]
inline-quotes = "
Expand Down
2 changes: 1 addition & 1 deletion testproject/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ALLOWED_HOSTS = env.list("ALLOWED_HOSTS", default=["*"])
CORS_ORIGIN_ALLOW_ALL = env.bool("CORS_ORIGIN_ALLOW_ALL", default=False)
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
STATIC_ROOT= os.path.join(BASE_DIR, 'static/')
STATIC_ROOT = os.path.join(BASE_DIR, "static/")

INSTALLED_APPS = [
"django.contrib.admin",
Expand Down
12 changes: 5 additions & 7 deletions testproject/tests/test_add_mfa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser

from flaky import flaky
Expand All @@ -12,9 +11,6 @@
from trench.command.create_secret import create_secret_command


User: AbstractUser = get_user_model()


@pytest.mark.django_db
def test_add_user_mfa(active_user):
client = TrenchAPIClient()
Expand All @@ -33,11 +29,13 @@ def test_add_user_mfa(active_user):


@pytest.mark.django_db
def test_should_fail_on_add_user_mfa_with_invalid_source_field(active_user: User):
def test_should_fail_on_add_user_mfa_with_invalid_source_field(
active_user: AbstractUser,
):
client = TrenchAPIClient()
client.authenticate(user=active_user)
secret = create_secret_command()
settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email_test"
settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email_test" # type: ignore[index]

response = client.post(
path="/auth/email/activate/",
Expand All @@ -53,7 +51,7 @@ def test_should_fail_on_add_user_mfa_with_invalid_source_field(active_user: User
response.data.get("error")
== "Field name `email_test` is not valid for model `User`."
)
settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email"
settings.TRENCH_AUTH["MFA_METHODS"]["email"]["SOURCE_FIELD"] = "email" # type: ignore[index]


@flaky
Expand Down
15 changes: 9 additions & 6 deletions testproject/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def test_method_handler_missing_error():
assert settings.MFA_METHODS["method_without_handler"] is None


def test_code_missing_error():
validator = ProtectedActionValidator(mfa_method_name="yubi", user=None)
@pytest.mark.django_db
def test_code_missing_error(active_user):
validator = ProtectedActionValidator(mfa_method_name="yubi", user=active_user)
with pytest.raises(OTPCodeMissingError):
validator.validate_code(value="")

Expand All @@ -44,15 +45,17 @@ def test_request_body_validator():
validator.update(instance=MFAMethod(), validated_data=OrderedDict())


def test_protected_action_validator():
validator = ProtectedActionValidator(mfa_method_name="yubi", user=None)
@pytest.mark.django_db
def test_protected_action_validator(active_user):
validator = ProtectedActionValidator(mfa_method_name="yubi", user=active_user)
with pytest.raises(NotImplementedError):
validator._validate_mfa_method(mfa=MFAMethod())


def test_mfa_method_activation_validator():
@pytest.mark.django_db
def test_mfa_method_activation_validator(active_user):
validator = MFAMethodActivationConfirmationValidator(
mfa_method_name="yubi", user=None
mfa_method_name="yubi", user=active_user
)
with pytest.raises(MFAMethodAlreadyActiveError):
validator._validate_mfa_method(mfa=MFAMethod(is_active=True))
Expand Down
6 changes: 5 additions & 1 deletion testproject/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import pytest

from trench.backends.application import ApplicationMessageDispatcher
Expand All @@ -23,7 +25,9 @@ def test_invalid_token():
def test_create_qr_link(active_user_with_many_otp_methods):
user, _ = active_user_with_many_otp_methods
mfa_method: MFAMethod = user.mfa_methods.filter(name="app").first()
handler: ApplicationMessageDispatcher = get_mfa_handler(mfa_method)
handler: ApplicationMessageDispatcher = cast(
ApplicationMessageDispatcher, get_mfa_handler(mfa_method)
)
qr_link = handler._create_qr_link(user=user)
assert type(qr_link) == str
assert user.username in qr_link
Expand Down
9 changes: 6 additions & 3 deletions testproject/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ def _second_factor_request(
code: Optional[str] = None,
path: str = PATH_AUTH_JWT_LOGIN_CODE,
) -> Response:
if handler is None and code is None:
raise ValueError("handler and code can't be None simultaneously")
if code is None:
if handler is None:
raise ValueError("handler and code can't be None simultaneously")
else:
code = handler.create_code()
return self.post(
path=path,
data={
"ephemeral_token": ephemeral_token,
"code": handler.create_code() if code is None else code, # type: ignore
"code": code,
},
format="json",
)
Expand Down
6 changes: 4 additions & 2 deletions trench/backends/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Type

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser

Expand All @@ -12,7 +14,7 @@
from trench.settings import trench_settings


User: AbstractUser = get_user_model()
User: Type[AbstractUser] = get_user_model()


class ApplicationMessageDispatcher(AbstractMessageDispatcher):
Expand All @@ -24,7 +26,7 @@ def dispatch_message(self) -> DispatchResponse:
logging.error(cause, exc_info=True) # pragma: nocover
return FailedDispatchResponse(details=str(cause)) # pragma: nocover

def _create_qr_link(self, user: User) -> str:
def _create_qr_link(self, user: AbstractUser) -> str:
return self._get_otp().provisioning_uri(
getattr(user, User.USERNAME_FIELD),
trench_settings.APPLICATION_ISSUER_NAME,
Expand Down
2 changes: 1 addition & 1 deletion trench/backends/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import boto3
import botocore.exceptions

from trench.backends.base import AbstractMessageDispatcher
from trench.responses import (
Expand All @@ -13,6 +12,7 @@
from trench.settings import AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION
from botocore.exceptions import ClientError, EndpointConnectionError


class AWSMessageDispatcher(AbstractMessageDispatcher):
_SMS_BODY = _("Your verification code is: ")
_SUCCESS_DETAILS = _("SMS message with MFA code has been sent.")
Expand Down
4 changes: 2 additions & 2 deletions trench/backends/basic_mail.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def dispatch_message(self) -> DispatchResponse:
email_html_template = self._config[EMAIL_HTML_TEMPLATE]
try:
send_mail(
subject=self._config.get(EMAIL_SUBJECT),
subject=str(self._config.get(EMAIL_SUBJECT)),
message=get_template(email_plain_template).render(context),
html_message=get_template(email_html_template).render(context),
from_email=settings.DEFAULT_FROM_EMAIL,
recipient_list=(self._to,),
recipient_list=(self._to,) if self._to else (),
fail_silently=False,
)
return SuccessfulDispatchResponse(details=self._SUCCESS_DETAILS)
Expand Down
12 changes: 4 additions & 8 deletions trench/command/authenticate_second_factor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser

from typing import Type

from django.contrib.auth.base_user import AbstractBaseUser

from trench.backends.provider import get_mfa_handler
from trench.command.remove_backup_code import remove_backup_code_command
from trench.command.validate_backup_code import validate_backup_code_command
Expand All @@ -11,18 +10,15 @@
from trench.utils import get_mfa_model, user_token_generator


User: AbstractUser = get_user_model()


class AuthenticateSecondFactorCommand:
def __init__(self, mfa_model: Type[MFAMethod]) -> None:
self._mfa_model = mfa_model

def execute(self, code: str, ephemeral_token: str) -> User:
def execute(self, code: str, ephemeral_token: str) -> AbstractBaseUser:
user = user_token_generator.check_token(user=None, token=ephemeral_token)
if user is None:
raise InvalidTokenError()
self.is_authenticated(user_id=user.id, code=code)
self.is_authenticated(user_id=user.pk, code=code)
return user

def is_authenticated(self, user_id: int, code: str) -> None:
Expand Down
10 changes: 3 additions & 7 deletions trench/command/authenticate_user.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AbstractUser

from django.contrib.auth import authenticate
from django.contrib.auth.base_user import AbstractBaseUser
from rest_framework.request import Request

from trench.exceptions import UnauthenticatedError


User: AbstractUser = get_user_model()


class AuthenticateUserCommand:
@staticmethod
def execute(request: Request, username: str, password: str) -> User:
def execute(request: Request, username: str, password: str) -> AbstractBaseUser:
user = authenticate(
request=request,
username=username,
Expand Down
4 changes: 2 additions & 2 deletions trench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from trench.exceptions import MFAMethodDoesNotExistError


class MFAUserMethodManager(Manager):
class MFAUserMethodManager(Manager["MFAMethod"]):
def get_by_name(self, user_id: Any, name: str) -> "MFAMethod":
try:
return self.get(user_id=user_id, name=name)
Expand Down Expand Up @@ -96,7 +96,7 @@ class Meta:
objects = MFAUserMethodManager()

def __str__(self) -> str:
return f"{self.name} (User id: {self.user_id})"
return f"{self.name} (User id: {self.user_id})" # type: ignore[attr-defined]

@property
def backup_codes(self) -> Iterable[str]:
Expand Down
27 changes: 14 additions & 13 deletions trench/responses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Union

from django_stubs_ext import StrOrPromise
from rest_framework.response import Response
from rest_framework.status import (
HTTP_200_OK,
Expand All @@ -14,32 +17,30 @@ class DispatchResponse(Response):

class SuccessfulDispatchResponse(DispatchResponse):
def __init__(
self, details: str, status: str = HTTP_200_OK, *args, **kwargs
self, details: StrOrPromise, status: int = HTTP_200_OK, *args, **kwargs
) -> None:
super().__init__(
data={self._FIELD_DETAILS: details}, status=status, *args, **kwargs
)
super().__init__({self._FIELD_DETAILS: details}, status, *args, **kwargs)


class FailedDispatchResponse(DispatchResponse):
def __init__(
self, details: str, status: str = HTTP_422_UNPROCESSABLE_ENTITY, *args, **kwargs
self,
details: StrOrPromise,
status: int = HTTP_422_UNPROCESSABLE_ENTITY,
*args,
**kwargs
) -> None:
super().__init__(
data={self._FIELD_DETAILS: details}, status=status, *args, **kwargs
)
super().__init__({self._FIELD_DETAILS: details}, status, *args, **kwargs)


class ErrorResponse(Response):
_FIELD_ERROR = "error"

def __init__(
self,
error: MFAValidationError,
status: str = HTTP_400_BAD_REQUEST,
error: Union[StrOrPromise, MFAValidationError],
status: int = HTTP_400_BAD_REQUEST,
*args,
**kwargs
) -> None:
super().__init__(
data={self._FIELD_ERROR: str(error)}, status=status, *args, **kwargs
)
super().__init__({self._FIELD_ERROR: str(error)}, status, *args, **kwargs)
Loading