diff --git a/curated_transformers/layers/activations.py b/curated_transformers/layers/activations.py index d3503133..4867a77e 100644 --- a/curated_transformers/layers/activations.py +++ b/curated_transformers/layers/activations.py @@ -1,5 +1,5 @@ import math -from enum import Enum, EnumMeta +from enum import Enum from typing import Type import torch @@ -7,46 +7,7 @@ from torch.nn import Module -class _ActivationMeta(EnumMeta): - """ - ``Enum`` metaclass to override the class ``__call__`` method with a more - fine-grained exception for unknown activation functions. - """ - - def __call__( - cls, - value, - names=None, - *, - module=None, - qualname=None, - type=None, - start=1, - ): - # Wrap superclass __call__ to give a nicer error message when - # an unknown activation is used. - if names is None: - try: - return EnumMeta.__call__( - cls, - value, - names, - module=module, - qualname=qualname, - type=type, - start=start, - ) - except ValueError: - supported_activations = ", ".join(sorted(v.value for v in cls)) - raise ValueError( - f"Invalid activation function `{value}`. " - f"Supported functions: {supported_activations}" - ) - else: - return EnumMeta.__call__(cls, value, names, module, qualname, type, start) - - -class Activation(Enum, metaclass=_ActivationMeta): +class Activation(Enum): """ Activation functions. @@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta): #: Sigmoid Linear Unit (`Hendrycks et al., 2016`_). SiLU = "silu" + @classmethod + def _missing_(cls, value): + supported_activations = ", ".join(sorted(v.value for v in cls)) + raise ValueError( + f"Invalid activation function `{value}`. " + f"Supported functions: {supported_activations}" + ) + @property def module(self) -> Type[torch.nn.Module]: """ diff --git a/curated_transformers/tests/layers/test_embeddings.py b/curated_transformers/tests/layers/test_embeddings.py index 085db7ad..8e741110 100644 --- a/curated_transformers/tests/layers/test_embeddings.py +++ b/curated_transformers/tests/layers/test_embeddings.py @@ -24,7 +24,8 @@ def test_rotary_embeddings_against_hf(device): X = torch.rand(16, 12, 64, 768, device=device) Y = re(X) - hf_re_cos, hf_re_sin = hf_re(X, seq_len=X.shape[-2]) + positions = torch.arange(X.shape[2], device=device).view([1, -1]) + hf_re_cos, hf_re_sin = hf_re(X, positions) Y_hf = hf_re_cos * X + hf_re_sin * rotate_half(X) torch_assertclose(Y, Y_hf) diff --git a/curated_transformers/tests/tokenizers/test_hf_hub.py b/curated_transformers/tests/tokenizers/test_hf_hub.py index 183145e4..b6859dce 100644 --- a/curated_transformers/tests/tokenizers/test_hf_hub.py +++ b/curated_transformers/tests/tokenizers/test_hf_hub.py @@ -51,6 +51,7 @@ def test_from_hf_hub_to_cache_legacy(): ) +@pytest.mark.xfail(reason="HfFileSystem calls safetensors with incorrect arguments") @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") def test_fsspec(sample_texts): # We only test one model, since using fsspec downloads the model diff --git a/setup.cfg b/setup.cfg index 52b0d016..654d45a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -version = 1.3.1 +version = 1.3.2 description = A PyTorch library of transformer models and components url = https://github.com/explosion/curated-transformers author = Explosion