Skip to content

Commit

Permalink
Improved from_type
Browse files Browse the repository at this point in the history
  • Loading branch information
christiansandberg committed Sep 22, 2024
1 parent cfaebbc commit 947e46c
Show file tree
Hide file tree
Showing 5 changed files with 409 additions and 103 deletions.
4 changes: 3 additions & 1 deletion onedm/sdf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


class CommonQualities(BaseModel):
model_config = ConfigDict(extra="allow", alias_generator=to_camel)
model_config = ConfigDict(
extra="allow", alias_generator=to_camel, populate_by_name=True
)

label: str | None = None
description: str | None = None
Expand Down
15 changes: 9 additions & 6 deletions onedm/sdf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import datetime
from abc import ABC
from enum import EnumMeta, IntEnum
from re import Pattern
from typing import Annotated, Any, Literal, Union

from pydantic import Field, NonNegativeInt, field_serializer
Expand Down Expand Up @@ -187,7 +188,7 @@ class StringData(DataQualities):
enum: list[str] | None = None
min_length: NonNegativeInt = 0
max_length: NonNegativeInt | None = None
pattern: str | None = None
pattern: str | Pattern[str] | None = None
format: str | None = None
content_format: str | None = None
choices: Annotated[dict[str, StringData] | None, Field(alias="sdfChoice")] = (
Expand Down Expand Up @@ -226,7 +227,7 @@ def _get_base_schema(self) -> core_schema.CoreSchema:

class ArrayData(DataQualities):
type: Literal["array"] = "array"
items: Data
items: Data | None = None
min_items: NonNegativeInt = 0
max_items: NonNegativeInt | None = None
unique_items: bool = False
Expand All @@ -240,20 +241,20 @@ def always_include_type(self, type: str, _):
def _get_base_schema(self) -> core_schema.ListSchema | core_schema.SetSchema:
if self.unique_items:
return core_schema.set_schema(
self.items.get_pydantic_schema(),
self.items.get_pydantic_schema() if self.items is not None else None,
min_length=self.min_items,
max_length=self.max_items,
)
return core_schema.list_schema(
self.items.get_pydantic_schema(),
self.items.get_pydantic_schema() if self.items is not None else None,
min_length=self.min_items,
max_length=self.max_items,
)


class ObjectData(DataQualities):
type: Literal["object"] = "object"
properties: dict[str, Data]
properties: dict[str, Data] | None = None
required: list[str] = Field(default_factory=list)
const: dict[str, Any] | None = None
default: dict[str, Any] | None = None
Expand All @@ -262,7 +263,9 @@ class ObjectData(DataQualities):
def always_include_type(self, type: str, _):
return type

def _get_base_schema(self) -> core_schema.TypedDictSchema:
def _get_base_schema(self) -> core_schema.CoreSchema:
if self.properties is None:
return core_schema.dict_schema()
required = self.required or []
fields = {
name: core_schema.typed_dict_field(
Expand Down
248 changes: 234 additions & 14 deletions onedm/sdf/from_type.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,252 @@
"""Conversion from native types to sdfData."""

from enum import Enum
from typing import Type

from pydantic import TypeAdapter
from pydantic_core import core_schema

from .data import Data, IntegerData
from .json_schema import from_json_schema
from . import data


def data_from_type(type_: Type) -> Data | None:
def data_from_type(type_: Type) -> data.Data | None:
"""Create from a native Python or Pydantic type.
None or null is not a supported type in SDF. In this case the return value
will be None.
"""
schema = TypeAdapter(type_).json_schema()
return data_from_schema(TypeAdapter(type_).core_schema)

if schema.get("type") == "null":
# Null types not supported

def data_from_schema(schema: core_schema.CoreSchema) -> data.Data | None:
schema_type = schema["type"]
data_type: data.Data
if schema_type == "none":
return None
if schema_type == "int":
data_type = data_from_int_schema(schema) # type: ignore
elif schema_type == "float":
data_type = data_from_float_schema(schema) # type: ignore
elif schema_type == "bool":
data_type = data_from_bool_schema(schema) # type: ignore
elif schema_type == "str":
data_type = data_from_str_schema(schema) # type: ignore
elif schema_type == "bytes":
data_type = data_from_bytes_schema(schema) # type: ignore
elif schema_type == "model":
data_type = data_from_model_schema(schema) # type: ignore
elif schema_type == "model-fields":
data_type = data_from_model_fields_schema(schema) # type: ignore
elif schema_type == "dataclass":
data_type = data_from_dataclass_schema(schema) # type: ignore
elif schema_type == "list":
data_type = data_from_list_schema(schema) # type: ignore
elif schema_type == "set":
data_type = data_from_set_schema(schema) # type: ignore
elif schema_type == "dict":
data_type = data_from_dict_schema(schema) # type: ignore
elif schema_type == "typed-dict":
data_type = data_from_typed_dict_schema(schema) # type: ignore
elif schema_type == "enum":
data_type = data_from_enum_schema(schema) # type: ignore
elif schema_type == "literal":
data_type = data_from_literal_schema(schema) # type: ignore
elif schema_type == "any":
data_type = data_from_any_schema(schema) # type: ignore
elif schema_type == "nullable":
data_type = data_from_nullable_schema(schema) # type: ignore
elif schema_type == "default":
data_type = data_from_default_schema(schema) # type: ignore
elif schema_type == "datetime":
data_type = data_from_datetime_schema(schema) # type: ignore
else:
raise NotImplementedError(f"Unsupported schema '{schema['type']}'")

# data_type.label = schema["metadata"].get("title")
return data_type


def data_from_any_schema(schema: core_schema.AnySchema):
return data.AnyData(nullable=False)


def data_from_nullable_schema(schema: core_schema.NullableSchema):
data_type = data_from_schema(schema["schema"])
data_type.nullable = True
return data_type


def data_from_default_schema(schema: core_schema.WithDefaultSchema):
data_type = data_from_schema(schema["schema"])
data_type.default = schema["default"]
return data_type


def data_from_model_schema(schema: core_schema.ModelSchema):
data_type = data_from_schema(schema["schema"])
return data_type


def data_from_model_fields_schema(schema: core_schema.ModelFieldsSchema):
return data.ObjectData(
label=schema.get("model_name"),
properties={
prop_schema.get("serialization_alias", name): data_from_schema(
prop_schema["schema"]
)
for name, prop_schema in schema["fields"].items()
},
nullable=False,
)


def data_from_dataclass_args_schema(schema: core_schema.DataclassArgsSchema):
return data.ObjectData(
properties={
field.get("serialization_alias", field["name"]): data_from_schema(
field["schema"]
)
for field in schema["fields"]
},
nullable=False,
)


def data_from_dataclass_schema(schema: core_schema.DataclassSchema):
return data_from_dataclass_args_schema(schema["schema"]) # type: ignore


def data_from_typed_dict_schema(schema: core_schema.TypedDictSchema):
return data.ObjectData(
properties={
field.get("serialization_alias", name): data_from_schema(field["schema"])
for name, field in schema["fields"].items()
},
required=[
field.get("serialization_alias", name)
for name, field in schema["fields"].items()
if field.get("required", False)
],
nullable=False,
)


def data_from_list_schema(schema: core_schema.ListSchema):
return data.ArrayData(
items=(
data_from_schema(schema["items_schema"])
if "items_schema" in schema
else None
),
min_items=schema.get("min_length", 0),
max_items=schema.get("max_length"),
nullable=False,
)


def data_from_set_schema(schema: core_schema.SetSchema):
return data.ArrayData(
items=(
data_from_schema(schema["items_schema"])
if "items_schema" in schema
else None
),
min_items=schema.get("min_length", 0),
max_items=schema.get("max_length"),
unique_items=True,
nullable=False,
)


def data_from_dict_schema(schema: core_schema.DictSchema):
return data.ObjectData(nullable=False)


def data_from_int_schema(schema: core_schema.IntSchema):
return data.IntegerData(
minimum=schema.get("ge"),
maximum=schema.get("le"),
exclusive_minimum=schema.get("gt"),
exclusive_maximum=schema.get("lt"),
multiple_of=schema.get("multiple_of"),
nullable=False,
)


def data_from_float_schema(schema: core_schema.FloatSchema):
return data.NumberData(
minimum=schema.get("ge"),
maximum=schema.get("le"),
exclusive_minimum=schema.get("gt"),
exclusive_maximum=schema.get("lt"),
multiple_of=schema.get("multiple_of"),
nullable=False,
)


def data_from_bool_schema(schema: core_schema.BoolSchema):
return data.BooleanData(nullable=False)


def data_from_str_schema(schema: core_schema.StringSchema):
return data.StringData(
pattern=schema.get("pattern"),
min_length=schema.get("min_length", 0),
max_length=schema.get("max_length"),
nullable=False,
)


def data_from_bytes_schema(schema: core_schema.BytesSchema):
return data.StringData(
sdf_type="byte-string",
format="bytes",
min_length=schema.get("min_length", 0),
max_length=schema.get("max_length"),
nullable=False,
)


def data_from_literal_schema(schema: core_schema.LiteralSchema):
choices = schema["expected"]
if len(choices) == 1:
return data.AnyData(
const=choices[0],
nullable=False,
)
if all(isinstance(choice, str) for choice in choices):
return data.StringData(
enum=choices,
nullable=False,
)
raise NotImplementedError(f"Literal with {choices} not supported")


data = from_json_schema(schema)
def data_from_enum_schema(schema: core_schema.EnumSchema):
if "sub_type" not in schema:
return data.AnyData(
choices={
member.name: data.AnyData(const=member.value)
for member in schema["members"]
},
nullable=False,
)
if schema["sub_type"] == "int":
return data.IntegerData(
choices={
member.name: data.IntegerData(const=member.value)
for member in schema["members"]
},
nullable=False,
)
if schema["sub_type"] == "str":
return data.StringData(
choices={
member.name: data.StringData(const=member.value)
for member in schema["members"]
},
nullable=False,
)

if isinstance(data, IntegerData) and data.enum and issubclass(type_, Enum):
data.choices = {
member.name: IntegerData(const=member.value) for member in type_
}
data.enum = None

return data
def data_from_datetime_schema(schema: core_schema.DatetimeSchema):
return data.StringData(nullable=False, format="date-time")
Loading

0 comments on commit 947e46c

Please sign in to comment.