Skip to content

Use batchreader in upsert #1995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
58 changes: 37 additions & 21 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,39 +774,55 @@ def upsert(
matched_predicate = upsert_util.create_match_filter(df, join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
matched_iceberg_table = DataScan(
matched_iceberg_record_batches = DataScan(
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=matched_predicate,
case_sensitive=case_sensitive,
).to_arrow()
).to_arrow_batch_reader()

update_row_cnt = 0
insert_row_cnt = 0
batches_to_overwrite = []
overwrite_predicates = []
rows_to_insert = df

if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)
for batch in matched_iceberg_record_batches:
rows = pa.Table.from_batches([batch])

update_row_cnt = len(rows_to_update)
if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols)

if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)

self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
batches_to_overwrite.append(rows_to_update)
overwrite_predicates.append(overwrite_mask_predicate)

if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
rows_to_insert = df.filter(~expr_match_arrow)
if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(rows, join_cols)
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)

insert_row_cnt = len(rows_to_insert)
# Filter rows per batch.
rows_to_insert = rows_to_insert.filter(~expr_match_arrow)

if insert_row_cnt > 0:
update_row_cnt = 0
insert_row_cnt = 0

if batches_to_overwrite:
rows_to_update = pa.concat_tables(batches_to_overwrite)
update_row_cnt = len(rows_to_update)
self.overwrite(
rows_to_update,
overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0],
)

if when_not_matched_insert_all:
insert_row_cnt = len(rows_to_insert)
if rows_to_insert:
self.append(rows_to_insert)

return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
Expand Down
Loading