Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add setter support #11

Merged
merged 10 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.13"
- "3.12"
- "3.11"
- "3.10"
Expand Down
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

# znfields

Provide a `getter` for `dataclasses.fields` to allow e.g. for lazy evaluation.
Provide a `getter` and `setter` for `dataclasses.fields` to allow e.g. for lazy
evaluation or field content validation.

```bash
pip install znfields
Expand All @@ -19,10 +20,15 @@ additional `getter` argument.
import dataclasses
import znfields

def parameter_getter(self, name):
def getter(self, name) -> str:
return f"{name}:{self.__dict__[name]}"

def setter(self, name, value) -> None:
if not isinstance(value, float):
raise ValueError(f"Value {value} is not a float")
self.__dict__[name] = value

@dataclasses.dataclass
class ClassWithParameter(znfields.Base):
parameter: float = znfields.field(getter=parameter_getter)
class MyModel(znfields.Base):
parameter: float = znfields.field(getter=getter, setter=setter)
```
27 changes: 27 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import dataclasses

import pytest

import znfields


def getter(self, name):
return f"{name}:{self.__dict__[name]}"


def setter(self, name, value):
if not isinstance(value, float):
raise ValueError(f"Value {value} is not a float")
self.__dict__[name] = value


@dataclasses.dataclass
class MyModel(znfields.Base):
parameter: float = znfields.field(getter=getter, setter=setter)


def test_readme():
model = MyModel(parameter=3.14)
assert model.parameter == "parameter:3.14"
with pytest.raises(ValueError):
model.parameter = 42
138 changes: 130 additions & 8 deletions tests/test_znfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,52 @@
import znfields


def example1_parameter_getter(self, name):
def getter_01(self, name):
return f"{name}:{self.__dict__[name]}"


def setter_01(self, name, value):
if not isinstance(value, float):
raise ValueError(f"Value {value} is not a float")
self.__dict__[name] = value


def stringify_list(self, name):
content = self.__dict__[name]
self.__dict__[name] = [str(x) for x in content]
# Can not return a copy to append to, but must be the same object
return self.__dict__[name]


@dataclasses.dataclass
class SetterGetterNoInit(znfields.Base):
parameter: float = znfields.field(getter=getter_01, setter=setter_01, init=False)


@dataclasses.dataclass
class SetterOnly(znfields.Base):
parameter: float = znfields.field(setter=setter_01)


@dataclasses.dataclass
class Example1(znfields.Base):
parameter: float = znfields.field(getter=example1_parameter_getter)
parameter: float = znfields.field(getter=getter_01)


@dataclasses.dataclass
class Example1WithDefault(znfields.Base):
parameter: float = znfields.field(getter=example1_parameter_getter, default=1)
parameter: float = znfields.field(getter=getter_01, default=1)


@dataclasses.dataclass
class Example1WithDefaultFactory(znfields.Base):
parameter: list = znfields.field(getter=stringify_list, default_factory=list)


class NoDataClass(znfields.Base):
parameter: float = znfields.field(getter=getter_01, setter=setter_01)


def test_example1():
example = Example1(parameter=1)
assert example.parameter == "parameter:1"
Expand Down Expand Up @@ -58,14 +78,15 @@ def test_example2():

def test_wrong_metadata():
with pytest.raises(TypeError):
znfields.field(getter=example1_parameter_getter, metadata="Hello")
znfields.field(getter=getter_01, metadata="Hello")

with pytest.raises(TypeError):
znfields.field(setter=setter_01, metadata="Hello")


@dataclasses.dataclass
class Example3(znfields.Base):
parameter: float = znfields.field(
getter=example1_parameter_getter, metadata={"category": "test"}
)
parameter: float = znfields.field(getter=getter_01, metadata={"category": "test"})


def test_example3():
Expand All @@ -75,7 +96,7 @@ def test_example3():
field = dataclasses.fields(example)[0]
assert field.metadata == {
"category": "test",
znfields.ZNFIELDS_GETTER_TYPE: example1_parameter_getter,
znfields.ZNFIELDS_GETTER_TYPE: getter_01,
}


Expand Down Expand Up @@ -154,3 +175,104 @@ def test_default_factory():
assert example.parameter == []
example.parameter.append(1)
assert example.parameter == ["1"]


def test_getter_setter_no_init():
example = SetterGetterNoInit()
with pytest.raises(ValueError):
example.parameter = "text"

example.parameter = 3.14
assert example.parameter == "parameter:3.14"

# test non-field attributes
example.some_attribute = 42
assert example.some_attribute == 42


@dataclasses.dataclass
class ParentClass(znfields.Base):
parent_field: str = znfields.field(getter=getter_01)


@dataclasses.dataclass
class ChildClass(ParentClass):
child_field: str = znfields.field(getter=getter_01)


def test_inherited_getter():
instance = ChildClass(parent_field="parent", child_field="child")
assert instance.parent_field == "parent_field:parent"
assert instance.child_field == "child_field:child"


def test_setter_validation():
example = SetterGetterNoInit()

with pytest.raises(ValueError):
example.parameter = "invalid value"

with pytest.raises(KeyError):
# dict is not set, getter raises KeyError instead of AttributeError
assert example.parameter is None

example.parameter = 2.71
assert example.parameter == "parameter:2.71"


@dataclasses.dataclass
class NoDefaultField(znfields.Base):
parameter: float = znfields.field(getter=getter_01, setter=setter_01)


def test_no_default_field():
with pytest.raises(TypeError):
NoDefaultField() # should raise because no default is provided
obj = NoDefaultField(parameter=1.23)
assert obj.parameter == "parameter:1.23"


@dataclasses.dataclass
class CombinedGetterSetter(znfields.Base):
parameter: float = znfields.field(getter=getter_01, setter=setter_01)


def test_combined_getter_setter():
obj = CombinedGetterSetter(parameter=2.5)
assert obj.parameter == "parameter:2.5"
obj.parameter = 3.5
assert obj.parameter == "parameter:3.5"

with pytest.raises(ValueError):
obj.parameter = "invalid value"


@dataclasses.dataclass
class Nested(znfields.Base):
inner_field: float = znfields.field(getter=getter_01)


@dataclasses.dataclass
class Outer(znfields.Base):
outer_field: Nested = dataclasses.field(default_factory=lambda: Nested(1.0))


def test_nested_dataclass():
obj = Outer()
assert obj.outer_field.inner_field == "inner_field:1.0"


def test_no_dataclass():
x = NoDataClass()
with pytest.raises(TypeError, match="is not a dataclass"):
x.parameter = 5

with pytest.raises(TypeError, match="is not a dataclass"):
assert x.parameter is None


def test_setter_only():
x = SetterOnly(parameter=5.5)
with pytest.raises(ValueError):
x.parameter = "5"
assert x.parameter == 5.5
90 changes: 89 additions & 1 deletion znfields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,45 @@


class _ZNFIELDS_GETTER_TYPE:
"""Sentinel class to identify the getter type."""

pass


class _ZNFIELDS_SETTER_TYPE:
"""Sentinel class used to identify the setter type."""

pass


# Sentinels to identify the getter and setter types
ZNFIELDS_GETTER_TYPE = _ZNFIELDS_GETTER_TYPE()
ZNFIELDS_SETTER_TYPE = _ZNFIELDS_SETTER_TYPE()


class Base:
"""Base class to extend dataclasses with custom getter and setter behavior
through field metadata.

Methods
-------
__getattribute__(name: str) -> Any
Overrides the default behavior of attribute access to allow for
custom getter functionality defined via field metadata.
__setattr__(name: str, value: Any) -> None
Overrides the default behavior of attribute assignment to allow for
custom setter functionality defined via field metadata.
"""

def __getattribute__(self, name: str) -> Any:
"""Overrides the default behavior of attribute access.

Allow for custom getter functionality defined via field metadata.

Raises
------
TypeError: If the class is not a dataclass.
"""
if name.startswith("__") and name.endswith("__"):
return super().__getattribute__(name)
if not dataclasses.is_dataclass(self):
Expand All @@ -27,11 +58,59 @@ def __getattribute__(self, name: str) -> Any:
return lazy(self, name)
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""Overrides the default behavior of attribute assignment.

Allow for custom setter functionality defined via field metadata.

Raises
------
TypeError: If the class is not a dataclass.
"""
if not dataclasses.is_dataclass(self):
raise TypeError(f"{self} is not a dataclass")
try:
field = next(
field for field in dataclasses.fields(self) if field.name == name
)
except StopIteration:
return super().__setattr__(name, value)
setter = field.metadata.get(ZNFIELDS_SETTER_TYPE)
if setter:
setter(self, name, value)
else:
super().__setattr__(name, value)


@functools.wraps(dataclasses.field)
def field(
*, getter: Optional[Callable[[Any, str], Any]] = None, **kwargs
*,
getter: Optional[Callable[[Any, str], Any]] = None,
setter: Optional[Callable[[Any, str, Any], None]] = None,
**kwargs,
) -> dataclasses.Field:
"""Wrapper around `dataclasses.field` to allow for defining custom
getter and setter functions via metadata.

Attributes
----------
getter : Optional[Callable[[Any, str], Any]]
A function that takes the instance and attribute name as arguments
and returns the value of the attribute.
setter : Optional[Callable[[Any, str, Any], None]]
A function that takes the instance, attribute name, and value as
arguments and sets the value of the attribute.

Returns
-------
dataclasses.Field
A field object with custom getter and setter functionality defined
via metadata.

Raises
------
TypeError: If the metadata is not a dictionary.
"""
if getter is not None:
if "metadata" in kwargs:
if not isinstance(kwargs["metadata"], dict):
Expand All @@ -41,4 +120,13 @@ def field(
kwargs["metadata"][ZNFIELDS_GETTER_TYPE] = getter
else:
kwargs["metadata"] = {ZNFIELDS_GETTER_TYPE: getter}
if setter is not None:
if "metadata" in kwargs:
if not isinstance(kwargs["metadata"], dict):
raise TypeError(
f"metadata must be a dict, not {type(kwargs['metadata'])}"
)
kwargs["metadata"][ZNFIELDS_SETTER_TYPE] = setter
else:
kwargs["metadata"] = {ZNFIELDS_SETTER_TYPE: setter}
return dataclasses.field(**kwargs)
Loading