66from __future__ import annotations
77
88from functools import cache
9- from typing import TYPE_CHECKING
9+ from typing import TYPE_CHECKING , Literal , cast
1010
1111from typing_extensions import assert_never
1212
1313import polars as pl
1414
1515import pylibcudf as plc
1616
17+ from cudf_polars .utils .versions import POLARS_VERSION_LT_136
18+
1719if TYPE_CHECKING :
1820 from cudf_polars .typing import (
1921 DataTypeHeader ,
22+ PolarsDataType ,
2023 )
2124
2225__all__ = ["DataType" ]
@@ -46,7 +49,18 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
4649 if name in SCALAR_NAME_TO_POLARS_TYPE_MAP :
4750 return {"kind" : "scalar" , "name" : name }
4851 if isinstance (dtype , pl .Decimal ):
49- return {"kind" : "decimal" , "precision" : dtype .precision , "scale" : dtype .scale }
52+ # Workaround for incorrect polars stubs where precision is typed as int | None
53+ # Fixed upstream: https://github.com/pola-rs/polars/pull/25227
54+ # TODO: Remove this workaround when polars >= 1.36
55+ if POLARS_VERSION_LT_136 :
56+ assert (
57+ dtype .precision is not None
58+ ) # Decimal always has precision at runtime
59+ return {
60+ "kind" : "decimal" ,
61+ "precision" : cast (int , dtype .precision ),
62+ "scale" : dtype .scale ,
63+ }
5064 if isinstance (dtype , pl .Datetime ):
5165 return {
5266 "kind" : "datetime" ,
@@ -56,12 +70,17 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
5670 if isinstance (dtype , pl .Duration ):
5771 return {"kind" : "duration" , "time_unit" : dtype .time_unit }
5872 if isinstance (dtype , pl .List ):
59- return {"kind" : "list" , "inner" : _dtype_to_header (dtype .inner )}
73+ # isinstance narrows dtype to pl.List, but .inner returns DataTypeClass | DataType
74+ return {
75+ "kind" : "list" ,
76+ "inner" : _dtype_to_header (cast (pl .DataType , dtype .inner )),
77+ }
6078 if isinstance (dtype , pl .Struct ):
79+ # isinstance narrows dtype to pl.Struct, but field.dtype returns DataTypeClass | DataType
6180 return {
6281 "kind" : "struct" ,
6382 "fields" : [
64- {"name" : f .name , "dtype" : _dtype_to_header (f .dtype )}
83+ {"name" : f .name , "dtype" : _dtype_to_header (cast ( pl . DataType , f .dtype ) )}
6584 for f in dtype .fields
6685 ],
6786 }
@@ -78,9 +97,14 @@ def _dtype_from_header(header: DataTypeHeader) -> pl.DataType:
7897 if header ["kind" ] == "decimal" :
7998 return pl .Decimal (header ["precision" ], header ["scale" ])
8099 if header ["kind" ] == "datetime" :
81- return pl .Datetime (time_unit = header ["time_unit" ], time_zone = header ["time_zone" ])
100+ return pl .Datetime (
101+ time_unit = cast (Literal ["ns" , "us" , "ms" ], header ["time_unit" ]),
102+ time_zone = header ["time_zone" ],
103+ )
82104 if header ["kind" ] == "duration" :
83- return pl .Duration (time_unit = header ["time_unit" ])
105+ return pl .Duration (
106+ time_unit = cast (Literal ["ns" , "us" , "ms" ], header ["time_unit" ])
107+ )
84108 if header ["kind" ] == "list" :
85109 return pl .List (_dtype_from_header (header ["inner" ]))
86110 if header ["kind" ] == "struct" :
@@ -182,9 +206,14 @@ class DataType:
182206 polars_type : pl .datatypes .DataType
183207 plc_type : plc .DataType
184208
185- def __init__ (self , polars_dtype : pl .DataType ) -> None :
186- self .polars_type = polars_dtype
187- self .plc_type = _from_polars (polars_dtype )
209+ def __init__ (self , polars_dtype : PolarsDataType ) -> None :
210+ # Convert DataTypeClass to DataType instance if needed
211+ # polars allows both pl.Int64 (class) and pl.Int64() (instance)
212+ if isinstance (polars_dtype , type ):
213+ polars_dtype = polars_dtype ()
214+ # After conversion, it's guaranteed to be a DataType instance
215+ self .polars_type = cast (pl .DataType , polars_dtype )
216+ self .plc_type = _from_polars (self .polars_type )
188217
189218 def id (self ) -> plc .TypeId :
190219 """The pylibcudf.TypeId of this DataType."""
@@ -193,12 +222,16 @@ def id(self) -> plc.TypeId:
193222 @property
194223 def children (self ) -> list [DataType ]:
195224 """The children types of this DataType."""
196- # these type ignores are needed because the type checker doesn't
197- # see that these equality checks passing imply a specific type for each child field.
225+ # Type checker doesn't narrow polars_type through plc_type.id() checks
198226 if self .plc_type .id () == plc .TypeId .STRUCT :
199- return [DataType (field .dtype ) for field in self .polars_type .fields ]
227+ # field.dtype returns DataTypeClass | DataType, need to cast to DataType
228+ return [
229+ DataType (cast (pl .DataType , field .dtype ))
230+ for field in cast (pl .Struct , self .polars_type ).fields
231+ ]
200232 elif self .plc_type .id () == plc .TypeId .LIST :
201- return [DataType (self .polars_type .inner )]
233+ # .inner returns DataTypeClass | DataType, need to cast to DataType
234+ return [DataType (cast (pl .DataType , cast (pl .List , self .polars_type ).inner ))]
202235 return []
203236
204237 def scale (self ) -> int :
0 commit comments