From db29595cdc16e6156fb58d853c1fd887f55e1559 Mon Sep 17 00:00:00 2001 From: Christian Sandberg Date: Thu, 19 Sep 2024 12:15:27 +0200 Subject: [PATCH] Refined schema typing --- onedm/sdf/common.py | 3 +- onedm/sdf/data.py | 79 +++++++++++++++++++----------- onedm/sdf/definitions.py | 63 +++++++++++++++--------- tests/sdf/test_value_validation.py | 20 +++++--- 4 files changed, 104 insertions(+), 61 deletions(-) diff --git a/onedm/sdf/common.py b/onedm/sdf/common.py index 35c2cd7..ec3e89f 100644 --- a/onedm/sdf/common.py +++ b/onedm/sdf/common.py @@ -6,13 +6,12 @@ class CommonQualities(BaseModel): model_config = ConfigDict( - extra="allow", validate_assignment=True, alias_generator=to_camel + extra="allow", alias_generator=to_camel ) label: str | None = None description: str | None = None ref: str | None = Field(None, alias="sdfRef") - required: list[str | bool] = Field(default_factory=list, alias="sdfRequired") def get_extra(self) -> dict[str, Any]: return self.__pydantic_extra__ diff --git a/onedm/sdf/data.py b/onedm/sdf/data.py index e3e826b..c4ed4da 100644 --- a/onedm/sdf/data.py +++ b/onedm/sdf/data.py @@ -7,7 +7,7 @@ from __future__ import annotations from abc import ABC -from enum import Enum +import datetime from typing import Annotated, Any, Literal, Union from pydantic import Field, NonNegativeInt, model_validator @@ -16,20 +16,11 @@ from .common import CommonQualities -class DataType(str, Enum): - BOOLEAN = "boolean" - NUMBER = "number" - INTEGER = "integer" - STRING = "string" - OBJECT = "object" - ARRAY = "array" - - class DataQualities(CommonQualities, ABC): """Base class for all data qualities.""" - type: DataType - sdf_type: str | None = None + type: Literal["boolean", "number", "integer", "string", "object", "array"] + sdf_type: str | None = Field(None, pattern=r"^[a-z][\-a-z0-9]*$") nullable: bool = True const: Any | None = None default: Any | None = None @@ -65,7 +56,7 @@ def validate(self, input: Any) -> Any: class NumberData(DataQualities): - type: Literal[DataType.NUMBER] + type: Literal["number"] unit: str | None = None minimum: float | None = None maximum: float | None = None @@ -80,10 +71,33 @@ class NumberData(DataQualities): @classmethod def set_default_type(cls, data: Any): if isinstance(data, dict): - data.setdefault("type", DataType.NUMBER) + data.setdefault("type", "number") return data - def _get_base_schema(self) -> core_schema.FloatSchema: + def _get_base_schema(self) -> core_schema.FloatSchema | core_schema.DatetimeSchema: + if self.sdf_type == "unix-time": + return core_schema.datetime_schema( + ge=( + datetime.datetime.fromtimestamp(self.minimum) + if self.minimum is not None + else None + ), + le=( + datetime.datetime.fromtimestamp(self.maximum) + if self.maximum is not None + else None + ), + gt=( + datetime.datetime.fromtimestamp(self.exclusive_minimum) + if self.exclusive_minimum is not None + else None + ), + lt=( + datetime.datetime.fromtimestamp(self.exclusive_maximum) + if self.exclusive_maximum is not None + else None + ), + ) return core_schema.float_schema( ge=self.minimum, le=self.maximum, @@ -92,19 +106,18 @@ def _get_base_schema(self) -> core_schema.FloatSchema: multiple_of=self.multiple_of, ) - def validate(self, input: Any) -> int: + def validate(self, input: Any) -> float: return super().validate(input) class IntegerData(DataQualities): - type: Literal[DataType.INTEGER] + type: Literal["integer"] unit: str | None = None minimum: int | None = None maximum: int | None = None exclusive_minimum: int | None = None exclusive_maximum: int | None = None multiple_of: int | None = None - format: str | None = None choices: dict[str, IntegerData] | None = Field(None, alias="sdfChoice") const: int | None = None default: int | None = None @@ -113,7 +126,7 @@ class IntegerData(DataQualities): @classmethod def set_default_type(cls, data: Any): if isinstance(data, dict): - data.setdefault("type", DataType.INTEGER) + data.setdefault("type", "integer") return data def _get_base_schema(self) -> core_schema.IntSchema: @@ -130,7 +143,7 @@ def validate(self, input: Any) -> int: class BooleanData(DataQualities): - type: Literal[DataType.BOOLEAN] + type: Literal["boolean"] const: bool | None = None default: bool | None = None @@ -138,7 +151,7 @@ class BooleanData(DataQualities): @classmethod def set_default_type(cls, data: Any): if isinstance(data, dict): - data.setdefault("type", DataType.BOOLEAN) + data.setdefault("type", "boolean") return data def _get_base_schema(self) -> core_schema.BoolSchema: @@ -149,7 +162,7 @@ def validate(self, input: Any) -> bool: class StringData(DataQualities): - type: Literal[DataType.STRING] + type: Literal["string"] enum: list[str] | None = None min_length: NonNegativeInt = 0 max_length: NonNegativeInt | None = None @@ -164,7 +177,7 @@ class StringData(DataQualities): @classmethod def set_default_type(cls, data: Any): if isinstance(data, dict): - data.setdefault("type", DataType.STRING) + data.setdefault("type", "string") return data def _get_base_schema( @@ -176,6 +189,16 @@ def _get_base_schema( return core_schema.bytes_schema( min_length=self.min_length, max_length=self.max_length ) + if self.format == "uuid": + return core_schema.uuid_schema() + if self.format == "date-time": + return core_schema.datetime_schema() + if self.format == "date": + return core_schema.date_schema() + if self.format == "time": + return core_schema.time_schema() + if self.format == "uri": + return core_schema.url_schema() return core_schema.str_schema( min_length=self.min_length, max_length=self.max_length, @@ -187,7 +210,7 @@ def validate(self, input: Any) -> str | bytes: class ArrayData(DataQualities): - type: Literal[DataType.ARRAY] + type: Literal["array"] min_items: NonNegativeInt = 0 max_items: NonNegativeInt | None = None unique_items: bool = False @@ -199,7 +222,7 @@ class ArrayData(DataQualities): @classmethod def set_default_type(cls, data: Any): if isinstance(data, dict): - data.setdefault("type", DataType.ARRAY) + data.setdefault("type", "array") return data def _get_base_schema(self) -> core_schema.ListSchema | core_schema.SetSchema: @@ -220,8 +243,8 @@ def validate(self, input: Any) -> list | set: class ObjectData(DataQualities): - type: Literal[DataType.OBJECT] - required: list[str] | None = None + type: Literal["object"] + required: list[str] = Field(default_factory=list) properties: dict[str, Data] | None = None const: dict[str, Any] | None = None default: dict[str, Any] | None = None @@ -230,7 +253,7 @@ class ObjectData(DataQualities): @classmethod def set_default_type(cls, data: Any): if isinstance(data, dict): - data.setdefault("type", DataType.OBJECT) + data.setdefault("type", "object") return data def _get_base_schema(self) -> core_schema.TypedDictSchema: diff --git a/onedm/sdf/definitions.py b/onedm/sdf/definitions.py index 135c6f6..659acac 100644 --- a/onedm/sdf/definitions.py +++ b/onedm/sdf/definitions.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Annotated, Union +from typing import Annotated, Literal, Union -from pydantic import Field +from pydantic import Field, NonNegativeInt from .common import CommonQualities from .data import ( @@ -16,38 +16,53 @@ ) -class PropertyCommon: +class NumberProperty(NumberData): observable: bool = True readable: bool = True writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") -class NumberProperty(NumberData, PropertyCommon): - pass - - -class IntegerProperty(IntegerData, PropertyCommon): - pass +class IntegerProperty(IntegerData): + observable: bool = True + readable: bool = True + writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") -class BooleanProperty(BooleanData, PropertyCommon): - pass +class BooleanProperty(BooleanData): + observable: bool = True + readable: bool = True + writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") -class StringProperty(StringData, PropertyCommon): - pass +class StringProperty(StringData): + observable: bool = True + readable: bool = True + writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") -class ArrayProperty(ArrayData, PropertyCommon): - pass +class ArrayProperty(ArrayData): + observable: bool = True + readable: bool = True + writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") -class ObjectProperty(ObjectData, PropertyCommon): - pass +class ObjectProperty(ObjectData): + observable: bool = True + readable: bool = True + writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") -class AnyProperty(AnyData, PropertyCommon): - pass +class AnyProperty(AnyData): + observable: bool = True + readable: bool = True + writable: bool = True + required: list[Literal[True]] | None = Field(default=None, alias="sdfRequired") Property = Union[ @@ -78,9 +93,10 @@ class Object(CommonQualities): actions: dict[str, Action] = Field(default_factory=dict, alias="sdfAction") events: dict[str, Event] = Field(default_factory=dict, alias="sdfEvent") data: dict[str, Data] = Field(default_factory=dict, alias="sdfData") + required: list[str] = Field(default_factory=list, alias="sdfRequired") # If array of objects - min_items: int | None = None - max_items: int | None = None + min_items: NonNegativeInt | None = None + max_items: NonNegativeInt | None = None class Thing(CommonQualities): @@ -90,6 +106,7 @@ class Thing(CommonQualities): actions: dict[str, Action] = Field(default_factory=dict, alias="sdfAction") events: dict[str, Event] = Field(default_factory=dict, alias="sdfEvent") data: dict[str, Data] = Field(default_factory=dict, alias="sdfData") + required: list[str] = Field(default_factory=list, alias="sdfRequired") # If array of things - min_items: int | None = None - max_items: int | None = None + min_items: NonNegativeInt | None = None + max_items: NonNegativeInt | None = None diff --git a/tests/sdf/test_value_validation.py b/tests/sdf/test_value_validation.py index 83499ab..32ac983 100644 --- a/tests/sdf/test_value_validation.py +++ b/tests/sdf/test_value_validation.py @@ -2,13 +2,14 @@ from onedm import sdf -def test_integer_validation(test_model: sdf.SDF): - assert test_model.data["Integer"].validate(2) == 2 +def test_integer_validation(): + integer = sdf.IntegerData(maximum=2) + assert integer.validate(2) == 2 with pytest.raises(ValueError): - test_model.data["Integer"].validate(True) + integer.validate(1.5) # Out of range with pytest.raises(ValueError): - test_model.data["Integer"].validate(3) + integer.validate(3) def test_number_validation(test_model: sdf.SDF): @@ -40,9 +41,12 @@ def test_string_validation(test_model: sdf.SDF): test_model.data["Number"].validate(["0123456789"]) -def test_nullable_validation(test_model: sdf.SDF): - assert test_model.data["NullableInteger"].validate(None) == None +def test_nullable_validation(): + nullable_integer = sdf.IntegerData(nullable=True) + assert nullable_integer.validate(None) == None - # Not nullable + +def test_non_nullable_validation(): + integer = sdf.IntegerData(nullable=False) with pytest.raises(ValueError): - test_model.data["Integer"].validate(None) + integer.validate(None)