Skip to content

Commit 4484c2b

Browse files
authored
Allow passing a base_schema to load_config() (#1103)
This PR adds a new argument to `load_config()` to allow passing a `base_schema` to be used when loading the configuration. This is necessary to be able to use custom fields in the schema, like the ones provided by `frequenz.quantities`. It also drops support for using `marshmallow_dataclas.dataclass` directly. `marshmallow_dataclass.dataclass` is intended to be used only when using `my_dataclass.Schema` to get the schema. But using this is not very convenient when using type hints as they are not well-supported by `marshmallow`, as the `load()` function can't have hints. This is actually why `load_config()` exists in the first place, so we are using `class_schema()` instead, so we don't really need that our types are decorated with `marshmallow_dataclass`, we can use the built-in `dataclass` instead, we just need to add the appropriate metadata if we want more complex validation. Using `class_shema()` is also necessary to be able to pass a `base_schema`, which we'll need when we want to use schemas with custom fields, like the ones provided by `frequenz.quantities`. Finally, it improves `load_config` documentation to make it more explicit that this is just a wrapper to external libraries, so users should read their documentation in full and which functions are used exactly.
2 parents 47adc26 + 1be3d2c commit 4484c2b

File tree

3 files changed

+65
-35
lines changed

3 files changed

+65
-35
lines changed

RELEASE_NOTES.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
## Upgrading
88

9-
<!-- Here goes notes on how to upgrade from previous versions, including deprecations and what they should be replaced with -->
9+
- `frequenz.sdk.config.load_config()` doesn't accept classes decorated with `marshmallow_dataclass.dataclass` anymore. You should use the built-in `dataclasses.dataclass` directly instead, no other changes should be needed, the metadata in the `dataclass` fields will still be used.
1010

1111
## New Features
1212

13-
<!-- Here goes the main new features and examples or instructions on how to use them -->
13+
14+
- `frequenz.sdk.config.load_config()` can now use a base schema to customize even further how data is loaded.
1415

1516
## Bug Fixes
1617

src/frequenz/sdk/config/_util.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,71 @@
44
"""Utilities to deal with configuration."""
55

66
from collections.abc import Mapping
7-
from typing import Any, TypeVar, cast
7+
from typing import Any, ClassVar, Protocol, TypeVar, cast
88

9+
from marshmallow import Schema
910
from marshmallow_dataclass import class_schema
1011

11-
T = TypeVar("T")
12+
13+
# This is a hack that relies on identifying dataclasses by looking into an undocumented
14+
# property of dataclasses[1], so it might break in the future. Nevertheless, it seems to
15+
# be widely used in the community, for example `mypy` and `pyright` seem to rely on
16+
# it[2].
17+
#
18+
# [1]: https://github.com/python/mypy/issues/15974#issuecomment-1694781006
19+
# [2]: https://github.com/python/mypy/issues/15974#issuecomment-1694993493
20+
class Dataclass(Protocol):
21+
"""A protocol for dataclasses."""
22+
23+
__dataclass_fields__: ClassVar[dict[str, Any]]
24+
"""The fields of the dataclass."""
25+
26+
27+
DataclassT = TypeVar("DataclassT", bound=Dataclass)
1228
"""Type variable for configuration classes."""
1329

1430

1531
def load_config(
16-
cls: type[T],
32+
cls: type[DataclassT],
1733
config: Mapping[str, Any],
1834
/,
35+
base_schema: type[Schema] | None = None,
1936
**marshmallow_load_kwargs: Any,
20-
) -> T:
37+
) -> DataclassT:
2138
"""Load a configuration from a dictionary into an instance of a configuration class.
2239
2340
The configuration class is expected to be a [`dataclasses.dataclass`][], which is
2441
used to create a [`marshmallow.Schema`][] schema to validate the configuration
25-
dictionary.
42+
dictionary using [`marshmallow_dataclass.class_schema`][] (which in turn uses the
43+
[`marshmallow.Schema.load`][] method to do the validation and deserialization).
2644
27-
To customize the schema derived from the configuration dataclass, you can use
28-
[`marshmallow_dataclass.dataclass`][] to specify extra metadata.
45+
To customize the schema derived from the configuration dataclass, you can use the
46+
`metadata` key in [`dataclasses.field`][] to pass extra options to
47+
[`marshmallow_dataclass`][] to be used during validation and deserialization.
2948
3049
Additional arguments can be passed to [`marshmallow.Schema.load`][] using keyword
31-
arguments.
50+
arguments `marshmallow_load_kwargs`.
51+
52+
Note:
53+
This method will raise [`marshmallow.ValidationError`][] if the configuration
54+
dictionary is invalid and you have to have in mind all of the gotchas of
55+
[`marshmallow`][] and [`marshmallow_dataclass`][] applies when using this
56+
function. It is recommended to carefully read the documentation of these
57+
libraries.
3258
3359
Args:
3460
cls: The configuration class.
3561
config: The configuration dictionary.
62+
base_schema: An optional class to be used as a base schema for the configuration
63+
class. This allow using custom fields for example. Will be passed to
64+
[`marshmallow_dataclass.class_schema`][].
3665
**marshmallow_load_kwargs: Additional arguments to be passed to
3766
[`marshmallow.Schema.load`][].
3867
3968
Returns:
4069
The loaded configuration as an instance of the configuration class.
4170
"""
42-
instance = class_schema(cls)().load(config, **marshmallow_load_kwargs)
71+
instance = class_schema(cls, base_schema)().load(config, **marshmallow_load_kwargs)
4372
# We need to cast because `.load()` comes from marshmallow and doesn't know which
4473
# type is returned.
45-
return cast(T, instance)
74+
return cast(DataclassT, instance)

tests/config/test_util.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any
88

99
import marshmallow
10-
import marshmallow_dataclass
1110
import pytest
1211
from pytest_mock import MockerFixture
1312

@@ -18,14 +17,6 @@
1817
class SimpleConfig:
1918
"""A simple configuration class for testing."""
2019

21-
name: str
22-
value: int
23-
24-
25-
@marshmallow_dataclass.dataclass
26-
class MmSimpleConfig:
27-
"""A simple configuration class for testing."""
28-
2920
name: str = dataclasses.field(metadata={"validate": lambda s: s.startswith("test")})
3021
value: int
3122

@@ -37,27 +28,36 @@ def test_load_config_dataclass() -> None:
3728
loaded_config = load_config(SimpleConfig, config)
3829
assert loaded_config == SimpleConfig(name="test", value=42)
3930

40-
config["name"] = "not test"
41-
loaded_config = load_config(SimpleConfig, config)
42-
assert loaded_config == SimpleConfig(name="not test", value=42)
43-
44-
45-
def test_load_config_marshmallow_dataclass() -> None:
46-
"""Test that load_config loads a configuration into a configuration class."""
47-
config: dict[str, Any] = {"name": "test", "value": 42}
48-
loaded_config = load_config(MmSimpleConfig, config)
49-
assert loaded_config == MmSimpleConfig(name="test", value=42)
50-
5131
config["name"] = "not test"
5232
with pytest.raises(marshmallow.ValidationError):
53-
_ = load_config(MmSimpleConfig, config)
33+
_ = load_config(SimpleConfig, config)
5434

5535

5636
def test_load_config_load_None() -> None:
5737
"""Test that load_config raises ValidationError if the configuration is None."""
5838
config: dict[str, Any] = {}
5939
with pytest.raises(marshmallow.ValidationError):
60-
_ = load_config(MmSimpleConfig, config.get("loggers", None))
40+
_ = load_config(SimpleConfig, config.get("loggers", None))
41+
42+
43+
def test_load_config_with_base_schema() -> None:
44+
"""Test that load_config loads a configuration using a base schema."""
45+
46+
class _MyBaseSchema(marshmallow.Schema):
47+
"""A base schema for testing."""
48+
49+
class Meta:
50+
"""Meta options for the schema."""
51+
52+
unknown = marshmallow.EXCLUDE
53+
54+
config: dict[str, Any] = {"name": "test", "value": 42, "extra": "extra"}
55+
56+
loaded_config = load_config(SimpleConfig, config, base_schema=_MyBaseSchema)
57+
assert loaded_config == SimpleConfig(name="test", value=42)
58+
59+
with pytest.raises(marshmallow.ValidationError):
60+
_ = load_config(SimpleConfig, config)
6161

6262

6363
def test_load_config_type_hints(mocker: MockerFixture) -> None:
@@ -70,7 +70,7 @@ def test_load_config_type_hints(mocker: MockerFixture) -> None:
7070
config: dict[str, Any] = {}
7171

7272
# We add the type hint to test that the return type (hint) is correct
73-
_: MmSimpleConfig = load_config(MmSimpleConfig, config, marshmallow_arg=1)
73+
_: SimpleConfig = load_config(SimpleConfig, config, marshmallow_arg=1)
7474
mock_class_schema.return_value.load.assert_called_once_with(
7575
config, marshmallow_arg=1
7676
)

0 commit comments

Comments
 (0)