Skip to content

Commit c838267

Browse files
authoredSep 7, 2022
Merge pull request #412 from p1c2u/refactor/customization-refactor
Customization refactor
2 parents 20526b3 + 0ec97d5 commit c838267

File tree

15 files changed

+213
-133
lines changed

15 files changed

+213
-133
lines changed
 

Diff for: ‎docs/customizations.rst

+26-12
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ By default, spec dict is validated on spec creation time. Disabling the validati
1515
Deserializers
1616
-------------
1717

18-
Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `RequestValidator` or `ResponseValidator` constructor:
18+
Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `MediaTypeDeserializersFactory` and then pass it to `RequestValidator` or `ResponseValidator` constructor:
1919

2020
.. code-block:: python
2121
22+
from openapi_core.deserializing.media_types.factories import MediaTypeDeserializersFactory
23+
from openapi_core.unmarshalling.schemas import oas30_response_schema_unmarshallers_factory
24+
2225
def protobuf_deserializer(message):
2326
feature = route_guide_pb2.Feature()
2427
feature.ParseFromString(message)
@@ -27,9 +30,14 @@ Pass custom defined media type deserializers dictionary with supported mimetypes
2730
custom_media_type_deserializers = {
2831
'application/protobuf': protobuf_deserializer,
2932
}
33+
media_type_deserializers_factory = MediaTypeDeserializersFactory(
34+
custom_deserializers=custom_media_type_deserializers,
35+
)
3036
3137
validator = ResponseValidator(
32-
custom_media_type_deserializers=custom_media_type_deserializers)
38+
oas30_response_schema_unmarshallers_factory,
39+
media_type_deserializers_factory=media_type_deserializers_factory,
40+
)
3341
3442
result = validator.validate(spec, request, response)
3543
@@ -38,28 +46,34 @@ Formats
3846

3947
OpenAPI defines a ``format`` keyword that hints at how a value should be interpreted, e.g. a ``string`` with the type ``date`` should conform to the RFC 3339 date format.
4048

41-
Openapi-core comes with a set of built-in formatters, but it's also possible to add support for custom formatters for `RequestValidator` and `ResponseValidator`.
49+
Openapi-core comes with a set of built-in formatters, but it's also possible to add custom formatters in `SchemaUnmarshallersFactory` and pass it to `RequestValidator` or `ResponseValidator`.
4250

4351
Here's how you could add support for a ``usdate`` format that handles dates of the form MM/DD/YYYY:
4452

4553
.. code-block:: python
4654
47-
from datetime import datetime
48-
import re
55+
from openapi_core.unmarshalling.schemas.factories import SchemaUnmarshallersFactory
56+
from openapi_schema_validator import OAS30Validator
57+
from datetime import datetime
58+
import re
4959
50-
class USDateFormatter:
51-
def validate(self, value) -> bool:
52-
return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value))
60+
class USDateFormatter:
61+
def validate(self, value) -> bool:
62+
return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value))
5363
54-
def unmarshal(self, value):
55-
return datetime.strptime(value, "%m/%d/%y").date
64+
def unmarshal(self, value):
65+
return datetime.strptime(value, "%m/%d/%y").date
5666
5767
5868
custom_formatters = {
5969
'usdate': USDateFormatter(),
6070
}
61-
62-
validator = ResponseValidator(custom_formatters=custom_formatters)
71+
schema_unmarshallers_factory = SchemaUnmarshallersFactory(
72+
OAS30Validator,
73+
custom_formatters=custom_formatters,
74+
context=UnmarshalContext.RESPONSE,
75+
)
76+
validator = ResponseValidator(schema_unmarshallers_factory)
6377
6478
result = validator.validate(spec, request, response)
6579

Diff for: ‎docs/usage.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Now you can use it to validate against requests
2323
2424
from openapi_core.validation.request import openapi_request_validator
2525
26-
result = validator.validate(spec, request)
26+
result = openapi_request_validator.validate(spec, request)
2727
2828
# raise errors if request invalid
2929
result.raise_for_errors()
@@ -57,7 +57,7 @@ You can also validate against responses
5757
5858
from openapi_core.validation.response import openapi_response_validator
5959
60-
result = validator.validate(spec, request, response)
60+
result = openapi_response_validator.validate(spec, request, response)
6161
6262
# raise errors if response invalid
6363
result.raise_for_errors()

Diff for: ‎openapi_core/casting/schemas/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from openapi_core.casting.schemas.factories import SchemaCastersFactory
2+
3+
__all__ = ["schema_casters_factory"]
4+
5+
schema_casters_factory = SchemaCastersFactory()

Diff for: ‎openapi_core/deserializing/media_types/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from openapi_core.deserializing.media_types.factories import (
2+
MediaTypeDeserializersFactory,
3+
)
4+
5+
__all__ = ["media_type_deserializers_factory"]
6+
7+
media_type_deserializers_factory = MediaTypeDeserializersFactory()

Diff for: ‎openapi_core/deserializing/parameters/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from openapi_core.deserializing.parameters.factories import (
2+
ParameterDeserializersFactory,
3+
)
4+
5+
__all__ = ["parameter_deserializers_factory"]
6+
7+
parameter_deserializers_factory = ParameterDeserializersFactory()

Diff for: ‎openapi_core/security/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from openapi_core.security.factories import SecurityProviderFactory
2+
3+
__all__ = ["security_provider_factory"]
4+
5+
security_provider_factory = SecurityProviderFactory()

Diff for: ‎openapi_core/unmarshalling/schemas/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from openapi_schema_validator import OAS30Validator
2+
3+
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
4+
from openapi_core.unmarshalling.schemas.factories import (
5+
SchemaUnmarshallersFactory,
6+
)
7+
8+
__all__ = [
9+
"oas30_request_schema_unmarshallers_factory",
10+
"oas30_response_schema_unmarshallers_factory",
11+
]
12+
13+
oas30_request_schema_unmarshallers_factory = SchemaUnmarshallersFactory(
14+
OAS30Validator,
15+
context=UnmarshalContext.REQUEST,
16+
)
17+
18+
oas30_response_schema_unmarshallers_factory = SchemaUnmarshallersFactory(
19+
OAS30Validator,
20+
context=UnmarshalContext.RESPONSE,
21+
)

Diff for: ‎openapi_core/unmarshalling/schemas/factories.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openapi_core.unmarshalling.schemas.unmarshallers import NumberUnmarshaller
1818
from openapi_core.unmarshalling.schemas.unmarshallers import ObjectUnmarshaller
1919
from openapi_core.unmarshalling.schemas.unmarshallers import StringUnmarshaller
20+
from openapi_core.unmarshalling.schemas.util import build_format_checker
2021

2122

2223
class SchemaUnmarshallersFactory:
@@ -40,13 +41,11 @@ class SchemaUnmarshallersFactory:
4041

4142
def __init__(
4243
self,
43-
resolver=None,
44-
format_checker=None,
44+
schema_validator_class,
4545
custom_formatters=None,
4646
context=None,
4747
):
48-
self.resolver = resolver
49-
self.format_checker = format_checker
48+
self.schema_validator_class = schema_validator_class
5049
if custom_formatters is None:
5150
custom_formatters = {}
5251
self.custom_formatters = custom_formatters
@@ -86,11 +85,13 @@ def get_formatter(self, type_format, default_formatters):
8685
return default_formatters.get(type_format)
8786

8887
def get_validator(self, schema):
88+
resolver = schema.accessor.dereferencer.resolver_manager.resolver
89+
format_checker = build_format_checker(**self.custom_formatters)
8990
kwargs = {
90-
"resolver": self.resolver,
91-
"format_checker": self.format_checker,
91+
"resolver": resolver,
92+
"format_checker": format_checker,
9293
}
9394
if self.context is not None:
9495
kwargs[self.CONTEXT_VALIDATION[self.context]] = True
9596
with schema.open() as schema_dict:
96-
return OAS30Validator(schema_dict, **kwargs)
97+
return self.schema_validator_class(schema_dict, **kwargs)

Diff for: ‎openapi_core/validation/request/__init__.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
"""OpenAPI core validation request module"""
2+
from openapi_core.unmarshalling.schemas import (
3+
oas30_request_schema_unmarshallers_factory,
4+
)
25
from openapi_core.validation.request.validators import RequestBodyValidator
36
from openapi_core.validation.request.validators import (
47
RequestParametersValidator,
@@ -7,13 +10,31 @@
710
from openapi_core.validation.request.validators import RequestValidator
811

912
__all__ = [
13+
"openapi_v30_request_body_validator",
14+
"openapi_v30_request_parameters_validator",
15+
"openapi_v30_request_security_validator",
16+
"openapi_v30_request_validator",
1017
"openapi_request_body_validator",
1118
"openapi_request_parameters_validator",
1219
"openapi_request_security_validator",
1320
"openapi_request_validator",
1421
]
1522

16-
openapi_request_body_validator = RequestBodyValidator()
17-
openapi_request_parameters_validator = RequestParametersValidator()
18-
openapi_request_security_validator = RequestSecurityValidator()
19-
openapi_request_validator = RequestValidator()
23+
openapi_v30_request_body_validator = RequestBodyValidator(
24+
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
25+
)
26+
openapi_v30_request_parameters_validator = RequestParametersValidator(
27+
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
28+
)
29+
openapi_v30_request_security_validator = RequestSecurityValidator(
30+
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
31+
)
32+
openapi_v30_request_validator = RequestValidator(
33+
schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory,
34+
)
35+
36+
# alias to the latest v3 version
37+
openapi_request_body_validator = openapi_v30_request_body_validator
38+
openapi_request_parameters_validator = openapi_v30_request_parameters_validator
39+
openapi_request_security_validator = openapi_v30_request_security_validator
40+
openapi_request_validator = openapi_v30_request_validator

Diff for: ‎openapi_core/validation/request/validators.py

+46-37
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
"""OpenAPI core validation request validators module"""
22
import warnings
33

4+
from openapi_core.casting.schemas import schema_casters_factory
45
from openapi_core.casting.schemas.exceptions import CastError
56
from openapi_core.deserializing.exceptions import DeserializeError
7+
from openapi_core.deserializing.media_types import (
8+
media_type_deserializers_factory,
9+
)
10+
from openapi_core.deserializing.parameters import (
11+
parameter_deserializers_factory,
12+
)
613
from openapi_core.schema.parameters import iter_params
14+
from openapi_core.security import security_provider_factory
715
from openapi_core.security.exceptions import SecurityError
8-
from openapi_core.security.factories import SecurityProviderFactory
916
from openapi_core.templating.media_types.exceptions import MediaTypeFinderError
1017
from openapi_core.templating.paths.exceptions import PathError
1118
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
@@ -28,6 +35,22 @@
2835

2936

3037
class BaseRequestValidator(BaseValidator):
38+
def __init__(
39+
self,
40+
schema_unmarshallers_factory,
41+
schema_casters_factory=schema_casters_factory,
42+
parameter_deserializers_factory=parameter_deserializers_factory,
43+
media_type_deserializers_factory=media_type_deserializers_factory,
44+
security_provider_factory=security_provider_factory,
45+
):
46+
super().__init__(
47+
schema_unmarshallers_factory,
48+
schema_casters_factory=schema_casters_factory,
49+
parameter_deserializers_factory=parameter_deserializers_factory,
50+
media_type_deserializers_factory=media_type_deserializers_factory,
51+
)
52+
self.security_provider_factory = security_provider_factory
53+
3154
def validate(
3255
self,
3356
spec,
@@ -36,22 +59,6 @@ def validate(
3659
):
3760
raise NotImplementedError
3861

39-
@property
40-
def schema_unmarshallers_factory(self):
41-
spec_resolver = (
42-
self.spec.accessor.dereferencer.resolver_manager.resolver
43-
)
44-
return SchemaUnmarshallersFactory(
45-
spec_resolver,
46-
self.format_checker,
47-
self.custom_formatters,
48-
context=UnmarshalContext.REQUEST,
49-
)
50-
51-
@property
52-
def security_provider_factory(self):
53-
return SecurityProviderFactory()
54-
5562
def _get_parameters(self, request, path, operation):
5663
operation_params = operation.get("parameters", [])
5764
path_params = path.get("parameters", [])
@@ -109,10 +116,10 @@ def _get_parameter(self, param, request):
109116
raise MissingRequiredParameter(name)
110117
raise MissingParameter(name)
111118

112-
def _get_security(self, request, operation):
119+
def _get_security(self, spec, request, operation):
113120
security = None
114-
if "security" in self.spec:
115-
security = self.spec / "security"
121+
if "security" in spec:
122+
security = spec / "security"
116123
if "security" in operation:
117124
security = operation / "security"
118125

@@ -122,16 +129,18 @@ def _get_security(self, request, operation):
122129
for security_requirement in security:
123130
try:
124131
return {
125-
scheme_name: self._get_security_value(scheme_name, request)
132+
scheme_name: self._get_security_value(
133+
spec, scheme_name, request
134+
)
126135
for scheme_name in list(security_requirement.keys())
127136
}
128137
except SecurityError:
129138
continue
130139

131140
raise InvalidSecurity
132141

133-
def _get_security_value(self, scheme_name, request):
134-
security_schemes = self.spec / "components#securitySchemes"
142+
def _get_security_value(self, spec, scheme_name, request):
143+
security_schemes = spec / "components#securitySchemes"
135144
if scheme_name not in security_schemes:
136145
return
137146
scheme = security_schemes[scheme_name]
@@ -174,10 +183,10 @@ def validate(
174183
request,
175184
base_url=None,
176185
):
177-
self.spec = spec
178-
self.base_url = base_url
179186
try:
180-
path, operation, _, path_result, _ = self._find_path(request)
187+
path, operation, _, path_result, _ = self._find_path(
188+
spec, request, base_url=base_url
189+
)
181190
except PathError as exc:
182191
return RequestValidationResult(errors=[exc])
183192

@@ -206,10 +215,10 @@ def validate(
206215
request,
207216
base_url=None,
208217
):
209-
self.spec = spec
210-
self.base_url = base_url
211218
try:
212-
_, operation, _, _, _ = self._find_path(request)
219+
_, operation, _, _, _ = self._find_path(
220+
spec, request, base_url=base_url
221+
)
213222
except PathError as exc:
214223
return RequestValidationResult(errors=[exc])
215224

@@ -244,15 +253,15 @@ def validate(
244253
request,
245254
base_url=None,
246255
):
247-
self.spec = spec
248-
self.base_url = base_url
249256
try:
250-
_, operation, _, _, _ = self._find_path(request)
257+
_, operation, _, _, _ = self._find_path(
258+
spec, request, base_url=base_url
259+
)
251260
except PathError as exc:
252261
return RequestValidationResult(errors=[exc])
253262

254263
try:
255-
security = self._get_security(request, operation)
264+
security = self._get_security(spec, request, operation)
256265
except InvalidSecurity as exc:
257266
return RequestValidationResult(errors=[exc])
258267

@@ -269,16 +278,16 @@ def validate(
269278
request,
270279
base_url=None,
271280
):
272-
self.spec = spec
273-
self.base_url = base_url
274281
try:
275-
path, operation, _, path_result, _ = self._find_path(request)
282+
path, operation, _, path_result, _ = self._find_path(
283+
spec, request, base_url=base_url
284+
)
276285
# don't process if operation errors
277286
except PathError as exc:
278287
return RequestValidationResult(errors=[exc])
279288

280289
try:
281-
security = self._get_security(request, operation)
290+
security = self._get_security(spec, request, operation)
282291
except InvalidSecurity as exc:
283292
return RequestValidationResult(errors=[exc])
284293

0 commit comments

Comments
 (0)
Please sign in to comment.