Skip to content

Commit 5ec2bcd

Browse files
Unnest2 (infiniflow#2514)
### What problem does this PR solve? Support group by unnest column and filter unnest column. Add unnest complext type. Add more unnest test. Add unnest pysdk test and docs. Fix test case in infiniflow#2509. ### Type of change - [x] Fix bug. - [x] New Feature (non-breaking change which adds functionality) - [x] Test cases
1 parent 8607e3d commit 5ec2bcd

24 files changed

+790
-203
lines changed

docs/references/pysdk_api_reference.md

+8
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,7 @@ A non-empty list of strings specifying the columns to include in the output. Eac
17181718
- `sum`
17191719
- `avg`
17201720
- An arithmetic function: Apply an arithmetic operation on specified columns (e.g., `c1+5`).
1721+
- An unnest function: Unnest an array column to multiple rows (e.g., `unnest(c1)`).
17211722

17221723
:::tip NOTE
17231724
The list must contain at least one element. Empty lists are not allowed.
@@ -1750,6 +1751,13 @@ table_object.output(["num", "body"]).to_df()
17501751
table_object.output(["_row_id"]).to_pl()
17511752
```
17521753

1754+
##### Select unnest columns to
1755+
1756+
```python
1757+
# Select column "c1" and unnest its cells
1758+
table_object.output(["unnest(c1)"]).to_pl()
1759+
```
1760+
17531761
##### Perform aggregation or arithmetic operations on selected columns
17541762

17551763
```python

python/test_pysdk/test_unnest.py

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import importlib
2+
import sys
3+
import os
4+
import os
5+
import pandas as pd
6+
import pytest
7+
from common import common_values
8+
import infinity
9+
import infinity.index as index
10+
import infinity_embedded
11+
from numpy import dtype
12+
from infinity.errors import ErrorCode
13+
from infinity.common import ConflictType, SortType, Array
14+
15+
current_dir = os.path.dirname(os.path.abspath(__file__))
16+
parent_dir = os.path.dirname(current_dir)
17+
if parent_dir not in sys.path:
18+
sys.path.insert(0, parent_dir)
19+
from infinity_http import infinity_http
20+
from common.utils import copy_data
21+
from datetime import date, time, datetime
22+
23+
24+
@pytest.fixture(scope="class")
25+
def local_infinity(request):
26+
return request.config.getoption("--local-infinity")
27+
28+
29+
@pytest.fixture(scope="class")
30+
def http(request):
31+
return request.config.getoption("--http")
32+
33+
34+
@pytest.fixture(scope="class")
35+
def setup_class(request, local_infinity, http):
36+
if local_infinity:
37+
module = importlib.import_module("infinity_embedded.index")
38+
globals()["index"] = module
39+
module = importlib.import_module("infinity_embedded.common")
40+
func = getattr(module, "ConflictType")
41+
globals()["ConflictType"] = func
42+
func = getattr(module, "InfinityException")
43+
globals()["InfinityException"] = func
44+
uri = common_values.TEST_LOCAL_PATH
45+
request.cls.infinity_obj = infinity_embedded.connect(uri)
46+
elif http:
47+
uri = common_values.TEST_LOCAL_HOST
48+
request.cls.infinity_obj = infinity_http()
49+
else:
50+
uri = common_values.TEST_LOCAL_HOST
51+
request.cls.infinity_obj = infinity.connect(uri)
52+
request.cls.uri = uri
53+
yield
54+
request.cls.infinity_obj.disconnect()
55+
56+
57+
@pytest.mark.usefixtures("setup_class")
58+
@pytest.mark.usefixtures("suffix")
59+
class TestInfinity:
60+
61+
# test/sql/dql/unnest/test_unnest.slt
62+
@pytest.mark.usefixtures("skip_if_http")
63+
@pytest.mark.usefixtures("skip_if_local_infinity")
64+
def test_unnest(self, suffix):
65+
db_obj = self.infinity_obj.get_database("default_db")
66+
67+
table_name = "test_unnest" + suffix
68+
db_obj.drop_table(table_name, conflict_type=ConflictType.Ignore)
69+
table_obj = db_obj.create_table(
70+
table_name, {"c1": {"type": "int"}, "c2": {"type": "array, int"}}
71+
)
72+
table_obj.insert(
73+
[
74+
{"c1": 1, "c2": Array(0, 1)},
75+
{"c1": 2, "c2": Array(2, 3)},
76+
{"c1": 3, "c2": Array(0, 1, 2)},
77+
{"c1": 4, "c2": Array(0, 2, 3)},
78+
{"c1": 5, "c2": Array()},
79+
]
80+
)
81+
82+
res, extra_result = table_obj.output(["unnest(c2)"]).to_df()
83+
gt = pd.DataFrame({"unnest(c2)": [0, 0, 0, 1, 1, 2, 2, 2, 3, 3]}).astype(
84+
dtype("int32")
85+
)
86+
pd.testing.assert_frame_equal(
87+
res.sort_values(by=res.columns.tolist()).reset_index(drop=True), gt
88+
)
89+
90+
res, extra_result = table_obj.output(["c1", "unnest(c2)"]).to_df()
91+
gt = pd.DataFrame(
92+
{
93+
"c1": [1, 1, 2, 2, 3, 3, 3, 4, 4, 4],
94+
"unnest(c2)": [0, 1, 2, 3, 0, 1, 2, 0, 2, 3],
95+
}
96+
).astype({"c1": dtype("int32"), "unnest(c2)": dtype("int32")})
97+
pd.testing.assert_frame_equal(
98+
res.sort_values(by=res.columns.tolist()).reset_index(drop=True), gt
99+
)
100+
101+
res, extra_result = table_obj.output(["c1", "c2", "unnest(c2)"]).to_df()
102+
gt = pd.DataFrame(
103+
{
104+
"c1": [1, 1, 2, 2, 3, 3, 3, 4, 4, 4],
105+
"c2": [
106+
[0, 1],
107+
[0, 1],
108+
[2, 3],
109+
[2, 3],
110+
[0, 1, 2],
111+
[0, 1, 2],
112+
[0, 1, 2],
113+
[0, 2, 3],
114+
[0, 2, 3],
115+
[0, 2, 3],
116+
],
117+
"unnest(c2)": [0, 1, 2, 3, 0, 1, 2, 0, 2, 3],
118+
}
119+
).astype({"c1": dtype("int32"), "unnest(c2)": dtype("int32")})
120+
pd.testing.assert_frame_equal(res, gt)
121+
122+
res, extra_result = (
123+
table_obj.output(["c1", "unnest(c2) as uc2"]).filter("c1 > 2").to_df()
124+
)
125+
gt = pd.DataFrame(
126+
{
127+
"c1": [3, 3, 3, 4, 4, 4],
128+
"uc2": [0, 1, 2, 0, 2, 3],
129+
}
130+
).astype({"c1": dtype("int32"), "uc2": dtype("int32")})
131+
pd.testing.assert_frame_equal(
132+
res.sort_values(by=res.columns.tolist()).reset_index(drop=True), gt
133+
)
134+
135+
res, extra_result = (
136+
table_obj.output(["c1", "unnest(c2) as uc2"]).filter("uc2 > 1").to_df()
137+
)
138+
gt = pd.DataFrame(
139+
{
140+
"c1": [2, 2, 3, 4, 4],
141+
"uc2": [2, 3, 2, 2, 3],
142+
}
143+
).astype({"c1": dtype("int32"), "uc2": dtype("int32")})
144+
pd.testing.assert_frame_equal(
145+
res.sort_values(by=res.columns.tolist()).reset_index(drop=True), gt
146+
)
147+
148+
res, extra_result = (
149+
table_obj.output(["unnest(c2) as uc2", "sum(c1)"]).group_by("uc2").to_df()
150+
)
151+
gt = pd.DataFrame(
152+
{
153+
"uc2": [0, 1, 2, 3],
154+
"sum(c1)": [8, 4, 9, 6],
155+
}
156+
).astype({"uc2": dtype("int32"), "sum(c1)": dtype("int64")})
157+
pd.testing.assert_frame_equal(
158+
res.sort_values(by=res.columns.tolist()).reset_index(drop=True), gt
159+
)
160+
161+
db_obj.drop_table(table_name)

0 commit comments

Comments
 (0)