Skip to content

Commit 891b4c7

Browse files
committed
Fix equality of bound expressions (#95)
1 parent a09de69 commit 891b4c7

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

pyiceberg/expressions/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def __init__(self, term: BoundTerm[L]):
346346

347347
def __eq__(self, other: Any) -> bool:
348348
"""Return the equality of two instances of the BoundPredicate class."""
349-
if isinstance(other, BoundPredicate):
349+
if isinstance(other, self.__class__):
350350
return self.term == other.term
351351
return False
352352

@@ -567,7 +567,7 @@ def __repr__(self) -> str:
567567

568568
def __eq__(self, other: Any) -> bool:
569569
"""Return the equality of two instances of the BoundSetPredicate class."""
570-
return self.term == other.term and self.literals == other.literals if isinstance(other, BoundSetPredicate) else False
570+
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False
571571

572572
def __getnewargs__(self) -> Tuple[BoundTerm[L], Set[Literal[L]]]:
573573
"""Pickle the BoundSetPredicate class."""
@@ -595,7 +595,7 @@ def __invert__(self) -> BoundNotIn[L]:
595595

596596
def __eq__(self, other: Any) -> bool:
597597
"""Return the equality of two instances of the BoundIn class."""
598-
return self.term == other.term and self.literals == other.literals if isinstance(other, BoundIn) else False
598+
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False
599599

600600
@property
601601
def as_unbound(self) -> Type[In[L]]:
@@ -725,7 +725,7 @@ def __init__(self, term: BoundTerm[L], literal: Literal[L]): # pylint: disable=
725725

726726
def __eq__(self, other: Any) -> bool:
727727
"""Return the equality of two instances of the BoundLiteralPredicate class."""
728-
if isinstance(other, BoundLiteralPredicate):
728+
if isinstance(other, self.__class__):
729729
return self.term == other.term and self.literal == other.literal
730730
return False
731731

tests/conftest.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pyiceberg import schema
5858
from pyiceberg.catalog import Catalog
5959
from pyiceberg.catalog.noop import NoopCatalog
60+
from pyiceberg.expressions import BoundReference
6061
from pyiceberg.io import (
6162
GCS_ENDPOINT,
6263
GCS_PROJECT_ID,
@@ -69,7 +70,7 @@
6970
)
7071
from pyiceberg.io.fsspec import FsspecFileIO
7172
from pyiceberg.manifest import DataFile, FileFormat
72-
from pyiceberg.schema import Schema
73+
from pyiceberg.schema import Accessor, Schema
7374
from pyiceberg.serializers import ToOutputFile
7475
from pyiceberg.table import FileScanTask, Table
7576
from pyiceberg.table.metadata import TableMetadataV2
@@ -1659,3 +1660,8 @@ def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
16591660
io=load_file_io(),
16601661
catalog=NoopCatalog("NoopCatalog"),
16611662
)
1663+
1664+
1665+
@pytest.fixture
1666+
def bound_reference_str() -> BoundReference[str]:
1667+
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))

tests/expressions/test_expressions.py

+9
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,15 @@ def test_above_long_bounds_greater_than_or_equal(
11491149
assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue()
11501150

11511151

1152+
def test_eq_bound_expression(bound_reference_str: BoundReference[str]) -> None:
1153+
assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) != BoundGreaterThanOrEqual(
1154+
term=bound_reference_str, literal=literal('a')
1155+
)
1156+
assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) == BoundEqualTo(
1157+
term=bound_reference_str, literal=literal('a')
1158+
)
1159+
1160+
11521161
# __ __ ___
11531162
# | \/ |_ _| _ \_ _
11541163
# | |\/| | || | _/ || |

tests/test_transforms.py

-5
Original file line numberDiff line numberDiff line change
@@ -559,11 +559,6 @@ def test_datetime_transform_repr(transform: TimeTransform[Any], transform_repr:
559559
assert repr(transform) == transform_repr
560560

561561

562-
@pytest.fixture
563-
def bound_reference_str() -> BoundReference[str]:
564-
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))
565-
566-
567562
@pytest.fixture
568563
def bound_reference_date() -> BoundReference[int]:
569564
return BoundReference(field=NestedField(1, "field", DateType(), required=False), accessor=Accessor(position=0, inner=None))

0 commit comments

Comments
 (0)