Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions protobuf_to_pydantic/gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,11 @@ def get_model(self, full_name: str) -> Type[BaseModel]:
###############
# util method #
###############
def _get_field_info_dict_by_full_name(self, full_name: str) -> Optional["FieldInfoTypedDict"]:
split_full_name = full_name.split(".")
def _get_field_info_dict_by_full_name(self, full_name: str, package_name: str) -> Optional["FieldInfoTypedDict"]:
if package_name and full_name.startswith(f"{package_name}."):
split_full_name = full_name[len(package_name) + 1 :].split(".")
else:
split_full_name = full_name.split(".")
if len(split_full_name) == 2:
message_name, *key_list = split_full_name
else:
Expand Down Expand Up @@ -315,7 +318,7 @@ def get_nested_message_dict_by_message(self, descriptor: Descriptor) -> Dict[str
####################
# field handler #
####################
def _protobuf_field_type_is_type_message_handler(self, field_dataclass: FieldDataClass) -> None:
def _protobuf_field_type_is_type_message_handler(self, field_dataclass: FieldDataClass, package_name: str) -> None:
protobuf_field = field_dataclass.protobuf_field
if protobuf_field.message_type.name in self._message_type_dict_by_type_name:
# Timestamp, Struct, Empty, Duration, Any support
Expand Down Expand Up @@ -352,7 +355,7 @@ def _protobuf_field_type_is_type_message_handler(self, field_dataclass: FieldDat
else:
# support google.protobuf.Message
field_info_dict: Union[FieldInfoTypedDict, dict] = (
self._get_field_info_dict_by_full_name(field_dataclass.protobuf_field.full_name) or {}
self._get_field_info_dict_by_full_name(field_dataclass.protobuf_field.full_name, package_name) or {}
)
skip_validate_rule = field_info_dict.get("skip", False)
full_name = protobuf_field.message_type.full_name
Expand Down Expand Up @@ -454,9 +457,9 @@ def _protobuf_field_lable_is_label_repeated_handler(self, field_dataclass: Field
if field_dataclass.field_default is not _pydantic_adapter.PydanticUndefined:
field_dataclass.field_default = _pydantic_adapter.PydanticUndefined

def _gen_field_info(self, field_dataclass: FieldDataClass, skip_validate_rule: bool) -> Optional[FieldInfo]:
def _gen_field_info(self, field_dataclass: FieldDataClass, skip_validate_rule: bool, package_name: str) -> Optional[FieldInfo]:
field_class = self._default_field
field_info_dict = self._get_field_info_dict_by_full_name(field_dataclass.protobuf_field.full_name)
field_info_dict = self._get_field_info_dict_by_full_name(field_dataclass.protobuf_field.full_name, package_name)

if field_info_dict is not None and not skip_validate_rule:
if self._parse_msg_desc_method != "PGV":
Expand Down Expand Up @@ -564,6 +567,7 @@ def _parse_msg_to_pydantic_model(
) -> Type[BaseModel]:
is_same_pkg = descriptor.file.name == root_descriptor.file.name if root_descriptor else True
class_name = class_name or descriptor.name
package_name = descriptor.file.package or "" # type: ignore

if not is_same_pkg:
class_name = replace_file_name_to_class_name(descriptor.file.name) + class_name
Expand Down Expand Up @@ -597,7 +601,7 @@ def _parse_msg_to_pydantic_model(
validators=validators,
)
if protobuf_field.type == FieldDescriptor.TYPE_MESSAGE:
self._protobuf_field_type_is_type_message_handler(field_dataclass)
self._protobuf_field_type_is_type_message_handler(field_dataclass, package_name)
elif protobuf_field.type == FieldDescriptor.TYPE_ENUM:
self._protobuf_field_type_is_type_enum_handler(field_dataclass)
if _pydantic_adapter.is_v1:
Expand All @@ -610,7 +614,7 @@ def _parse_msg_to_pydantic_model(
# At this time, the field type may be modified by the above logic, so it needs to be handled separately
if protobuf_field.label == FieldDescriptor.LABEL_REPEATED:
self._protobuf_field_lable_is_label_repeated_handler(field_dataclass)
field_info = self._gen_field_info(field_dataclass, skip_validate_rule)
field_info = self._gen_field_info(field_dataclass, skip_validate_rule, package_name)
if not field_info:
continue

Expand Down