Skip to content

Commit df6208e

Browse files
committed
feat: add fill_nan method to DataFrame for handling NaN values
1 parent 4cf7496 commit df6208e

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

python/datafusion/dataframe.py

+55
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
import pandas as pd
4747
import polars as pl
48+
import pyarrow as pa
4849

4950
from enum import Enum
5051

@@ -909,3 +910,57 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> "DataFrame":
909910
exprs.append(f.col(col_name))
910911

911912
return self.select(*exprs)
913+
914+
def fill_nan(self, value: float | int, subset: list[str] | None = None) -> "DataFrame":
915+
"""Fill NaN values in specified numeric columns with a value.
916+
917+
Args:
918+
value: Numeric value to replace NaN values with
919+
subset: Optional list of column names to fill. If None, fills all numeric columns.
920+
921+
Returns:
922+
DataFrame with NaN values replaced in numeric columns
923+
924+
Examples:
925+
>>> df = df.fill_nan(0) # Fill all NaNs with 0 in numeric columns
926+
>>> df = df.fill_nan(99.9, subset=["price", "score"]) # Fill specific columns
927+
928+
Notes:
929+
- Only fills NaN values in numeric columns (float32, float64)
930+
- Non-numeric columns are kept unchanged
931+
- For columns not in subset, the original column is kept unchanged
932+
- Value must be numeric (int or float)
933+
"""
934+
import pyarrow as pa
935+
from datafusion import functions as f
936+
937+
if not isinstance(value, (int, float)):
938+
raise ValueError("Value must be numeric (int or float)")
939+
940+
# Get columns to process
941+
if subset is None:
942+
# Only get numeric columns if no subset specified
943+
subset = [
944+
field.name for field in self.schema()
945+
if pa.types.is_floating(field.type)
946+
]
947+
else:
948+
schema_cols = self.schema().names
949+
for col in subset:
950+
if col not in schema_cols:
951+
raise ValueError(f"Column '{col}' not found in DataFrame")
952+
if not pa.types.is_floating(self.schema().field(col).type):
953+
raise ValueError(f"Column '{col}' is not a numeric column")
954+
955+
# Build expressions for select
956+
exprs = []
957+
for col_name in self.schema().names:
958+
if col_name in subset:
959+
# Use nanvl function to replace NaN values
960+
expr = f.nanvl(f.col(col_name), f.lit(value))
961+
exprs.append(expr.alias(col_name))
962+
else:
963+
# Keep columns not in subset unchanged
964+
exprs.append(f.col(col_name))
965+
966+
return self.select(*exprs)

python/tests/test_dataframe.py

+48
Original file line numberDiff line numberDiff line change
@@ -1264,3 +1264,51 @@ def test_fill_null(df):
12641264
)
12651265
with pytest.raises(ValueError, match="Column 'f' not found in DataFrame"):
12661266
df_with_nulls.fill_null("missing", subset=["e", "f"])
1267+
1268+
def test_fill_nan(df):
1269+
# Test filling NaNs with integer value
1270+
df_with_nans = df.with_column("d", literal(float("nan")).cast(pa.float64()))
1271+
df_filled = df_with_nans.fill_nan(0)
1272+
result = df_filled.to_pydict()
1273+
assert result["d"] == [0, 0, 0]
1274+
1275+
# Test filling NaNs with float value
1276+
df_with_nans = df.with_column("d", literal(float("nan")).cast(pa.float64()))
1277+
df_filled = df_with_nans.fill_nan(99.9)
1278+
result = df_filled.to_pydict()
1279+
assert result["d"] == [99.9, 99.9, 99.9]
1280+
1281+
# Test filling NaNs with subset of columns
1282+
df_with_nans = df.with_columns(
1283+
literal(float("nan")).cast(pa.float64()).alias("d"),
1284+
literal(float("nan")).cast(pa.float64()).alias("e"),
1285+
)
1286+
df_filled = df_with_nans.fill_nan(99.9, subset=["e"])
1287+
result = df_filled.to_pydict()
1288+
assert result["d"] == [float("nan"), float("nan"), float("nan")]
1289+
assert result["e"] == [99.9, 99.9, 99.9]
1290+
1291+
# Test filling NaNs with value that cannot be cast to column type
1292+
df_with_nans = df.with_column("d", literal(float("nan")).cast(pa.float64()))
1293+
with pytest.raises(ValueError, match="Value must be numeric"):
1294+
df_with_nans.fill_nan("invalid")
1295+
1296+
# Test filling NaNs with subset of columns where some casts fail
1297+
df_with_nans = df.with_columns(
1298+
literal(float("nan")).alias("d").cast(pa.float64()),
1299+
literal(float("nan")).alias("e").cast(pa.float64()),
1300+
)
1301+
df_filled = df_with_nans.fill_nan(0, subset=["d", "e"])
1302+
result = df_filled.to_pydict()
1303+
assert result["d"] == [0, 0, 0]
1304+
assert result["e"] == [0, 0, 0]
1305+
1306+
# Test filling NaNs with subset of columns where all casts succeed
1307+
df_with_nans = df.with_columns(
1308+
literal(float("nan")).alias("d").cast(pa.float64()),
1309+
literal(float("nan")).alias("e").cast(pa.float64()),
1310+
)
1311+
df_filled = df_with_nans.fill_nan(99.9, subset=["e"])
1312+
result = df_filled.to_pydict()
1313+
assert result["d"] == [float("nan"), float("nan"), float("nan")]
1314+
assert result["e"] == [99.9, 99.9, 99.9]

0 commit comments

Comments
 (0)