Skip to content

Commit 1060d8b

Browse files
authored
Add Polars to mypy environment and fix errors (#20563)
#19072 made a lot of the necessary fixes, but polars was not actually added to the pre-commit mypy environment so we haven't been checking since then. As a result, some new issues have crept in, and #20272 removed various ignores that are required for polars type safety but mypy didn't know that without polars available. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - Bradley Dice (https://github.com/bdice) - Tom Augspurger (https://github.com/TomAugspurger) - Matthew Murray (https://github.com/Matt711) URL: #20563
1 parent ce4b22f commit 1060d8b

File tree

16 files changed

+132
-47
lines changed

16 files changed

+132
-47
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ repos:
3434
rev: 'v1.13.0'
3535
hooks:
3636
- id: mypy
37-
additional_dependencies: [types-cachetools, pyarrow-stubs, numpy, pytest]
37+
additional_dependencies: [numpy, polars, pyarrow-stubs, pytest, types-cachetools]
3838
args: ["--config-file=pyproject.toml",
3939
"python/cudf/cudf",
4040
"python/pylibcudf/pylibcudf",

python/cudf_polars/cudf_polars/containers/datatype.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66
from __future__ import annotations
77

88
from functools import cache
9-
from typing import TYPE_CHECKING
9+
from typing import TYPE_CHECKING, Literal, cast
1010

1111
from typing_extensions import assert_never
1212

1313
import polars as pl
1414

1515
import pylibcudf as plc
1616

17+
from cudf_polars.utils.versions import POLARS_VERSION_LT_136
18+
1719
if TYPE_CHECKING:
1820
from cudf_polars.typing import (
1921
DataTypeHeader,
22+
PolarsDataType,
2023
)
2124

2225
__all__ = ["DataType"]
@@ -46,7 +49,18 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
4649
if name in SCALAR_NAME_TO_POLARS_TYPE_MAP:
4750
return {"kind": "scalar", "name": name}
4851
if isinstance(dtype, pl.Decimal):
49-
return {"kind": "decimal", "precision": dtype.precision, "scale": dtype.scale}
52+
# Workaround for incorrect polars stubs where precision is typed as int | None
53+
# Fixed upstream: https://github.com/pola-rs/polars/pull/25227
54+
# TODO: Remove this workaround when polars >= 1.36
55+
if POLARS_VERSION_LT_136:
56+
assert (
57+
dtype.precision is not None
58+
) # Decimal always has precision at runtime
59+
return {
60+
"kind": "decimal",
61+
"precision": cast(int, dtype.precision),
62+
"scale": dtype.scale,
63+
}
5064
if isinstance(dtype, pl.Datetime):
5165
return {
5266
"kind": "datetime",
@@ -56,12 +70,17 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
5670
if isinstance(dtype, pl.Duration):
5771
return {"kind": "duration", "time_unit": dtype.time_unit}
5872
if isinstance(dtype, pl.List):
59-
return {"kind": "list", "inner": _dtype_to_header(dtype.inner)}
73+
# isinstance narrows dtype to pl.List, but .inner returns DataTypeClass | DataType
74+
return {
75+
"kind": "list",
76+
"inner": _dtype_to_header(cast(pl.DataType, dtype.inner)),
77+
}
6078
if isinstance(dtype, pl.Struct):
79+
# isinstance narrows dtype to pl.Struct, but field.dtype returns DataTypeClass | DataType
6180
return {
6281
"kind": "struct",
6382
"fields": [
64-
{"name": f.name, "dtype": _dtype_to_header(f.dtype)}
83+
{"name": f.name, "dtype": _dtype_to_header(cast(pl.DataType, f.dtype))}
6584
for f in dtype.fields
6685
],
6786
}
@@ -78,9 +97,14 @@ def _dtype_from_header(header: DataTypeHeader) -> pl.DataType:
7897
if header["kind"] == "decimal":
7998
return pl.Decimal(header["precision"], header["scale"])
8099
if header["kind"] == "datetime":
81-
return pl.Datetime(time_unit=header["time_unit"], time_zone=header["time_zone"])
100+
return pl.Datetime(
101+
time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"]),
102+
time_zone=header["time_zone"],
103+
)
82104
if header["kind"] == "duration":
83-
return pl.Duration(time_unit=header["time_unit"])
105+
return pl.Duration(
106+
time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"])
107+
)
84108
if header["kind"] == "list":
85109
return pl.List(_dtype_from_header(header["inner"]))
86110
if header["kind"] == "struct":
@@ -182,9 +206,14 @@ class DataType:
182206
polars_type: pl.datatypes.DataType
183207
plc_type: plc.DataType
184208

185-
def __init__(self, polars_dtype: pl.DataType) -> None:
186-
self.polars_type = polars_dtype
187-
self.plc_type = _from_polars(polars_dtype)
209+
def __init__(self, polars_dtype: PolarsDataType) -> None:
210+
# Convert DataTypeClass to DataType instance if needed
211+
# polars allows both pl.Int64 (class) and pl.Int64() (instance)
212+
if isinstance(polars_dtype, type):
213+
polars_dtype = polars_dtype()
214+
# After conversion, it's guaranteed to be a DataType instance
215+
self.polars_type = cast(pl.DataType, polars_dtype)
216+
self.plc_type = _from_polars(self.polars_type)
188217

189218
def id(self) -> plc.TypeId:
190219
"""The pylibcudf.TypeId of this DataType."""
@@ -193,12 +222,16 @@ def id(self) -> plc.TypeId:
193222
@property
194223
def children(self) -> list[DataType]:
195224
"""The children types of this DataType."""
196-
# these type ignores are needed because the type checker doesn't
197-
# see that these equality checks passing imply a specific type for each child field.
225+
# Type checker doesn't narrow polars_type through plc_type.id() checks
198226
if self.plc_type.id() == plc.TypeId.STRUCT:
199-
return [DataType(field.dtype) for field in self.polars_type.fields]
227+
# field.dtype returns DataTypeClass | DataType, need to cast to DataType
228+
return [
229+
DataType(cast(pl.DataType, field.dtype))
230+
for field in cast(pl.Struct, self.polars_type).fields
231+
]
200232
elif self.plc_type.id() == plc.TypeId.LIST:
201-
return [DataType(self.polars_type.inner)]
233+
# .inner returns DataTypeClass | DataType, need to cast to DataType
234+
return [DataType(cast(pl.DataType, cast(pl.List, self.polars_type).inner))]
202235
return []
203236

204237
def scale(self) -> int:

python/cudf_polars/cudf_polars/dsl/expressions/boolean.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from enum import IntEnum, auto
1010
from functools import partial, reduce
11-
from typing import TYPE_CHECKING, Any, ClassVar
11+
from typing import TYPE_CHECKING, Any, ClassVar, cast
12+
13+
import polars as pl
1214

1315
import pylibcudf as plc
1416

@@ -350,9 +352,14 @@ def do_evaluate(
350352
needles, haystack = columns
351353
if haystack.obj.type().id() == plc.TypeId.LIST:
352354
# Unwrap values from the list column
355+
# .inner returns DataTypeClass | DataType, need to cast to DataType
353356
haystack = Column(
354357
haystack.obj.children()[1],
355-
dtype=DataType(haystack.dtype.polars_type.inner),
358+
dtype=DataType(
359+
cast(
360+
pl.DataType, cast(pl.List, haystack.dtype.polars_type).inner
361+
)
362+
),
356363
).astype(needles.dtype, stream=df.stream)
357364
if haystack.size:
358365
return Column(

python/cudf_polars/cudf_polars/dsl/expressions/string.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import re
1111
from datetime import datetime
1212
from enum import IntEnum, auto
13-
from typing import TYPE_CHECKING, Any, ClassVar
13+
from typing import TYPE_CHECKING, Any, ClassVar, cast
1414

15+
import polars as pl
1516
from polars.exceptions import InvalidOperationError
1617
from polars.polars import dtype_str_repr
1718

@@ -37,12 +38,12 @@
3738

3839
def _dtypes_for_json_decode(dtype: DataType) -> JsonDecodeType:
3940
"""Get the dtypes for json decode."""
40-
# the type checker doesn't know that this equality check implies a struct dtype.
41+
# Type checker doesn't narrow polars_type through dtype.id() check
4142
if dtype.id() == plc.TypeId.STRUCT:
4243
return [
4344
(field.name, child.plc_type, _dtypes_for_json_decode(child))
4445
for field, child in zip(
45-
dtype.polars_type.fields,
46+
cast(pl.Struct, dtype.polars_type).fields,
4647
dtype.children,
4748
strict=True,
4849
)

python/cudf_polars/cudf_polars/dsl/expressions/struct.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from enum import IntEnum, auto
1010
from io import StringIO
11-
from typing import TYPE_CHECKING, Any, ClassVar
11+
from typing import TYPE_CHECKING, Any, ClassVar, cast
12+
13+
import polars as pl
1214

1315
import pylibcudf as plc
1416

@@ -87,13 +89,14 @@ def do_evaluate(
8789
"""Evaluate this expression given a dataframe for context."""
8890
columns = [child.evaluate(df, context=context) for child in self.children]
8991
(column,) = columns
90-
# these type ignores are needed because the type checker doesn't
91-
# know that polars only calls StructFunction with struct types.
92+
# Type checker doesn't know polars only calls StructFunction with struct types
9293
if self.name == StructFunction.Name.FieldByName:
9394
field_index = next(
9495
(
9596
i
96-
for i, field in enumerate(self.children[0].dtype.polars_type.fields)
97+
for i, field in enumerate(
98+
cast(pl.Struct, self.children[0].dtype.polars_type).fields
99+
)
97100
if field.name == self.options[0]
98101
),
99102
None,
@@ -113,7 +116,9 @@ def do_evaluate(
113116
table,
114117
[
115118
(field.name, [])
116-
for field in self.children[0].dtype.polars_type.fields
119+
for field in cast(
120+
pl.Struct, self.children[0].dtype.polars_type
121+
).fields
117122
],
118123
)
119124
options = (

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,8 @@ def add_file_paths(
598598
plc.Table(
599599
[
600600
plc.Column.from_arrow(
601-
pl.Series(values=map(str, paths)), stream=df.stream
601+
pl.Series(values=map(str, paths)),
602+
stream=df.stream,
602603
)
603604
]
604605
),

python/cudf_polars/cudf_polars/dsl/to_ast.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from __future__ import annotations
77

88
from functools import partial, reduce, singledispatch
9-
from typing import TYPE_CHECKING, TypeAlias, TypedDict
9+
from typing import TYPE_CHECKING, TypeAlias, TypedDict, cast
10+
11+
import polars as pl
1012

1113
import pylibcudf as plc
1214
from pylibcudf import expressions as plc_expr
@@ -226,10 +228,10 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression:
226228
if haystack.dtype.id() == plc.TypeId.LIST:
227229
# Because we originally translated pl_expr.Literal with a list scalar
228230
# to a expr.LiteralColumn, so the actual type is in the inner type
229-
#
230-
# the type-ignore is safe because the for plc.TypeID.LIST, we know
231-
# we have a polars.List type, which has an inner attribute.
232-
plc_dtype = DataType(haystack.dtype.polars_type.inner).plc_type
231+
# .inner returns DataTypeClass | DataType, need to cast to DataType
232+
plc_dtype = DataType(
233+
cast(pl.DataType, cast(pl.List, haystack.dtype.polars_type).inner)
234+
).plc_type
233235
else:
234236
plc_dtype = haystack.dtype.plc_type # pragma: no cover
235237
values = (

python/cudf_polars/cudf_polars/experimental/sort.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def find_sort_splits(
9898
stream=stream,
9999
)
100100
# And convert to list for final processing
101+
# The type ignores are for cross-library boundaries: plc.Column -> pl.Series
102+
# These work at runtime via the Arrow C Data Interface protocol
103+
# TODO: Find a way for pylibcudf types to show they export the Arrow protocol
104+
# (mypy wasn't happy with a custom protocol)
101105
split_first_list = pl.Series(split_first_col).to_list()
102106
split_last_list = pl.Series(split_last_col).to_list()
103107
split_part_id_list = pl.Series(split_part_id).to_list()

python/cudf_polars/cudf_polars/testing/asserts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def assert_gpu_result_equal(
112112

113113
# These keywords are correct, but mypy doesn't see that.
114114
# the 'misc' is for 'error: Keywords must be strings'
115-
expect = lazydf.collect(**final_polars_collect_kwargs)
116-
got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine)
115+
expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload]
116+
got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload]
117117

118118
assert_kwargs_bool: dict[str, bool] = {
119119
"check_row_order": check_row_order,
@@ -136,7 +136,7 @@ def assert_gpu_result_equal(
136136
expect,
137137
got,
138138
**assert_kwargs_bool,
139-
**tol_kwargs,
139+
**tol_kwargs, # type: ignore[arg-type]
140140
)
141141

142142

@@ -294,7 +294,7 @@ def assert_collect_raises(
294294
)
295295

296296
try:
297-
lazydf.collect(**final_polars_collect_kwargs)
297+
lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload]
298298
except polars_except:
299299
pass
300300
except Exception as e:
@@ -307,7 +307,7 @@ def assert_collect_raises(
307307

308308
engine = GPUEngine(raise_on_fail=True)
309309
try:
310-
lazydf.collect(**final_cudf_collect_kwargs, engine=engine)
310+
lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload]
311311
except cudf_except:
312312
pass
313313
except Exception as e:

python/cudf_polars/cudf_polars/testing/plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def pytest_configure(config: pytest.Config) -> None:
5757
collect = polars.LazyFrame.collect
5858
engine = polars.GPUEngine(raise_on_fail=no_fallback)
5959
# https://github.com/python/mypy/issues/2427
60-
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
60+
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment]
6161
elif executor == "in-memory":
6262
collect = polars.LazyFrame.collect
6363
engine = polars.GPUEngine(executor=executor)
64-
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
64+
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment]
6565
elif executor == "streaming" and blocksize_mode == "small":
6666
executor_options: dict[str, Any] = {}
6767
executor_options["max_rows_per_partition"] = 4
@@ -70,7 +70,7 @@ def pytest_configure(config: pytest.Config) -> None:
7070
executor_options["fallback_mode"] = StreamingFallbackMode.SILENT
7171
collect = polars.LazyFrame.collect
7272
engine = polars.GPUEngine(executor=executor, executor_options=executor_options)
73-
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
73+
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment]
7474
else:
7575
# run with streaming executor and default blocksize
7676
polars.Config.set_engine_affinity("gpu")

0 commit comments

Comments
 (0)