Skip to content

Commit 8f47dfd

Browse files
authored
Move determine_partitions and helper methods to io.pyarrow (#906)
1 parent 5aa451d commit 8f47dfd

File tree

4 files changed

+180
-189
lines changed

4 files changed

+180
-189
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
DataFileContent,
114114
FileFormat,
115115
)
116-
from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value
116+
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
117117
from pyiceberg.schema import (
118118
PartnerAccessor,
119119
PreOrderSchemaVisitor,
@@ -2125,8 +2125,6 @@ def _dataframe_to_data_files(
21252125
]),
21262126
)
21272127
else:
2128-
from pyiceberg.table import _determine_partitions
2129-
21302128
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
21312129
yield from write_file(
21322130
io=io,
@@ -2143,3 +2141,100 @@ def _dataframe_to_data_files(
21432141
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
21442142
]),
21452143
)
2144+
2145+
2146+
@dataclass(frozen=True)
2147+
class _TablePartition:
2148+
partition_key: PartitionKey
2149+
arrow_table_partition: pa.Table
2150+
2151+
2152+
def _get_table_partitions(
2153+
arrow_table: pa.Table,
2154+
partition_spec: PartitionSpec,
2155+
schema: Schema,
2156+
slice_instructions: list[dict[str, Any]],
2157+
) -> list[_TablePartition]:
2158+
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])
2159+
2160+
partition_fields = partition_spec.fields
2161+
2162+
offsets = [inst["offset"] for inst in sorted_slice_instructions]
2163+
projected_and_filtered = {
2164+
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
2165+
.take(offsets)
2166+
.to_pylist()
2167+
for partition_field in partition_fields
2168+
}
2169+
2170+
table_partitions = []
2171+
for idx, inst in enumerate(sorted_slice_instructions):
2172+
partition_slice = arrow_table.slice(**inst)
2173+
fieldvalues = [
2174+
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
2175+
for partition_field in partition_fields
2176+
]
2177+
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
2178+
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
2179+
return table_partitions
2180+
2181+
2182+
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
2183+
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
2184+
2185+
Example:
2186+
Input:
2187+
An arrow table with partition key of ['n_legs', 'year'] and with data of
2188+
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
2189+
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
2190+
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
2191+
The algorithm:
2192+
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
2193+
and null_placement of "at_end".
2194+
This gives the same table as raw input.
2195+
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
2196+
and null_placement : "at_start".
2197+
This gives:
2198+
[8, 7, 4, 5, 6, 3, 1, 2, 0]
2199+
Based on this we get partition groups of indices:
2200+
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
2201+
We then retrieve the partition keys by offsets.
2202+
And slice the arrow table by offsets and lengths of each partition.
2203+
"""
2204+
partition_columns: List[Tuple[PartitionField, NestedField]] = [
2205+
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
2206+
]
2207+
partition_values_table = pa.table({
2208+
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
2209+
for partition, field in partition_columns
2210+
})
2211+
2212+
# Sort by partitions
2213+
sort_indices = pa.compute.sort_indices(
2214+
partition_values_table,
2215+
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
2216+
null_placement="at_end",
2217+
).to_pylist()
2218+
arrow_table = arrow_table.take(sort_indices)
2219+
2220+
# Get slice_instructions to group by partitions
2221+
partition_values_table = partition_values_table.take(sort_indices)
2222+
reversed_indices = pa.compute.sort_indices(
2223+
partition_values_table,
2224+
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
2225+
null_placement="at_start",
2226+
).to_pylist()
2227+
slice_instructions: List[Dict[str, Any]] = []
2228+
last = len(reversed_indices)
2229+
reversed_indices_size = len(reversed_indices)
2230+
ptr = 0
2231+
while ptr < reversed_indices_size:
2232+
group_size = last - reversed_indices[ptr]
2233+
offset = reversed_indices[ptr]
2234+
slice_instructions.append({"offset": offset, "length": group_size})
2235+
last = reversed_indices[ptr]
2236+
ptr = ptr + group_size
2237+
2238+
table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
2239+
2240+
return table_partitions

pyiceberg/table/__init__.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
PARTITION_FIELD_ID_START,
9393
UNPARTITIONED_PARTITION_SPEC,
9494
PartitionField,
95-
PartitionFieldValue,
9695
PartitionKey,
9796
PartitionSpec,
9897
_PartitionNameGenerator,
@@ -4412,105 +4411,6 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
44124411
)
44134412

44144413

4415-
@dataclass(frozen=True)
4416-
class TablePartition:
4417-
partition_key: PartitionKey
4418-
arrow_table_partition: pa.Table
4419-
4420-
4421-
def _get_table_partitions(
4422-
arrow_table: pa.Table,
4423-
partition_spec: PartitionSpec,
4424-
schema: Schema,
4425-
slice_instructions: list[dict[str, Any]],
4426-
) -> list[TablePartition]:
4427-
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])
4428-
4429-
partition_fields = partition_spec.fields
4430-
4431-
offsets = [inst["offset"] for inst in sorted_slice_instructions]
4432-
projected_and_filtered = {
4433-
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
4434-
.take(offsets)
4435-
.to_pylist()
4436-
for partition_field in partition_fields
4437-
}
4438-
4439-
table_partitions = []
4440-
for idx, inst in enumerate(sorted_slice_instructions):
4441-
partition_slice = arrow_table.slice(**inst)
4442-
fieldvalues = [
4443-
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
4444-
for partition_field in partition_fields
4445-
]
4446-
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
4447-
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
4448-
return table_partitions
4449-
4450-
4451-
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]:
4452-
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
4453-
4454-
Example:
4455-
Input:
4456-
An arrow table with partition key of ['n_legs', 'year'] and with data of
4457-
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
4458-
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
4459-
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
4460-
The algorithm:
4461-
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
4462-
and null_placement of "at_end".
4463-
This gives the same table as raw input.
4464-
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
4465-
and null_placement : "at_start".
4466-
This gives:
4467-
[8, 7, 4, 5, 6, 3, 1, 2, 0]
4468-
Based on this we get partition groups of indices:
4469-
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
4470-
We then retrieve the partition keys by offsets.
4471-
And slice the arrow table by offsets and lengths of each partition.
4472-
"""
4473-
import pyarrow as pa
4474-
4475-
partition_columns: List[Tuple[PartitionField, NestedField]] = [
4476-
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
4477-
]
4478-
partition_values_table = pa.table({
4479-
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
4480-
for partition, field in partition_columns
4481-
})
4482-
4483-
# Sort by partitions
4484-
sort_indices = pa.compute.sort_indices(
4485-
partition_values_table,
4486-
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
4487-
null_placement="at_end",
4488-
).to_pylist()
4489-
arrow_table = arrow_table.take(sort_indices)
4490-
4491-
# Get slice_instructions to group by partitions
4492-
partition_values_table = partition_values_table.take(sort_indices)
4493-
reversed_indices = pa.compute.sort_indices(
4494-
partition_values_table,
4495-
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
4496-
null_placement="at_start",
4497-
).to_pylist()
4498-
slice_instructions: List[Dict[str, Any]] = []
4499-
last = len(reversed_indices)
4500-
reversed_indices_size = len(reversed_indices)
4501-
ptr = 0
4502-
while ptr < reversed_indices_size:
4503-
group_size = last - reversed_indices[ptr]
4504-
offset = reversed_indices[ptr]
4505-
slice_instructions.append({"offset": offset, "length": group_size})
4506-
last = reversed_indices[ptr]
4507-
ptr = ptr + group_size
4508-
4509-
table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
4510-
4511-
return table_partitions
4512-
4513-
45144414
class _ManifestMergeManager(Generic[U]):
45154415
_target_size_bytes: int
45164416
_min_count_to_merge: int

tests/io/test_pyarrow.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
PyArrowFileIO,
6262
StatsAggregator,
6363
_ConvertToArrowSchema,
64+
_determine_partitions,
6465
_primitive_to_physical,
6566
_read_deletes,
6667
bin_pack_arrow_table,
@@ -69,11 +70,12 @@
6970
schema_to_pyarrow,
7071
)
7172
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
72-
from pyiceberg.partitioning import PartitionSpec
73+
from pyiceberg.partitioning import PartitionField, PartitionSpec
7374
from pyiceberg.schema import Schema, make_compatible_name, visit
7475
from pyiceberg.table import FileScanTask, TableProperties
7576
from pyiceberg.table.metadata import TableMetadataV2
76-
from pyiceberg.typedef import UTF8
77+
from pyiceberg.transforms import IdentityTransform
78+
from pyiceberg.typedef import UTF8, Record
7779
from pyiceberg.types import (
7880
BinaryType,
7981
BooleanType,
@@ -1718,3 +1720,81 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
17181720
# and will produce half the number of files if we double the target size
17191721
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
17201722
assert len(list(bin_packed)) == 5
1723+
1724+
1725+
def test_partition_for_demo() -> None:
1726+
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
1727+
test_schema = Schema(
1728+
NestedField(field_id=1, name="year", field_type=StringType(), required=False),
1729+
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
1730+
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
1731+
schema_id=1,
1732+
)
1733+
test_data = {
1734+
"year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
1735+
"n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100],
1736+
"animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"],
1737+
}
1738+
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
1739+
partition_spec = PartitionSpec(
1740+
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
1741+
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
1742+
)
1743+
result = _determine_partitions(partition_spec, test_schema, arrow_table)
1744+
assert {table_partition.partition_key.partition for table_partition in result} == {
1745+
Record(n_legs_identity=2, year_identity=2020),
1746+
Record(n_legs_identity=100, year_identity=2021),
1747+
Record(n_legs_identity=4, year_identity=2021),
1748+
Record(n_legs_identity=4, year_identity=2022),
1749+
Record(n_legs_identity=2, year_identity=2022),
1750+
Record(n_legs_identity=5, year_identity=2019),
1751+
}
1752+
assert (
1753+
pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows
1754+
)
1755+
1756+
1757+
def test_identity_partition_on_multi_columns() -> None:
1758+
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
1759+
test_schema = Schema(
1760+
NestedField(field_id=1, name="born_year", field_type=StringType(), required=False),
1761+
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
1762+
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
1763+
schema_id=1,
1764+
)
1765+
# 5 partitions, 6 unique row values, 12 rows
1766+
test_rows = [
1767+
(2021, 4, "Dog"),
1768+
(2022, 4, "Horse"),
1769+
(2022, 4, "Another Horse"),
1770+
(2021, 100, "Centipede"),
1771+
(None, 4, "Kirin"),
1772+
(2021, None, "Fish"),
1773+
] * 2
1774+
expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))}
1775+
partition_spec = PartitionSpec(
1776+
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
1777+
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
1778+
)
1779+
import random
1780+
1781+
# there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
1782+
for _ in range(1000):
1783+
random.shuffle(test_rows)
1784+
test_data = {
1785+
"born_year": [row[0] for row in test_rows],
1786+
"n_legs": [row[1] for row in test_rows],
1787+
"animal": [row[2] for row in test_rows],
1788+
}
1789+
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
1790+
1791+
result = _determine_partitions(partition_spec, test_schema, arrow_table)
1792+
1793+
assert {table_partition.partition_key.partition for table_partition in result} == expected
1794+
concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])
1795+
assert concatenated_arrow_table.num_rows == arrow_table.num_rows
1796+
assert concatenated_arrow_table.sort_by([
1797+
("born_year", "ascending"),
1798+
("n_legs", "ascending"),
1799+
("animal", "ascending"),
1800+
]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")])

0 commit comments

Comments
 (0)