Skip to content

Fix parsing reference for nested fields #965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Any,
Generic,
Iterable,
Optional,
Set,
Tuple,
Type,
Expand Down Expand Up @@ -109,10 +110,12 @@ class BoundReference(BoundTerm[L]):

field: NestedField
accessor: Accessor
name: str

def __init__(self, field: NestedField, accessor: Accessor):
def __init__(self, field: NestedField, accessor: Accessor, name: Optional[str] = None):
self.field = field
self.accessor = accessor
self.name = name if name else field.name

def eval(self, struct: StructProtocol) -> L:
"""Return the value at the referenced field's position in an object that abides by the StructProtocol.
Expand Down Expand Up @@ -185,7 +188,7 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[L]
"""
field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive)
accessor = schema.accessor_for_field(field.field_id)
return self.as_bound(field=field, accessor=accessor) # type: ignore
return self.as_bound(field=field, accessor=accessor, name=self.name) # type: ignore

@property
def as_bound(self) -> Type[BoundReference[L]]:
Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/expressions/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

@column.set_parse_action
def _(result: ParseResults) -> Reference:
return Reference(result.column[-1])
return Reference(".".join(result.column))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extract out "." as a variable



boolean = one_of(["true", "false"], caseless=True).set_results_name("boolean")
Expand Down
31 changes: 17 additions & 14 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,51 +572,54 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:


class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
def _flat_name_to_list(self, name: str) -> List[str]:
return name.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a risky route. I'd rather migrate to a Tuple[str] situation internally so we can actually support fields with .'s


def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return pc.field(term.ref().field.name).isin(pyarrow_literals)
return pc.field(*self._flat_name_to_list(term.ref().name)).isin(pyarrow_literals)

def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
return ~pc.field(*self._flat_name_to_list(term.ref().name)).isin(pyarrow_literals)

def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression:
ref = pc.field(term.ref().field.name)
ref = pc.field(*self._flat_name_to_list(term.ref().name))
return pc.is_nan(ref)

def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression:
ref = pc.field(term.ref().field.name)
ref = pc.field(*self._flat_name_to_list(term.ref().name))
return ~pc.is_nan(ref)

def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression:
return pc.field(term.ref().field.name).is_null(nan_is_null=False)
return pc.field(*self._flat_name_to_list(term.ref().name)).is_null(nan_is_null=False)

def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
return pc.field(term.ref().field.name).is_valid()
return pc.field(*self._flat_name_to_list(term.ref().name)).is_valid()

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(*self._flat_name_to_list(term.ref().name)) == _convert_scalar(literal.value, term.ref().field.field_type)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(*self._flat_name_to_list(term.ref().name)) != _convert_scalar(literal.value, term.ref().field.field_type)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(*self._flat_name_to_list(term.ref().name)) >= _convert_scalar(literal.value, term.ref().field.field_type)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(*self._flat_name_to_list(term.ref().name)) > _convert_scalar(literal.value, term.ref().field.field_type)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(*self._flat_name_to_list(term.ref().name)) < _convert_scalar(literal.value, term.ref().field.field_type)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type)
return pc.field(*self._flat_name_to_list(term.ref().name)) <= _convert_scalar(literal.value, term.ref().field.field_type)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.starts_with(pc.field(term.ref().field.name), literal.value)
return pc.starts_with(pc.field(*self._flat_name_to_list(term.ref().name)), literal.value)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return ~pc.starts_with(pc.field(term.ref().field.name), literal.value)
return ~pc.starts_with(pc.field(*self._flat_name_to_list(term.ref().name)), literal.value)

def visit_true(self) -> pc.Expression:
return pc.scalar(True)
Expand Down
82 changes: 82 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,67 @@ def generate_snapshot(
}


TABLE_METADATA_V2_WITH_STRUCT_TYPE = {
"format-version": 2,
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
"location": "s3://bucket/test/location",
"last-sequence-number": 34,
"last-updated-ms": 1602638573590,
"last-column-id": 7,
"current-schema-id": 1,
"schemas": [
{
"type": "struct",
"schema-id": 1,
"identifier-field-ids": [],
"fields": [
{
"id": 1,
"name": "person",
"required": False,
"type": {
"type": "struct",
"fields": [
{"id": 2, "name": "id", "required": False, "type": "long"},
{"id": 3, "name": "age", "required": False, "type": "float"},
],
},
},
],
}
],
"default-spec-id": 0,
"partition-specs": [{"spec-id": 0, "fields": []}],
"last-partition-id": 1000,
"properties": {"read.split.target.size": "134217728"},
"current-snapshot-id": 3055729675574597004,
"snapshots": [
{
"snapshot-id": 3051729675574597004,
"timestamp-ms": 1515100955770,
"sequence-number": 0,
"summary": {"operation": "append"},
"manifest-list": "s3://a/b/1.avro",
},
{
"snapshot-id": 3055729675574597004,
"parent-snapshot-id": 3051729675574597004,
"timestamp-ms": 1555100955770,
"sequence-number": 1,
"summary": {"operation": "append"},
"manifest-list": "s3://a/b/2.avro",
"schema-id": 1,
},
],
"snapshot-log": [
{"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770},
{"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770},
],
"metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}],
"refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}},
}


@pytest.fixture
def example_table_metadata_v2() -> Dict[str, Any]:
return EXAMPLE_TABLE_METADATA_V2
Expand All @@ -929,6 +990,11 @@ def table_metadata_v2_with_fixed_and_decimal_types() -> Dict[str, Any]:
return TABLE_METADATA_V2_WITH_FIXED_AND_DECIMAL_TYPES


@pytest.fixture
def table_metadata_v2_with_struct_type() -> Dict[str, Any]:
return TABLE_METADATA_V2_WITH_STRUCT_TYPE


@pytest.fixture(scope="session")
def metadata_location(tmp_path_factory: pytest.TempPathFactory) -> str:
from pyiceberg.io.pyarrow import PyArrowFileIO
Expand Down Expand Up @@ -2158,6 +2224,22 @@ def table_v2_with_fixed_and_decimal_types(
)


@pytest.fixture
def table_v2_with_struct_type(
table_metadata_v2_with_struct_type: Dict[str, Any],
) -> Table:
table_metadata = TableMetadataV2(
**table_metadata_v2_with_struct_type,
)
return Table(
identifier=("database", "table"),
metadata=table_metadata,
metadata_location=f"{table_metadata.location}/uuid.metadata.json",
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def table_v2_with_extensive_snapshots(example_table_metadata_v2_with_extensive_snapshots: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV2(**example_table_metadata_v2_with_extensive_snapshots)
Expand Down
6 changes: 5 additions & 1 deletion tests/expressions/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_false() -> None:
def test_is_null() -> None:
assert IsNull("foo") == parser.parse("foo is null")
assert IsNull("foo") == parser.parse("foo IS NULL")
assert IsNull("foo") == parser.parse("table.foo IS NULL")
assert IsNull("table.foo") == parser.parse("table.foo IS NULL")


def test_not_null() -> None:
Expand Down Expand Up @@ -199,3 +199,7 @@ def test_with_function() -> None:
parser.parse("foo = 1 and lower(bar) = '2'")

assert "Expected end of text, found 'and'" in str(exc_info)


def test_nested_field_equality() -> None:
assert EqualTo("foo.first", "a") == parser.parse("foo.first == 'a'")
Copy link
Contributor

@Fokko Fokko Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we first have to come up with a proposal for representing nested fields in a flat string that doesn't result in these edge cases, or is at least configurable (e.g. Config parameter PYICEBERG__NESTED_FIELD_DELIMITER defined at the session level that defaults to .)

I think the key to success is to have some kind of syntax for quoting literals. For example: https://spark.apache.org/docs/latest/sql-ref-literals.html

Then we can parse something like:

'a.b' -> Reference(('a.b',))
'a.b'.c -> Reference(('a.b', 'c'))
a.b.c -> Reference(('a', 'b', 'c'))

Or folks have to use:

row_filter=EqualTo(('a.b',), 123)
row_filter=EqualTo(('a.b', 'c'), 123)
row_filter=EqualTo(('a', 'b', 'c'), 123)

Copy link
Contributor

@Fokko Fokko Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also an interesting proposal on the spec side of things: apache/iceberg#10883

Related: apache/iceberg#598

Loading
Loading