1515import base64
1616import decimal
1717import re
18- from typing import Iterable , Optional , Sequence , Set , Union
18+ from typing import Iterable , Optional , Sequence , Set , TypeVar , Union
1919
2020import geopandas as gpd # type: ignore
2121import google .api_core .operation
6868 "content" ,
6969]
7070
71+ SeriesOrIndexT = TypeVar ("SeriesOrIndexT" , pd .Series , pd .Index )
72+
7173
7274def pandas_major_version () -> int :
7375 match = re .search (r"^v?(\d+)" , pd .__version__ .strip ())
@@ -90,15 +92,27 @@ def assert_series_equivalent(pd_series: pd.Series, bf_series: bpd.Series, **kwar
9092
9193def _normalize_all_nulls (col : pd .Series ) -> pd .Series :
9294 # This over-normalizes probably, make more conservative later
93- if col .hasnans and (
94- pd_types .is_float_dtype (col .dtype ) or pd_types .is_integer_dtype (col .dtype )
95- ):
96- col = col .astype ("float64" )
95+ if col .hasnans and (pd_types .is_float_dtype (col .dtype )):
96+ col = col .astype ("float64" ).astype ("Float64" )
9797 if pd_types .is_object_dtype (col ):
98- col = col .fillna (float ( "nan" ) )
98+ col = col .fillna (pd . NA )
9999 return col
100100
101101
102+ def _normalize_index_nulls (idx : pd .Index ) -> pd .Index :
103+ if isinstance (idx , pd .MultiIndex ):
104+ new_levels = [
105+ _normalize_index_nulls (idx .get_level_values (i )) for i in range (idx .nlevels )
106+ ]
107+ return pd .MultiIndex .from_arrays (new_levels , names = idx .names )
108+ if idx .hasnans :
109+ if pd_types .is_float_dtype (idx .dtype ) or pd_types .is_integer_dtype (idx .dtype ):
110+ idx = idx .astype ("float64" ).astype ("Float64" )
111+ if pd_types .is_object_dtype (idx .dtype ):
112+ idx = idx .fillna (pd .NA )
113+ return idx
114+
115+
102116def assert_frame_equal (
103117 left : pd .DataFrame ,
104118 right : pd .DataFrame ,
@@ -123,6 +137,8 @@ def assert_frame_equal(
123137 if nulls_are_nan :
124138 left = left .apply (_normalize_all_nulls )
125139 right = right .apply (_normalize_all_nulls )
140+ left .index = _normalize_index_nulls (left .index )
141+ right .index = _normalize_index_nulls (right .index )
126142
127143 pd .testing .assert_frame_equal (left , right , ** kwargs )
128144
@@ -155,6 +171,10 @@ def assert_series_equal(
155171 if nulls_are_nan :
156172 left = _normalize_all_nulls (left )
157173 right = _normalize_all_nulls (right )
174+ left .index = _normalize_index_nulls (left .index )
175+ right .index = _normalize_index_nulls (right .index )
176+ left .name = pd .NA if pd .isna (left .name ) else left .name # type: ignore
177+ right .name = pd .NA if pd .isna (right .name ) else right .name # type: ignore
158178
159179 pd .testing .assert_series_equal (left , right , ** kwargs )
160180
0 commit comments