Skip to content

Commit 3d4abc7

Browse files
fix more tests
1 parent 15a8d7a commit 3d4abc7

File tree

7 files changed

+114
-59
lines changed

7 files changed

+114
-59
lines changed

bigframes/ml/metrics/_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def confusion_matrix(
214214
y_true = row["y_true"]
215215
y_pred = row["y_pred"]
216216
count = row["dummy"]
217-
confusion_matrix[y_pred][y_true] = count
217+
confusion_matrix.at[y_true, y_pred] = count
218218

219219
return confusion_matrix
220220

@@ -251,7 +251,7 @@ def recall_score(
251251
/ is_accurate.groupby(y_true_series).count()
252252
).to_pandas()
253253

254-
recall_score = pd.Series(0, index=index)
254+
recall_score = pd.Series(0.0, index=index)
255255
for i in recall_score.index:
256256
recall_score.loc[i] = recall.loc[i]
257257

@@ -321,7 +321,7 @@ def _precision_score_per_label(y_true: bpd.Series, y_pred: bpd.Series) -> pd.Ser
321321
is_accurate.groupby(y_pred).sum() / is_accurate.groupby(y_pred).count()
322322
).to_pandas()
323323

324-
precision_score = pd.Series(0, index=index)
324+
precision_score = pd.Series(0.0, index=index)
325325
for i in precision.index:
326326
precision_score.loc[i] = precision.loc[i]
327327

@@ -366,7 +366,7 @@ def f1_score(
366366
recall = recall_score(y_true_series, y_pred_series, average=None)
367367
precision = precision_score(y_true_series, y_pred_series, average=None)
368368

369-
f1_score = pd.Series(0, index=recall.index)
369+
f1_score = pd.Series(0.0, index=recall.index)
370370
for index in recall.index:
371371
if precision[index] + recall[index] != 0:
372372
f1_score[index] = (

bigframes/testing/utils.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import base64
1616
import decimal
1717
import re
18-
from typing import Iterable, Optional, Sequence, Set, Union
18+
from typing import Iterable, Optional, Sequence, Set, TypeVar, Union
1919

2020
import geopandas as gpd # type: ignore
2121
import google.api_core.operation
@@ -68,6 +68,8 @@
6868
"content",
6969
]
7070

71+
SeriesOrIndexT = TypeVar("SeriesOrIndexT", pd.Series, pd.Index)
72+
7173

7274
def pandas_major_version() -> int:
7375
match = re.search(r"^v?(\d+)", pd.__version__.strip())
@@ -90,15 +92,27 @@ def assert_series_equivalent(pd_series: pd.Series, bf_series: bpd.Series, **kwar
9092

9193
def _normalize_all_nulls(col: pd.Series) -> pd.Series:
9294
# This over-normalizes probably, make more conservative later
93-
if col.hasnans and (
94-
pd_types.is_float_dtype(col.dtype) or pd_types.is_integer_dtype(col.dtype)
95-
):
96-
col = col.astype("float64")
95+
if col.hasnans and (pd_types.is_float_dtype(col.dtype)):
96+
col = col.astype("float64").astype("Float64")
9797
if pd_types.is_object_dtype(col):
98-
col = col.fillna(float("nan"))
98+
col = col.fillna(pd.NA)
9999
return col
100100

101101

102+
def _normalize_index_nulls(idx: pd.Index) -> pd.Index:
103+
if isinstance(idx, pd.MultiIndex):
104+
new_levels = [
105+
_normalize_index_nulls(idx.get_level_values(i)) for i in range(idx.nlevels)
106+
]
107+
return pd.MultiIndex.from_arrays(new_levels, names=idx.names)
108+
if idx.hasnans:
109+
if pd_types.is_float_dtype(idx.dtype) or pd_types.is_integer_dtype(idx.dtype):
110+
idx = idx.astype("float64").astype("Float64")
111+
if pd_types.is_object_dtype(idx.dtype):
112+
idx = idx.fillna(pd.NA)
113+
return idx
114+
115+
102116
def assert_frame_equal(
103117
left: pd.DataFrame,
104118
right: pd.DataFrame,
@@ -123,6 +137,8 @@ def assert_frame_equal(
123137
if nulls_are_nan:
124138
left = left.apply(_normalize_all_nulls)
125139
right = right.apply(_normalize_all_nulls)
140+
left.index = _normalize_index_nulls(left.index)
141+
right.index = _normalize_index_nulls(right.index)
126142

127143
pd.testing.assert_frame_equal(left, right, **kwargs)
128144

@@ -155,6 +171,10 @@ def assert_series_equal(
155171
if nulls_are_nan:
156172
left = _normalize_all_nulls(left)
157173
right = _normalize_all_nulls(right)
174+
left.index = _normalize_index_nulls(left.index)
175+
right.index = _normalize_index_nulls(right.index)
176+
left.name = pd.NA if pd.isna(left.name) else left.name # type: ignore
177+
right.name = pd.NA if pd.isna(right.name) else right.name # type: ignore
158178

159179
pd.testing.assert_series_equal(left, right, **kwargs)
160180

tests/system/small/core/test_reshape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import pandas as pd
16-
import pandas.testing
1716
import pytest
1817

1918
from bigframes import session
2019
from bigframes.core.reshape import merge
20+
import bigframes.testing
2121

2222

2323
@pytest.mark.parametrize(
@@ -56,7 +56,7 @@ def test_join_with_index(
5656
how=how,
5757
)
5858

59-
pandas.testing.assert_frame_equal(
59+
bigframes.testing.assert_frame_equal(
6060
bf_result, pd_result, check_dtype=False, check_index_type=False
6161
)
6262

tests/system/small/ml/test_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
import pandas as pd
16-
import pandas.testing
1716
import pytest
1817

1918
import bigframes.ml.utils as utils
19+
import bigframes.testing
2020

2121
_DATA_FRAME = pd.DataFrame({"column": [1, 2, 3]})
2222
_SERIES = pd.Series([1, 2, 3], name="column")
@@ -31,7 +31,7 @@ def test_convert_to_dataframe(session, data):
3131

3232
(actual_result,) = utils.batch_convert_to_dataframe(bf_data)
3333

34-
pandas.testing.assert_frame_equal(
34+
bigframes.testing.assert_frame_equal(
3535
actual_result.to_pandas(),
3636
_DATA_FRAME,
3737
check_index_type=False,
@@ -46,7 +46,7 @@ def test_convert_to_dataframe(session, data):
4646
def test_convert_pandas_to_dataframe(data, session):
4747
(actual_result,) = utils.batch_convert_to_dataframe(data, session=session)
4848

49-
pandas.testing.assert_frame_equal(
49+
bigframes.testing.assert_frame_equal(
5050
actual_result.to_pandas(),
5151
_DATA_FRAME,
5252
check_index_type=False,
@@ -63,7 +63,7 @@ def test_convert_to_series(session, data):
6363

6464
(actual_result,) = utils.batch_convert_to_series(bf_data)
6565

66-
pandas.testing.assert_series_equal(
66+
bigframes.testing.assert_series_equal(
6767
actual_result.to_pandas(), _SERIES, check_index_type=False, check_dtype=False
6868
)
6969

@@ -75,6 +75,6 @@ def test_convert_to_series(session, data):
7575
def test_convert_pandas_to_series(data, session):
7676
(actual_result,) = utils.batch_convert_to_series(data, session=session)
7777

78-
pandas.testing.assert_series_equal(
78+
bigframes.testing.assert_series_equal(
7979
actual_result.to_pandas(), _SERIES, check_index_type=False, check_dtype=False
8080
)

tests/system/small/operations/test_timedeltas.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
from bigframes import dtypes
2626
import bigframes.testing
2727

28+
# Some methods/features used by this test don't exist in pandas 1.x
29+
pytest.importorskip("pandas", minversion="2.0.0")
30+
2831

2932
@pytest.fixture(scope="module")
3033
def temporal_dfs(session):

0 commit comments

Comments
 (0)