-
Notifications
You must be signed in to change notification settings - Fork 112
/
Copy pathtest_datachain_merge.py
94 lines (79 loc) · 2.9 KB
/
test_datachain_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pytest
from datachain.lib.dc import DataChain
@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
@pytest.mark.parametrize("inner", [True, False])
def test_merge_union(cloud_test_catalog, inner, cloud_type):
session = cloud_test_catalog.session
src = cloud_test_catalog.src_uri
dogs = DataChain.from_storage(f"{src}/dogs/*", session=session)
cats = DataChain.from_storage(f"{src}/cats/*", session=session)
dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int})
dogs2 = dogs.map(sig2=lambda: 2, output={"sig2": int})
cats1 = cats.map(sig1=lambda: 1, output={"sig1": int})
merged = (dogs1 | cats1).merge(dogs2, "file.path", inner=inner)
signals = merged.select("file.path", "sig1", "sig2").order_by("file.path").results()
if inner:
assert signals == [
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
("dogs/others/dog4", 1, 2),
]
else:
assert signals == [
("cats/cat1", 1, None),
("cats/cat2", 1, None),
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
("dogs/others/dog4", 1, 2),
]
@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
@pytest.mark.parametrize("inner1", [True, False])
@pytest.mark.parametrize("inner2", [True, False])
@pytest.mark.parametrize("inner3", [True, False])
def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3):
session = cloud_test_catalog.session
src = cloud_test_catalog.src_uri
dogs = DataChain.from_storage(f"{src}/dogs/*", session=session)
cats = DataChain.from_storage(f"{src}/cats/*", session=session)
dogs_and_cats = dogs | cats
dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int})
cats1 = cats.map(sig2=lambda: 2, output={"sig2": int})
dogs2 = dogs_and_cats.merge(dogs1, "file.path", inner=inner1)
cats2 = dogs_and_cats.merge(cats1, "file.path", inner=inner2)
merged = dogs2.merge(cats2, "file.path", inner=inner3)
merged_signals = (
merged.select("file.path", "sig1", "sig2").order_by("file.path").results()
)
if inner1 and inner2 and inner3:
assert merged_signals == []
elif inner1:
assert merged_signals == [
("dogs/dog1", 1, None),
("dogs/dog2", 1, None),
("dogs/dog3", 1, None),
("dogs/others/dog4", 1, None),
]
elif inner2 and inner3:
assert merged_signals == [
("cats/cat1", None, 2),
("cats/cat2", None, 2),
]
else:
assert merged_signals == [
("cats/cat1", None, 2),
("cats/cat2", None, 2),
("dogs/dog1", 1, None),
("dogs/dog2", 1, None),
("dogs/dog3", 1, None),
("dogs/others/dog4", 1, None),
]