Skip to content

Feat: let nw.Enum accept categories, map pandas ordered categorical to Enum (only in main namespace, not stable.v1) #2192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
3581985
enh nw.Enum to accept categories
camriddell Dec 11, 2024
69e9d88
add tests for nw.Enum(categories)
camriddell Mar 11, 2025
9e63e68
fix enum type checking for Enum dtype
camriddell Mar 11, 2025
3465fbf
Merge branch 'main' of github.com:narwhals-dev/narwhals into enh-enum…
camriddell Mar 11, 2025
a570751
fix enum doctest
camriddell Mar 12, 2025
6f22771
positive check for enum instance for pyright
camriddell Mar 12, 2025
04b2ee6
Merge branch 'main' into enh-enum-creation
camriddell Mar 12, 2025
1393ac7
Merge branch 'main' of github.com:narwhals-dev/narwhals into enh-enum…
camriddell Mar 27, 2025
6e85c1a
enum use implementation specific CategoricalDtype
camriddell Mar 31, 2025
70d5c67
enum preserve v1 behavior
camriddell Mar 31, 2025
e4f2d87
preserve v1 polars enum conversion
camriddell Mar 31, 2025
86eb3a2
add Enum support to dask
camriddell Apr 4, 2025
3e1db9e
modin to xfail on Enum dtype
camriddell Apr 4, 2025
5a996b1
Merge remote-tracking branch 'upstream/main' into enh-enum-creation
camriddell Apr 4, 2025
1ce56ca
Enum support outside of V1
camriddell Apr 4, 2025
a8f6e42
Fix v1 enum missing argument teset
camriddell Apr 4, 2025
342ea83
fix enum error match for py38
camriddell Apr 4, 2025
5dd7c72
add pragma: no cover to v1.Enum from aligned with DType class
camriddell Apr 4, 2025
1bf98d0
parametrize api versions for dtypes tests
camriddell Apr 4, 2025
da1a455
decouple narwhals versioned dtypes
camriddell Apr 4, 2025
a79554a
Merge branch 'main' of github.com:narwhals-dev/narwhals into enh-enum…
camriddell Apr 8, 2025
0712557
enum types to pass tests
camriddell Apr 9, 2025
4fed6d1
Merge branch 'main' of github.com:narwhals-dev/narwhals into enh-enum…
camriddell Apr 9, 2025
f06df25
fix: pyright warning on incorrect line for Enum
camriddell Apr 9, 2025
a0a8c39
test(typing): Fix redefinition
dangotbanned Apr 9, 2025
21d03dd
test: Add `test_enum_v1_is_enum_unstable`
dangotbanned Apr 9, 2025
36fd178
Merge branch 'main' into enh-enum-creation
dangotbanned Apr 9, 2025
105e394
fix: Get more of `__eq__` working
dangotbanned Apr 9, 2025
0d5b5cc
fix: Correct subtyping for `__eq__`
dangotbanned Apr 9, 2025
02b60e0
chore: Coverage
dangotbanned Apr 9, 2025
07d4a25
Merge remote-tracking branch 'upstream/main' into enh-enum-creation
MarcoGorelli Apr 13, 2025
5e53281
list -> tuple
MarcoGorelli Apr 13, 2025
0018a5b
use str in annotation but at runtime dont raise for hashable (pandas..)
MarcoGorelli Apr 13, 2025
99ad2b5
improve error message
MarcoGorelli Apr 13, 2025
000013c
move another test to v1test
MarcoGorelli Apr 13, 2025
6aed0db
update backcompat doc
MarcoGorelli Apr 13, 2025
3357c55
coverage
MarcoGorelli Apr 13, 2025
3b40237
test: fix doctest
dangotbanned Apr 13, 2025
15cac8e
Merge remote-tracking branch 'upstream/main' into pr/camriddell/2192
dangotbanned Apr 14, 2025
84ea789
test: Add `test_enum_from_series`
dangotbanned Apr 14, 2025
833bc3c
enum add non-string and duplicate checking
camriddell Apr 16, 2025
7aca1b8
Merge branch 'main' of github.com:narwhals-dev/narwhals into enh-enum…
camriddell Apr 16, 2025
eb0b328
Merge branch 'main' into enh-enum-creation
dangotbanned Apr 16, 2025
d2504a4
Update docs/backcompat.md
camriddell Apr 16, 2025
def83a3
fix(typing): Resolve some `Enum.categories` issues
dangotbanned Apr 17, 2025
f89e331
refactor: Simplify `__init__`, raise earlier
dangotbanned Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/backcompat.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Here are exceptions to our backwards compatibility policy:
need to rethink Narwhals. However, we expect such radical changes to be exceedingly unlikely.
- We may consider making some type hints more precise.
- Anything labelled "unstable".
- We may sometimes need to bump the minimum versions of supported backends.

In general, decision are driven by use-cases, and we conduct a search of public GitHub repositories
before making any change.
Expand All @@ -113,6 +114,11 @@ before making any change.

The following are differences between the main Narwhals namespace and `narwhals.stable.v1`:

- Since Narwhals 1.35:

- pandas' ordered categoricals get mapped to `nw.Enum` instead of `nw.Categorical`.
- `nw.Enum` must be provided `categories` at instantiation.

- Since Narwhals 1.29.0, `LazyFrame.gather_every` has been deprecated from the main namespace.

- Since Narwhals 1.24.1, an empty or all-null object-dtype pandas Series is inferred to
Expand Down
16 changes: 15 additions & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version
Expand All @@ -24,7 +25,6 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals.dtypes import DType
from narwhals.utils import Version


def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
Expand Down Expand Up @@ -125,6 +125,20 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> An
return "object" # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
import pandas as pd

# NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow
# Should be one of the `ListLike*` types
# https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505
return pd.CategoricalDtype(dtype.categories, ordered=True) # pyright: ignore[reportArgumentType]
Comment on lines +132 to +138
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@camriddell ignore this, I only meant to add as a comment - not the review 🫣

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli gentle nudge on this, in case it was missed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey - yeah, probably, the pandas stubs definitely don't get all the attention they probably deserve

msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)

if isinstance_or_issubclass(dtype, dtypes.Categorical):
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
Expand Down
25 changes: 20 additions & 5 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def rename(


@functools.lru_cache(maxsize=16)
def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType:
def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType:
dtype = str(native_dtype)

Comment on lines +214 to +216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems to have been there since the first commit (3581985), but doesn't seem to be documented?

It looks like this part is related

https://github.com/camriddell/narwhals/blob/d2504a40efc606d8e626a5b9049ff8054417d64c/narwhals/_pandas_like/utils.py#L320-L321

Which would mean we do the str(...) call twice now. Just an observation, not sure if there is a cost to that

https://github.com/camriddell/narwhals/blob/d2504a40efc606d8e626a5b9049ff8054417d64c/narwhals/_pandas_like/utils.py#L306-L309

Are all non-object pandas data types guaranteed to be immutable?
I think str was used because it is hashable, so is safe to use in functools.lru_cache

dtypes = import_dtypes_module(version)
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
return dtypes.Int64()
Expand Down Expand Up @@ -249,7 +251,13 @@ def non_object_native_to_narwhals_dtype(dtype: str, version: Version) -> DType:
return dtypes.String()
if dtype in {"bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"}:
return dtypes.Boolean()
if dtype == "category" or dtype.startswith("dictionary<"):
if dtype.startswith("dictionary<"):
return dtypes.Categorical()
if dtype == "category":
if version is Version.V1:
return dtypes.Categorical()
if native_dtype.ordered:
return dtypes.Enum(native_dtype.categories)
return dtypes.Categorical()
if (match_ := PATTERN_PD_DATETIME.match(dtype)) or (
match_ := PATTERN_PA_DATETIME.match(dtype)
Expand Down Expand Up @@ -310,7 +318,7 @@ def native_to_narwhals_dtype(
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
if str_dtype != "object":
return non_object_native_to_narwhals_dtype(str_dtype, version)
return non_object_native_to_narwhals_dtype(native_dtype, version)
elif implementation is Implementation.DASK:
# Per conversations with their maintainers, they don't support arbitrary
# objects, so we can just return String.
Expand Down Expand Up @@ -471,8 +479,15 @@ def narwhals_to_native_dtype( # noqa: PLR0915
msg = "PyArrow>=11.0.0 is required for `Date` dtype."
return "date32[pyarrow]"
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
ns = implementation.to_native_namespace()
return ns.CategoricalDtype(dtype.categories, ordered=True)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)

if isinstance_or_issubclass(
dtype, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary)
):
Expand Down
19 changes: 13 additions & 6 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import NarwhalsError
from narwhals.exceptions import ShapeError
from narwhals.utils import Version
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

Expand All @@ -26,7 +27,6 @@
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals.dtypes import DType
from narwhals.utils import Version

T = TypeVar("T")

Expand Down Expand Up @@ -110,8 +110,10 @@ def native_to_narwhals_dtype(
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if dtype == pl.Enum:
return dtypes.Enum()
if isinstance_or_issubclass(dtype, pl.Enum):
if version is Version.V1:
return dtypes.Enum() # type: ignore[call-arg]
return dtypes.Enum(dtype.categories)
if dtype == pl.Date:
return dtypes.Date()
if isinstance_or_issubclass(dtype, pl.Datetime):
Expand Down Expand Up @@ -185,9 +187,14 @@ def narwhals_to_native_dtype(
return pl.Object()
if dtype == dtypes.Categorical:
return pl.Categorical()
if dtype == dtypes.Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
return pl.Enum(dtype.categories)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if dtype == dtypes.Date:
return pl.Date()
if dtype == dtypes.Time:
Expand Down
33 changes: 32 additions & 1 deletion narwhals/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
from collections import OrderedDict
from datetime import timezone
from itertools import starmap
Expand All @@ -9,6 +10,7 @@
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from typing import Iterable
from typing import Iterator
from typing import Sequence

Expand Down Expand Up @@ -443,9 +445,38 @@ class Enum(DType):
>>> data = ["beluga", "narwhal", "orca"]
>>> s_native = pl.Series(data, dtype=pl.Enum(data))
>>> nw.from_native(s_native, series_only=True).dtype
Enum
Enum(categories=['beluga', 'narwhal', 'orca'])
"""
Comment on lines 445 to 449
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about adapting this to be more like the example(s) in pl.Enum?

I believe this style was originally used because we didn't have a constructor for nw.Enum.
But, a great thing about this PR is we no longer have that limitation:

Suggested change
>>> data = ["beluga", "narwhal", "orca"]
>>> s_native = pl.Series(data, dtype=pl.Enum(data))
>>> nw.from_native(s_native, series_only=True).dtype
Enum
Enum(categories=['beluga', 'narwhal', 'orca'])
"""
>>> nw.Enum(["beluga", "narwhal", "orca"])
Enum(categories=['beluga', 'narwhal', 'orca'])
"""


categories: Sequence[str]

def __init__(self, categories: Iterable[str] | type[enum.Enum]) -> None:
if isinstance(categories, type) and issubclass(categories, enum.Enum):
categories = (v.value for v in categories)
sequence: tuple[str, ...] = tuple(categories)
seen: set[str] = set()
for cat in sequence:
if cat in seen:
msg = f"{type(self).__name__} categories must be unique; found duplicate {cat!r}"
raise ValueError(msg)
if not isinstance(cat, str):
msg = f"{type(self).__name__} categories must be strings; found data of type {type(cat).__name__!r}"
raise TypeError(msg)
seen.add(cat)
self.categories = sequence

def __eq__(self: Self, other: object) -> bool:
# allow comparing object instances to class
if type(other) is type:
return other is Enum
return isinstance(other, type(self)) and self.categories == other.categories

def __hash__(self: Self) -> int: # pragma: no cover
return hash((self.__class__, tuple(self.categories)))

def __repr__(self: Self) -> str: # pragma: no cover
return f"{type(self).__name__}(categories={list(self.categories)!r})"


class Field:
"""Definition of a single field within a `Struct` DataType.
Expand Down
31 changes: 30 additions & 1 deletion narwhals/stable/v1/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from narwhals.dtypes import Decimal
from narwhals.dtypes import DType
from narwhals.dtypes import Duration as NwDuration
from narwhals.dtypes import Enum
from narwhals.dtypes import Enum as NwEnum
from narwhals.dtypes import Field
from narwhals.dtypes import Float32
from narwhals.dtypes import Float64
Expand Down Expand Up @@ -72,6 +72,35 @@ def __hash__(self: Self) -> int:
return hash(self.__class__)


class Enum(NwEnum):
"""A fixed categorical encoding of a unique set of strings.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dangotbanned the typing here gets a bit wonky as we currently need v1._dtypes... implementations to inherit from what is defined in nw.dtypes. However nw.dtypes.Enum had its call signature changed which should not be propagated down to v1._dtypes.Enum so I implemented this functionality to skip a level of inheritance on its defined methods.

If feels like I may have something backwards here though? Would love to hear you thoughts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping @camriddell, will take a look in the morning

Copy link
Member

@dangotbanned dangotbanned Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a tricky one, but I have a few ideas I'm gonna try out today.

I did a search of existing usage and looked at what we allow in tests.

I think our main concern should be preserving the behavior of isinstance(..., nw.Enum).
The cases with dtype == nw.Enum are simple to handle without subclassing.

I haven't tried out customizing-instance-and-subclass-checks yet - but have thought about it for another DType issue (#2050 (comment))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wanting to avoid subclassing, since this is a pretty clear Liskov substitution principle violation (not your fault, just how v1 inheriting from main works)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's ok to allow Enum to accept categories in v1 as well, so long as == nw.Enum keeps working - can you check what we do for Datetime and Duration? I think something similar might work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's ok to allow Enum to accept categories in v1 as well, so long as == nw.Enum keeps working - can you check what we do for Datetime and Duration? I think something similar might work?

Lol @MarcoGorelli the timing on this πŸ˜… (105e394)


Polars has an Enum data type, while pandas and PyArrow do not.

Examples:
>>> import polars as pl
>>> import narwhals.stable.v1 as nw
>>> data = ["beluga", "narwhal", "orca"]
>>> s_native = pl.Series(data, dtype=pl.Enum(data))
>>> nw.from_native(s_native, series_only=True).dtype
Enum
"""

def __init__(self: Self) -> None:
super(NwEnum, self).__init__()

def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
if type(other) is type:
return other in {type(self), NwEnum}
return isinstance(other, type(self))

def __hash__(self: Self) -> int: # pragma: no cover
return super(NwEnum, self).__hash__()

def __repr__(self: Self) -> str: # pragma: no cover
return super(NwEnum, self).__repr__()


__all__ = [
"Array",
"Binary",
Expand Down
51 changes: 51 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import enum
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal

import numpy as np
Expand Down Expand Up @@ -400,3 +403,51 @@ def test_cast_decimal_to_native() -> None:
.with_columns(a=nw.col("a").cast(nw.Decimal()))
.to_native()
)


class FakeEnum(enum.Enum):
A = "A"
B = "B"


@pytest.mark.parametrize(
"categories", [["a", "b"], [np.str_("a"), np.str_("b")], FakeEnum]
)
def test_enum_valid(categories: Iterable[Any] | type[enum.Enum]) -> None:
dtype = nw.Enum(categories)
assert dtype == nw.Enum
assert len(dtype.categories) == len([*categories])


@pytest.mark.parametrize(
("categories", "exception", "match"),
[
(["a", None], TypeError, "categories must be strings"),
(["a", float("nan")], TypeError, "categories must be strings"),
([object()], TypeError, "categories must be strings"),
(enum.Enum("FakeEnum", "a b"), TypeError, "categories must be strings"),
(["a", "a"], ValueError, "categories must be unique"),
],
)
def test_enum_errors(
categories: Iterable[Any], exception: type[Exception], match: str
) -> None:
with pytest.raises(exception, match=match):
nw.Enum(categories)


def test_enum_from_series() -> None:
pytest.importorskip("polars")
import polars as pl

elements = "a", "d", "e", "b", "c"
categories = pl.Series(elements)
categories_nw = nw.from_native(categories, series_only=True)
assert nw.Enum(categories_nw).categories == elements
assert nw.Enum(categories).categories == elements


def test_enum_categories_immutable() -> None:
dtype = nw.Enum(["a", "b"])
with pytest.raises(TypeError, match="does not support item assignment"):
dtype.categories[0] = "c" # type: ignore[index]
13 changes: 13 additions & 0 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ def test_dtypes() -> None:
assert df_from_pd.schema == df_from_pd.collect_schema() == expected
assert {name: df_from_pd[name].dtype for name in df_from_pd.columns} == expected

df_from_pd = nw.from_native(df_pl.to_pandas(), eager_only=True)

pure_pd_expected = {
**expected,
"n": nw.Datetime,
"s": nw.Object,
"u": nw.Object,
}
assert df_from_pd.schema == df_from_pd.collect_schema() == pure_pd_expected
assert {
name: df_from_pd[name].dtype for name in df_from_pd.columns
} == pure_pd_expected

df_from_pa = nw.from_native(df_pl.to_arrow(), eager_only=True)

assert df_from_pa.schema == df_from_pa.collect_schema() == expected
Expand Down
37 changes: 16 additions & 21 deletions tests/series_only/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tests.utils import PANDAS_VERSION

if TYPE_CHECKING:
from tests.utils import Constructor
from tests.utils import ConstructorEager


Expand Down Expand Up @@ -113,29 +114,23 @@ def test_unknown_to_int() -> None:
assert nw.from_native(df).select(nw.col("a").cast(nw.Int64)).schema == {"a": nw.Int64}


def test_cast_to_enum_polars() -> None:
pytest.importorskip("polars")
import polars as pl

# we don't yet support metadata in dtypes, so for now disallow this
# seems like a very niche use case anyway, and allowing it later wouldn't be
# backwards-incompatible
df_pl = pl.DataFrame({"a": ["a", "b"]}, schema={"a": pl.Categorical})
with pytest.raises(
NotImplementedError, match=r"Converting to Enum is not \(yet\) supported"
def test_cast_to_enum_vmain(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
# Backends that do not (yet) support Enum dtype
if any(
backend in str(constructor)
for backend in ["pyarrow_table", "duckdb", "sqlframe", "pyspark", "modin"]
):
nw.from_native(df_pl).select(nw.col("a").cast(nw.Enum))
request.applymarker(pytest.mark.xfail)

df_nw = nw.from_native(constructor({"a": ["a", "b"]}))
col_a = nw.col("a")

def test_cast_to_enum_pandas() -> None:
pytest.importorskip("pandas")
import pandas as pd

# we don't yet support metadata in dtypes, so for now disallow this
# seems like a very niche use case anyway, and allowing it later wouldn't be
# backwards-incompatible
df_pd = pd.DataFrame({"a": ["a", "b"]}, dtype="category")
with pytest.raises(
NotImplementedError, match=r"Converting to Enum is not \(yet\) supported"
ValueError, match="Can not cast / initialize Enum without categories present"
):
nw.from_native(df_pd).select(nw.col("a").cast(nw.Enum))
df_nw.select(col_a.cast(nw.Enum))

df_nw = df_nw.select(col_a.cast(nw.Enum(["a", "b"])))
assert df_nw.collect_schema() == {"a": nw.Enum(["a", "b"])}
Loading
Loading