Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 78 additions & 3 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import yaml

from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log
from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,7 +243,9 @@ def _process_config_class(cls: type["Config"]):
return cls


def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]:
def config_class[
T: Config
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand All @@ -253,7 +255,7 @@ def wrap(cls):
if hasattr(cls, "__post_init__"):
raise TypeError(f"`__post_init__` should not be implemented for `Config` classes")

wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True))
wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False))

wrapped_init = cls.__init__

Expand All @@ -267,6 +269,14 @@ def __init__(self, **kwargs):
self.validate()

wrapped.__init__ = __init__

wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None

if dynamic_type is not None:
for cls_, name in dynamic_type.items():
print(cls_, name, wrapped)
cls_.register_subclass(name, wrapped)

return wrapped

return wrap
Expand Down Expand Up @@ -305,6 +315,9 @@ class Config(metaclass=ConfigMeta):
# without them being automatically added to `_explicit_fields`.
_setting_implicit_default: bool | None = Field(init=False)

# A registry for all the config classes.
_registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Make the class read-only after validation.
Expand Down Expand Up @@ -358,6 +371,17 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
Validate a class and mark it as read-only
This should not be overridden in derived classes.
"""
# Should be handled in `from_dict`, but can fail if instantiating directly.
try:
expected_class = self.get_subclass(self.type)
except KeyError as e:
# Delayed instantiation error in `from_dict`.
raise ValidationError(*e.args)

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)

if not self._validated:
try:
self._validate()
Expand Down Expand Up @@ -738,6 +762,14 @@ def _from_dict(
if "__class__" in default:
del default["__class__"]

try:
actual_cls = cls.get_subclass(default.get("type"))
if actual_cls is not None and actual_cls is not cls:
return actual_cls._from_dict(default, strict=strict, flat=flat)
except KeyError:
# Postpone error to validation.
pass

# Do not validate yet in case the root class sets cross-dependencies in validation.
with NoAutoValidate():
for name, field in cls.fields():
Expand Down Expand Up @@ -864,6 +896,42 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
)
return None

@classmethod
def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None:
Assert.custom(issubclass, cls_, cls)
if cls._registry is None:
raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..")
if name in cls._registry:
old_cls = cls._registry[name]
if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__:
del cls._registry[name]
else:
raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.")
cls._registry[name] = cls_

@classmethod
def get_subclass(cls, name: str | None):
# TODO: Make it case-insensitive?
if name is None:
return None
cls_ = None
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry:
if cls_ is None:
cls_ = base_class._registry[name]
if not issubclass(cls_, cls):
raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})")
elif base_class._registry[name] is not cls_:
# We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization.
# TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit?
raise KeyError(
f"Ambiguous type `{name}` for base class {cls.__name__}."
f" ({cls_.__name__} vs {base_class._registry[name]})"
)
if cls_ is None:
raise KeyError(f"Unknown type {name} for base class {cls.__name__}")
return cls_

def __init_subclass__(cls):
"""
We need to postpone validation until the class has been processed by the dataclass wrapper.
Expand Down Expand Up @@ -913,6 +981,13 @@ def __init_subclass__(cls):
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type

# Type for the field. At the end of class definition to avoid shadowing builtin.
type: str | None = Field(
default=None,
desc="The config class name.",
hint=FieldHint.feature,
)


class Configurable[ConfigType: Config]:
config_class: typing.ClassVar[type[Config]] = Config
Expand Down
92 changes: 15 additions & 77 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SamplingParameters,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
Expand Down Expand Up @@ -93,61 +93,9 @@ class GPTSamplingData(SamplingData):
truncate_documents: bool = True


@config_class()
@config_class(registry=True)
class GPTSampledDatasetConfig(SampledDatasetConfig):

# TODO: Generalize dynamic types?
_registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[
str, type["GPTDatasetConfig"]
]("gpt_dataset_class", {})
type_: typing.ClassVar[str | None] = None
type: str | None = Field(
default=None,
desc="The type of dataset.",
hint=FieldHint.core,
)

def _validate(self) -> None:
if self.type is None:
self.type = self.type_
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.eq(self.type, self.__class__.type_)
super()._validate()

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
type_ = default.get("type")
if type_ is None:
actual_cls = cls
else:
if type_ not in cls._registry:
raise ValueError(
f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}"
)
actual_cls = cls._registry[type_]
Assert.custom(issubclass, actual_cls, cls)
if actual_cls == cls:
return super()._from_dict(default, strict=strict, flat=flat)
else:
return actual_cls._from_dict(default, strict=strict, flat=flat)

def __init_subclass__(cls) -> None:
if cls._abstract and cls.type_ is not None:
# Abstract classes should not have a `type_`
raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.")
if cls.type_ is not None:
if cls.type_ in cls._registry:
raise ValueError(
f"Registry {cls._registry.name} already contains type {cls.type_}."
f" Make sure all classes either have a unique or `None` type."
)
GPTSampledDatasetConfig._registry[cls.type_] = cls
super().__init_subclass__()
pass


@config_class()
Expand All @@ -161,10 +109,9 @@ def build(self) -> "GPTIndexedDataset":
raise NotImplementedError()


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "random"})
class GPTRandomDatasetConfig(GPTSamplableDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "random"
name: str = Field(
default="dummy",
desc="The name of the dataset.",
Expand All @@ -177,10 +124,9 @@ def build(self) -> "GPTRandomDataset":
return GPTRandomDataset(self.name)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"})
class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "memmap"
path: pathlib.Path = Field(
default=None,
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
Expand All @@ -203,10 +149,9 @@ def build(self) -> "GPTMemmapDataset":
return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"})
class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "concatenated"
datasets: list[GPTIndexedDatasetConfig] = FieldUpdate()

def build(self) -> "GPTConcatenatedDataset":
Expand All @@ -215,10 +160,9 @@ def build(self) -> "GPTConcatenatedDataset":
return self._build(GPTConcatenatedDataset)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"})
class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "slice"
dataset: GPTIndexedDatasetConfig = FieldUpdate()

def build(self) -> "GPTDatasetSlice":
Expand All @@ -227,25 +171,22 @@ def build(self) -> "GPTDatasetSlice":
return self._build(GPTDatasetSlice)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"})
class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig):
_abstract = False
type_: typing.ClassVar[str | None] = "sampled"
sampling: GPTSamplingConfig = FieldUpdate()
dataset: GPTSampledDatasetConfig = FieldUpdate()


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"})
class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "blended"
datasets: list[GPTSampledDatasetConfig] = FieldUpdate()


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "file"})
class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "file"
path: pathlib.Path = Field(
default=None,
desc="The path to a dataset config file.",
Expand Down Expand Up @@ -281,11 +222,11 @@ def _convert_paths(self, config):
return config


@config_class()
# Add user-friendly names for the configs.
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"})
class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig):
# TODO v0.3: Remove.
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "concatenated_memmap"
path: pathlib.Path = Field(
default=None,
desc="The path to a dataset directory.",
Expand Down Expand Up @@ -388,14 +329,13 @@ class FimConfig(Config):
)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"})
class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig):
"""
Configuration for FIM.
"""

_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "fim"

dataset: GPTSampledDatasetConfig = Field(
default=None,
Expand Down Expand Up @@ -456,10 +396,9 @@ class GPTLegacyConfig(Config):
)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"})
class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig):
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "legacy"

def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:

Expand Down Expand Up @@ -538,15 +477,14 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling)


@config_class()
@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"})
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
"""
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
"""

# TODO: This belongs to a testing plugin.
_abstract: typing.ClassVar[bool] = False
type_: typing.ClassVar[str | None] = "test_slow"
sleep: float = Field(
default=1,
desc="Sleep time during build, in seconds.",
Expand Down
8 changes: 7 additions & 1 deletion fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class NormalizationType(str, enum.Enum):
rms_norm = "rms_norm"


@config_class()
@config_class(registry=True)
class NormalizationConfig(BaseModelConfig):
_abstract = False

Expand Down Expand Up @@ -107,6 +107,12 @@ def _from_dict(
return super()._from_dict(default, strict, flat)


for name in NormalizationType:
# We need this because we are using the reserved field name `type`.
# TODO: Implement proper dynamic typing.
NormalizationConfig.register_subclass(name.value, NormalizationConfig)


class PeftType(str, enum.Enum):
# TODO : Use a dynamic config type instead.
none = "none"
Expand Down
Loading