Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add formatter module #1746

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions datamodel_code_generator/formatter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .black import BlackCodeFormatter
from .isort import IsortCodeFormatter

__all__ = [
'IsortCodeFormatter',
'BlackCodeFormatter',
]
91 changes: 91 additions & 0 deletions datamodel_code_generator/formatter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from importlib import import_module
from typing import Any, ClassVar, Dict

from datamodel_code_generator.imports import Import


class BaseCodeFormatter:
"""An abstract class for representing a code formatter.

All formatters that format a generated code should subclass
it. All subclass should override `apply` method which
has a string with code in input and returns a formatted code in string.
We also need to determine a `formatter_name` field
which is unique name of formatter.

Example:
>>> class CustomHeaderCodeFormatter(BaseCodeFormatter):
... formatter_name: ClassVar[str] = "custom"
... def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
... super().__init__(formatter_kwargs=formatter_kwargs)
...
... default_header = "my header"
... self.header: str = self.formatter_kwargs.get("header", default_header)
... def apply(self, code: str) -> str:
... return f'# {self.header}\\n{code}'
...
... formatter_kwargs = {"header": "formatted with CustomHeaderCodeFormatter"}
... formatter = CustomHeaderCodeFormatter(formatter_kwargs)
... code = '''x = 1\ny = 2'''
... print(formatter.apply(code))
# formatted with CustomHeaderCodeFormatter
x = 1
y = 2

"""

formatter_name: ClassVar[str] = ''

def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
if self.formatter_name == '':
raise ValueError('`formatter_name` should be not empty string')

self.formatter_kwargs = formatter_kwargs

def apply(self, code: str) -> str:
raise NotImplementedError


def load_code_formatter(
custom_formatter_import: str, custom_formatters_kwargs: Dict[str, Any]
) -> BaseCodeFormatter:
"""Load a formatter by import path as string.

Args:
custom_formatter_import: custom formatter module.
custom_formatters_kwargs: kwargs for custom formatters from config.

Examples:
for default formatters use
>>> custom_formatter_import = "datamodel_code_generator.formatter.BlackCodeFormatter"
this is equivalent to code
>>> from datamodel_code_generator.formatter import BlackCodeFormatter

custom formatter
>>> custom_formatter_import = "my_package.my_sub_package.FormatterName"
this is equivalent to code
>>> from my_package.my_sub_package import FormatterName

"""

import_ = Import.from_full_path(custom_formatter_import)
imported_module_ = import_module(import_.from_)
Comment on lines +72 to +73
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think the uses case 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also had the same reaction. But the Import class is very suitable in this case.


if not hasattr(imported_module_, import_.import_):
raise NameError(
f'Custom formatter module `{import_.from_}` not contains formatter with name `{import_.import_}`'
)

formatter_class = imported_module_.__getattribute__(import_.import_)

if not issubclass(formatter_class, BaseCodeFormatter):
raise TypeError(
f'The custom module `{custom_formatter_import}` must inherit from '
'`datamodel-code-generator.formatter.BaseCodeFormatter`'
)

custom_formatter_kwargs = custom_formatters_kwargs.get(
formatter_class.formatter_name, {}
)

return formatter_class(formatter_kwargs=custom_formatter_kwargs)
108 changes: 108 additions & 0 deletions datamodel_code_generator/formatter/black.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Dict

import black

from datamodel_code_generator.util import cached_property

from .base import BaseCodeFormatter


class PythonVersion(Enum):
PY_36 = '3.6'
PY_37 = '3.7'
PY_38 = '3.8'
PY_39 = '3.9'
PY_310 = '3.10'
PY_311 = '3.11'
PY_312 = '3.12'

@cached_property
def _is_py_38_or_later(self) -> bool: # pragma: no cover
return self.value not in {self.PY_36.value, self.PY_37.value} # type: ignore

@cached_property
def _is_py_39_or_later(self) -> bool: # pragma: no cover
return self.value not in {self.PY_36.value, self.PY_37.value, self.PY_38.value} # type: ignore

@cached_property
def _is_py_310_or_later(self) -> bool: # pragma: no cover
return self.value not in {
self.PY_36.value,
self.PY_37.value,
self.PY_38.value,
self.PY_39.value,
} # type: ignore

@cached_property
def _is_py_311_or_later(self) -> bool: # pragma: no cover
return self.value not in {
self.PY_36.value,
self.PY_37.value,
self.PY_38.value,
self.PY_39.value,
self.PY_310.value,
} # type: ignore

@property
def has_literal_type(self) -> bool:
return self._is_py_38_or_later

@property
def has_union_operator(self) -> bool: # pragma: no cover
return self._is_py_310_or_later

@property
def has_annotated_type(self) -> bool:
return self._is_py_39_or_later

@property
def has_typed_dict(self) -> bool:
return self._is_py_38_or_later

@property
def has_typed_dict_non_required(self) -> bool:
return self._is_py_311_or_later


if TYPE_CHECKING:

class _TargetVersion(Enum):
...
Fixed Show fixed Hide fixed

BLACK_PYTHON_VERSION: Dict[PythonVersion, _TargetVersion]
else:
BLACK_PYTHON_VERSION: Dict[PythonVersion, black.TargetVersion] = {
v: getattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}')
for v in PythonVersion
if hasattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}')
}


class BlackCodeFormatter(BaseCodeFormatter):
formatter_name: ClassVar[str] = 'black'

def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
super().__init__(formatter_kwargs=formatter_kwargs)

if TYPE_CHECKING:
self.black_mode: black.FileMode
else:
self.black_mode = black.FileMode(
target_versions={
BLACK_PYTHON_VERSION[formatter_kwargs.get('target-version', '3.7')]
},
line_length=formatter_kwargs.get(
'line-length', black.DEFAULT_LINE_LENGTH
),
string_normalization=not formatter_kwargs.get(
'skip-string-normalization', True
),
**formatter_kwargs,
)

def apply(self, code: str) -> str:
return black.format_str(
code,
mode=self.black_mode,
)
47 changes: 47 additions & 0 deletions datamodel_code_generator/formatter/isort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pathlib import Path
from typing import Any, ClassVar, Dict

import isort

from .base import BaseCodeFormatter


class IsortCodeFormatter(BaseCodeFormatter):
formatter_name: ClassVar[str] = 'isort'

def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
super().__init__(formatter_kwargs=formatter_kwargs)

if 'settings_path' not in self.formatter_kwargs:
settings_path = Path().resolve()
else:
settings_path = Path(self.formatter_kwargs['settings_path'])

self.settings_path: str = str(settings_path)
self.isort_config_kwargs: Dict[str, Any] = {}

if 'known_third_party' in self.formatter_kwargs:
self.isort_config_kwargs['known_third_party'] = self.formatter_kwargs[
'known_third_party'
]

if isort.__version__.startswith('4.'):
self.isort_config = None
else:
self.isort_config = isort.Config(
settings_path=self.settings_path, **self.isort_config_kwargs
)

if isort.__version__.startswith('4.'):

def apply(self, code: str) -> str:
return isort.SortImports(
file_contents=code,
settings_path=self.settings_path,
**self.isort_config_kwargs,
).output

else:

def apply(self, code: str) -> str:
return isort.code(code, config=self.isort_config)
17 changes: 17 additions & 0 deletions tests/data/python/custom_formatters/add_license_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any, Dict, ClassVar

from datamodel_code_generator.formatter.base import BaseCodeFormatter


class LicenseFormatter(BaseCodeFormatter):
"""Add a license to file from license file path."""
formatter_name: ClassVar[str] = "license_formatter"

def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
super().__init__(formatter_kwargs)

license_txt = formatter_kwargs.get('license_txt', "a license")
self.license_header = '\n'.join([f'# {line}' for line in license_txt.split('\n')])

def apply(self, code: str) -> str:
return f'{self.license_header}\n{code}'
Empty file added tests/formatter/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions tests/formatter/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import ClassVar

import pytest

from datamodel_code_generator.formatter.base import (
BaseCodeFormatter,
load_code_formatter,
)

UN_EXIST_FORMATTER = 'tests.data.python.custom_formatters.un_exist.CustomFormatter'
WRONG_FORMATTER = 'tests.data.python.custom_formatters.wrong.WrongFormatterName_'
NOT_SUBCLASS_FORMATTER = (
'tests.data.python.custom_formatters.not_subclass.CodeFormatter'
)
ADD_LICENSE_FORMATTER = (
'tests.data.python.custom_formatters.add_license_formatter.LicenseFormatter'
)


def test_incorrect_from_base_not_implemented_apply():
class CustomFormatter(BaseCodeFormatter):
formatter_name: ClassVar[str] = 'formatter'

with pytest.raises(NotImplementedError):
formatter = CustomFormatter({})
formatter.apply('')


def test_incorrect_from_base():
class CustomFormatter(BaseCodeFormatter):
def apply(self, code: str) -> str:
return code

with pytest.raises(ValueError):
_ = CustomFormatter({})


def test_load_code_formatter_un_exist_custom_formatter():
with pytest.raises(ModuleNotFoundError):
load_code_formatter(UN_EXIST_FORMATTER, {})


def test_load_code_formatter_invalid_formatter_name():
with pytest.raises(NameError):
load_code_formatter(WRONG_FORMATTER, {})


def test_load_code_formatter_is_not_subclass():
with pytest.raises(TypeError):
load_code_formatter(NOT_SUBCLASS_FORMATTER, {})


def test_add_license_formatter_without_kwargs():
formatter = load_code_formatter(ADD_LICENSE_FORMATTER, {})
formatted_code = formatter.apply('x = 1\ny = 2')

assert (
formatted_code
== """# a license
x = 1
y = 2"""
)


def test_add_license_formatter_with_kwargs():
formatter = load_code_formatter(
ADD_LICENSE_FORMATTER,
{
'license_formatter': {
'license_txt': 'MIT License\n\nCopyright (c) 2023 Blah-blah\n'
}
},
)
formatted_code = formatter.apply('x = 1\ny = 2')

assert (
formatted_code
== """# MIT License
#
# Copyright (c) 2023 Blah-blah
#
x = 1
y = 2"""
)
15 changes: 15 additions & 0 deletions tests/formatter/test_black.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from datamodel_code_generator.formatter.base import (
BaseCodeFormatter,
load_code_formatter,
)
from datamodel_code_generator.formatter.black import BlackCodeFormatter


def test_black_formatter_is_subclass_if_base():
assert issubclass(BlackCodeFormatter, BaseCodeFormatter)
assert BlackCodeFormatter.formatter_name == 'black'
assert hasattr(BlackCodeFormatter, 'apply')


def test_load_black_formatter():
_ = load_code_formatter('datamodel_code_generator.formatter.BlackCodeFormatter', {})
15 changes: 15 additions & 0 deletions tests/formatter/test_isort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from datamodel_code_generator.formatter.base import (
BaseCodeFormatter,
load_code_formatter,
)
from datamodel_code_generator.formatter.isort import IsortCodeFormatter


def test_isort_formatter_is_subclass_if_base():
assert issubclass(IsortCodeFormatter, BaseCodeFormatter)
assert IsortCodeFormatter.formatter_name == 'isort'
assert hasattr(IsortCodeFormatter, 'apply')


def test_load_isort_formatter():
_ = load_code_formatter('datamodel_code_generator.formatter.IsortCodeFormatter', {})
Loading