Skip to content

Commit

Permalink
Allow subclassing in hydra config classes (#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored Mar 23, 2024
2 parents 641af83 + 8402529 commit 57e70d7
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 94 deletions.
92 changes: 60 additions & 32 deletions ranzen/hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from collections.abc import Iterator, MutableMapping, Sequence
from contextlib import contextmanager
import dataclasses
from dataclasses import MISSING, Field, asdict, is_dataclass
from dataclasses import MISSING, asdict, is_dataclass
from enum import Enum
import shlex
from typing import Any, Final, Union, cast
from typing import Any, Final, cast, get_args, get_type_hints
from typing_extensions import deprecated

import attrs
from attrs import NOTHING, Attribute
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
Expand Down Expand Up @@ -174,39 +172,69 @@ class Config:
register_hydra_config(Config, groups)
"""
assert isinstance(main_cls, type), "`main_cls` has to be a type."
configs: Union[tuple[Attribute, ...], tuple[Field, ...]]
is_dc = is_dataclass(main_cls)
if is_dc:
configs = dataclasses.fields(main_cls)
elif attrs.has(main_cls):
configs = attrs.fields(main_cls)
else:
if not is_dataclass(main_cls):
raise ValueError(f"The config class {main_cls.__name__} should be a dataclass.")
entries = dataclasses.fields(main_cls)
try:
types = get_type_hints(main_cls)
except NameError as exc:
raise ValueError(
f"The given class {main_cls.__name__} is neither a dataclass nor an attrs class."
)
ABSENT = MISSING if is_dc else NOTHING

for config in configs:
if config.type == Any or (isinstance(typ := config.type, str) and typ == "Any"):
if config.name not in groups:
raise ValueError(f"{IF} type Any, {NEED} variants: `{config.name}`")
if config.default is not ABSENT or (
isinstance(config, Field) and config.default_factory is not ABSENT
):
raise ValueError(f"{IF} type Any, {NEED} no default value: `{config.name}`")
f"Can't resolve type hints from the config class: `{main_cls.__name__}`."
) from exc

for entry in entries:
typ = types[entry.name]
if typ == Any:
if (group := groups.get(entry.name)) is not None:
for var_name, var_class in group.items():
if not is_dataclass(var_class):
raise ValueError(
f"All variants should be dataclasses: type `{var_class.__name__}` "
f"of variant `{entry.name}={var_name}` is not a dataclass."
)
else:
raise ValueError(f"{IF} type `Any`, {NEED} variants: `{entry.name}`")
if entry.default is not MISSING or entry.default_factory is not MISSING:
raise ValueError(f"{IF} type `Any`, {NEED} no default value: `{entry.name}`")
else:
if config.name in groups:
raise ValueError(f"{IF} a real type, {NEED} no variants: `{config.name}`")
if config.default is ABSENT and not (
isinstance(config, Field) and config.default_factory is not ABSENT
):
raise ValueError(f"{IF} a real type, {NEED} a default value: `{config.name}`")
if is_dataclass(typ):
if entry.default is MISSING and entry.default_factory is MISSING:
if (group := groups.get(entry.name)) is not None:
for var_name, var_class in group.items():
if not issubclass(var_class, typ):
raise ValueError(
f"All variants should be subclasses of their entry's type: type"
f" `{var_class.__name__}` of variant `{entry.name}={var_name}` "
f"is not a subclass of `{typ.__name__}`."
)
else:
raise ValueError(
f"{IF} a dataclass type, "
f"{NEED} a default value or registered variants: `{entry.name}`. "
"You can specify a default value with `field(default_factory=...)`."
)
else:
if entry.name in groups:
raise ValueError(
f"Can't have both a default value and variants: `{entry.name}`."
)
elif entry.name in groups:
raise ValueError(
f"Entry `{entry.name}` has registered variants, but its type, "
f"`{entry.type.__name__}`, is not a dataclass."
)

cs = ConfigStore.instance()
cs.store(node=main_cls, name=schema_name)
for group, entries in groups.items():
for name, node in entries.items():
for var_name, var_type in entries.items():
if (bases := getattr(var_type, "__orig_bases__", None)) is not None:
if len(bases) > 0 and len(get_args(bases[0])) > 0:
raise ValueError(
f"Can't register a dataclass with generic base class: `{var_type.__name__}`"
f" with base class `{bases[0].__name__}`."
)
try:
cs.store(node=node, name=name, group=group)
cs.store(node=var_type, name=var_name, group=group)
except Exception as exc:
raise RuntimeError(f"{main_cls=}, {node=}, {name=}, {group=}") from exc
raise RuntimeError(f"{main_cls=}, {var_type=}, {var_name=}, {group=}") from exc
145 changes: 83 additions & 62 deletions tests/hydra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any
from typing import Any, Generic, TypeVar

import attrs
from attrs import define
from omegaconf import MISSING, DictConfig, MissingMandatoryValue, OmegaConf
import pytest

from ranzen.hydra import prepare_for_logging, register_hydra_config


def test_dataclass_no_default() -> None:
def test_config_no_default() -> None:
"""This isn't so much wrong as just clumsy."""

@dataclass
Expand All @@ -28,11 +26,18 @@ class Config:
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
register_hydra_config(Config, options)

@dataclass
class UnrelatedClass:
root: Path

options = {"dm": {"base": UnrelatedClass}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_dataclass_any() -> None:
def test_config_any() -> None:
@dataclass
class DataModule:
root: Path
Expand All @@ -50,124 +55,140 @@ class Config:
options = {"dm": {"base": DataModule}}
register_hydra_config(Config, options)

class NotDC:
x: int

options = {"dm": {"base": NotDC}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

def test_dataclass_any_with_default() -> None:
"""An Any field with default is completely out."""

def test_config_any_string() -> None:
@dataclass
class Model:
layers: int = 1
class DataModule:
root: Path

@dataclass
class Config:
model: Any = dataclasses.field(default_factory=Model)
dm: "Any"

# we're assuming that the only reason you want to use Any is that
# you want to use variants
options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)
options = {"dm": {"base": DataModule}}
register_hydra_config(Config, options)


def test_dataclass_with_default() -> None:
"""A normal field with a default should not have variants."""
def test_config_base_class() -> None:
@dataclass
class DataModule:
root: Path

@dataclass
class Model:
layers: int = 1
class CMnist(DataModule):
colorize: bool

@dataclass
class CelebA(DataModule):
target: str

@dataclass
class Config:
model: Model = dataclasses.field(default_factory=Model)
dm: DataModule

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"cmnist": CMnist, "celeba": CelebA}}
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
@dataclass
class NotSubclass:
root: Path

options = {"dm": {"base": NotSubclass}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_attrs_no_default() -> None:
"""This isn't so much wrong as just clumsy."""
def test_config_any_with_default() -> None:
"""An Any field with default is completely out."""

@define
class DataModule:
root: Path
@dataclass
class Model:
layers: int = 1

@define
@dataclass
class Config:
dm: DataModule
model: Any = dataclasses.field(default_factory=Model)

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_attrs_any() -> None:
@define
class DataModule:
root: Path
def test_config_with_default() -> None:
"""A normal field with a default should not have variants."""

@dataclass
class Model:
layers: int = 1

@define
@dataclass
class Config:
dm: Any
model: Model = dataclasses.field(default_factory=Model)

# we're assuming that the only reason you want to use Any is that
# you want to use variants
options = {}
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
register_hydra_config(Config, options)

@dataclass
class _GlobalModel:
layers: int = 1

def test_attrs_any_with_default() -> None:
"""An Any field with default is completely out."""

@define
class Model:
layers: int = 1
def test_config_with_default_string() -> None:
"""Need to use a global class here because otherwise the type annotations can'b be resolved."""

@define
@dataclass
class Config:
# it should of course be `factory` and not `default` here,
# but OmegaConf is stupid as always
model: Any = attrs.field(default=Model)
model: "_GlobalModel" = dataclasses.field(default_factory=_GlobalModel)

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
options = {"model": {"base": _GlobalModel}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_attrs_with_default() -> None:
"""A normal field with a default should not have variants."""
T = TypeVar("T")

@define
class Model:
layers: int = 1

@define
class Config:
# it should of course be `factory` and not `default` here,
# but OmegaConf is stupid as always
model: Model = attrs.field(default=Model)
def test_config_generic() -> None:
class Base(Generic[T]): ...

options = {}
register_hydra_config(Config, options)
@dataclass
class DataModule(Base):
root: Path

options = {"model": {"base": Model}}
@dataclass
class Config:
dm: DataModule

options = {"dm": {"asdf": DataModule}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

Expand Down

0 comments on commit 57e70d7

Please sign in to comment.