Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 18 additions & 27 deletions src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,34 +462,25 @@ class ApplyFunc:
def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
# First column is row position, extract it for later use
row_positions = df.iloc[:, 0]

# If we have index columns, set them as the index
num_index_columns = (
0 if index_column_labels is None else len(index_column_labels)
)
if num_index_columns > 0:
# Columns after row position are index columns, then data columns
index_cols = df.iloc[:, 1 : 1 + num_index_columns]
data_cols = df.iloc[:, 1 + num_index_columns :]

# Set the index using the index columns
if num_index_columns == 1:
index = index_cols.iloc[:, 0]
if index_column_labels:
index.name = index_column_labels[0]
else:
# Multi-index case
index = native_pd.MultiIndex.from_arrays(
[index_cols.iloc[:, i] for i in range(num_index_columns)],
names=index_column_labels if index_column_labels else None,
)
data_cols.set_index(index, inplace=True)
df = data_cols
if index_column_labels is None:
# Set the row positions column as the index.
df.set_index(df.columns[0], drop=False, inplace=True)
else:
# No index columns, use row position as index (original behavior)
df = df.iloc[:, 1:]
df.set_index(row_positions, inplace=True)

# The columns after the row position column represent index
# labels. Set them as the index and remove them from the data
# columns. We don't care about the index names because `func`
# applies to each row and doesn't see or affect the index names.
df.set_index(
# If we try to select a slice of df.columns instead of
# a list of column labels, pandas treats the slice as a
# sequence of row labels rather than as a sequence of
# index column labels.
list(df.columns[1 : len(index_column_labels) + 1]),
drop=True,
inplace=True,
)
# Drop the row positions column.
df = df.iloc[:, 1:]
df.columns = column_index
df = df.apply(
func, axis=1, raw=raw, result_type=result_type, args=args, **kwargs
Expand Down
Loading