Skip to content

Commit 4e9c66d

Browse files
omkengeFokko
andauthored
Fix TypeError in create_match_filter for Composite Keys (#1693)
**Old Code Behavior:** Even if there's only one such condition, the code wraps it in an Or() operator. But Or() is meant to combine two or more conditions (like “condition A OR condition B”). If you give it only one condition, it complains because it expects a second condition. **New Code Behavior:** The new change checks how many conditions you have. If there's only one condition, it simply returns that condition. If there are more than one, it uses Or() to combine them. --------- Co-authored-by: Fokko Driesprong <[email protected]>
1 parent c7fe114 commit 4e9c66d

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

pyiceberg/table/upsert_util.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
# under the License.
1717
import functools
1818
import operator
19+
from typing import List, cast
1920

2021
import pyarrow as pa
2122
from pyarrow import Table as pyarrow_table
2223
from pyarrow import compute as pc
2324

2425
from pyiceberg.expressions import (
26+
AlwaysFalse,
2527
And,
2628
BooleanExpression,
2729
EqualTo,
@@ -36,7 +38,16 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
3638
if len(join_cols) == 1:
3739
return In(join_cols[0], unique_keys[0].to_pylist())
3840
else:
39-
return Or(*[And(*[EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()])
41+
filters: List[BooleanExpression] = [
42+
cast(BooleanExpression, And(*[EqualTo(col, row[col]) for col in join_cols])) for row in unique_keys.to_pylist()
43+
]
44+
45+
if len(filters) == 0:
46+
return AlwaysFalse()
47+
elif len(filters) == 1:
48+
return filters[0]
49+
else:
50+
return functools.reduce(lambda a, b: Or(a, b), filters)
4051

4152

4253
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
@@ -86,7 +97,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
8697
if rows_to_update:
8798
rows_to_update_table = pa.concat_tables(rows_to_update)
8899
else:
89-
rows_to_update_table = pa.Table.from_arrays([], names=source_table.column_names)
100+
rows_to_update_table = source_table.schema.empty_table()
90101

91102
common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
92103
rows_to_update_table = rows_to_update_table.select(list(common_columns))

tests/table/test_upsert.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323

2424
from pyiceberg.catalog import Catalog
2525
from pyiceberg.exceptions import NoSuchTableError
26+
from pyiceberg.expressions import And, EqualTo, Reference
27+
from pyiceberg.expressions.literals import LongLiteral
2628
from pyiceberg.schema import Schema
2729
from pyiceberg.table import UpsertResult
30+
from pyiceberg.table.upsert_util import create_match_filter
2831
from pyiceberg.types import IntegerType, NestedField, StringType
2932
from tests.catalog.test_base import InMemoryCatalog, Table
3033

@@ -366,3 +369,22 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None:
366369

367370
assert upd.rows_updated == 1
368371
assert upd.rows_inserted == 1
372+
373+
374+
def test_create_match_filter_single_condition() -> None:
375+
"""
376+
Test create_match_filter with a composite key where the source yields exactly one unique key.
377+
Expected: The function returns the single And condition directly.
378+
"""
379+
380+
data = [
381+
{"order_id": 101, "order_line_id": 1, "extra": "x"},
382+
{"order_id": 101, "order_line_id": 1, "extra": "x"}, # duplicate
383+
]
384+
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())])
385+
table = pa.Table.from_pylist(data, schema=schema)
386+
expr = create_match_filter(table, ["order_id", "order_line_id"])
387+
assert expr == And(
388+
EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)),
389+
EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)),
390+
)

0 commit comments

Comments
 (0)