-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathvalidator.py
More file actions
126 lines (106 loc) · 4.35 KB
/
validator.py
File metadata and controls
126 lines (106 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2023-2025 Buf Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from google.protobuf import message
from buf.validate import validate_pb2 # type: ignore
from protovalidate.config import Config
from protovalidate.internal import extra_func
from protovalidate.internal import rules as _rules
CompilationError = _rules.CompilationError
Violations = validate_pb2.Violations
Violation = _rules.Violation
class Validator:
"""
Validates protobuf messages against static rules.
Each validator instance caches internal state generated from the static
rules, so reusing the same instance for multiple validations
significantly improves performance.
"""
_factory: _rules.RuleFactory
_cfg: Config
def __init__(self, config=None):
self._factory = _rules.RuleFactory(extra_func.EXTRA_FUNCS)
self._cfg = config if config is not None else Config()
def validate(
self,
message: message.Message,
):
"""
Validates the given message against the static rules defined in
the message's descriptor.
Parameters:
message: The message to validate.
Raises:
CompilationError: If the static rules could not be compiled.
ValidationError: If the message is invalid. The violations raised as part of this error should
always be equal to the list of violations returned by `collect_violations`.
"""
violations = self.collect_violations(message)
if len(violations) > 0:
msg = f"invalid {message.DESCRIPTOR.name}"
raise ValidationError(msg, violations)
def collect_violations(
self,
message: message.Message,
*,
into: typing.Optional[list[Violation]] = None,
) -> list[Violation]:
"""
Validates the given message against the static rules defined in
the message's descriptor. Compared to `validate`, `collect_violations` simply
returns the violations as a list and puts the burden of raising an appropriate
exception on the caller.
The violations returned from this method should always be equal to the violations
raised as part of the ValidationError in the call to `validate`.
Parameters:
message: The message to validate.
into: If provided, any violations will be appended to the
Violations object and the same object will be returned.
Raises:
CompilationError: If the static rules could not be compiled.
"""
ctx = _rules.RuleContext(config=self._cfg, violations=into)
for rule in self._factory.get(message.DESCRIPTOR):
rule.validate(ctx, message)
if ctx.done:
break
for violation in ctx.violations:
if violation.proto.HasField("field"):
violation.proto.field.elements.reverse()
if violation.proto.HasField("rule"):
violation.proto.rule.elements.reverse()
return ctx.violations
class ValidationError(ValueError):
"""
An error raised when a message fails to validate.
"""
_violations: list[_rules.Violation]
def __init__(self, msg: str, violations: list[_rules.Violation]):
super().__init__(msg)
self._violations = violations
def to_proto(self) -> validate_pb2.Violations:
"""
Provides the Protobuf form of the validation errors.
"""
result = validate_pb2.Violations()
for violation in self._violations:
result.violations.append(violation.proto)
return result
@property
def violations(self) -> list[Violation]:
"""
Provides the validation errors as a simple Python list, rather than the
Protobuf-specific collection type used by Violations.
"""
return self._violations