Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 975a065

Browse files
authoredOct 18, 2024··
Extend TrivialTaskletElimination for map scope (#1650)
Extend the transformation `TrivialTaskletElimination` for the case where the input or output of the copy-tasklet is a map node. The following SDFG: <img width="266" alt="image" src="https://github.com/user-attachments/assets/6e231bbf-d736-4dcf-b132-2e7d59c26ad5"> is transformed to this SDFG: <img width="343" alt="image" src="https://github.com/user-attachments/assets/82ec07b1-6b3d-421f-bca7-5c4b3bd1f320">
1 parent 4fbeba4 commit 975a065

File tree

2 files changed

+160
-17
lines changed

2 files changed

+160
-17
lines changed
 

‎dace/transformation/dataflow/trivial_tasklet_elimination.py‎

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,62 @@ class TrivialTaskletElimination(transformation.SingleStateTransformation):
1717
"""
1818

1919
read = transformation.PatternNode(nodes.AccessNode)
20+
read_map = transformation.PatternNode(nodes.MapEntry)
2021
tasklet = transformation.PatternNode(nodes.Tasklet)
2122
write = transformation.PatternNode(nodes.AccessNode)
23+
write_map = transformation.PatternNode(nodes.MapExit)
2224

2325
@classmethod
2426
def expressions(cls):
25-
return [sdutil.node_path_graph(cls.read, cls.tasklet, cls.write)]
27+
return [
28+
sdutil.node_path_graph(cls.read, cls.tasklet, cls.write),
29+
sdutil.node_path_graph(cls.read_map, cls.tasklet, cls.write),
30+
sdutil.node_path_graph(cls.read, cls.tasklet, cls.write_map),
31+
]
2632

2733
def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
28-
read = self.read
34+
read = self.read_map if expr_index == 1 else self.read
2935
tasklet = self.tasklet
30-
write = self.write
31-
# Do not apply on Streams
32-
if isinstance(sdfg.arrays[read.data], data.Stream):
33-
return False
34-
if isinstance(sdfg.arrays[write.data], data.Stream):
36+
write = self.write_map if expr_index == 2 else self.write
37+
if len(tasklet.in_connectors) != 1:
3538
return False
3639
if len(graph.in_edges(tasklet)) != 1:
3740
return False
38-
if len(graph.out_edges(tasklet)) != 1:
39-
return False
40-
if graph.edges_between(tasklet, write)[0].data.wcr:
41-
return False
42-
if len(tasklet.in_connectors) != 1:
43-
return False
4441
if len(tasklet.out_connectors) != 1:
4542
return False
43+
if len(graph.out_edges(tasklet)) != 1:
44+
return False
4645
in_conn = list(tasklet.in_connectors.keys())[0]
4746
out_conn = list(tasklet.out_connectors.keys())[0]
4847
if tasklet.code.as_string != f'{out_conn} = {in_conn}':
4948
return False
50-
49+
read_memlet = graph.edges_between(read, tasklet)[0].data
50+
read_desc = sdfg.arrays[read_memlet.data]
51+
write_memlet = graph.edges_between(tasklet, write)[0].data
52+
if write_memlet.wcr:
53+
return False
54+
write_desc = sdfg.arrays[write_memlet.data]
55+
# Do not apply on streams
56+
if isinstance(read_desc, data.Stream):
57+
return False
58+
if isinstance(write_desc, data.Stream):
59+
return False
60+
# Keep copy-tasklet connected to map node if source and destination nodes
61+
# have different data type (implicit type cast)
62+
if expr_index != 0 and read_desc.dtype != write_desc.dtype:
63+
return False
64+
5165
return True
5266

5367
def apply(self, graph, sdfg):
54-
read = self.read
68+
read = self.read_map if self.expr_index == 1 else self.read
5569
tasklet = self.tasklet
56-
write = self.write
70+
write = self.write_map if self.expr_index == 2 else self.write
5771

5872
in_edge = graph.edges_between(read, tasklet)[0]
5973
out_edge = graph.edges_between(tasklet, write)[0]
6074
graph.remove_edge(in_edge)
6175
graph.remove_edge(out_edge)
6276
out_edge.data.other_subset = in_edge.data.subset
63-
graph.add_nedge(read, write, out_edge.data)
77+
graph.add_edge(read, in_edge.src_conn, write, out_edge.dst_conn, out_edge.data)
6478
graph.remove_node(tasklet)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
2+
import dace
3+
from dace.transformation.dataflow.trivial_tasklet_elimination import TrivialTaskletElimination
4+
5+
6+
N = 10
7+
8+
9+
def test_trivial_tasklet():
10+
ty_ = dace.int32
11+
sdfg = dace.SDFG("trivial_tasklet")
12+
sdfg.add_symbol("s", ty_)
13+
sdfg.add_array("v", (N,), ty_)
14+
st = sdfg.add_state()
15+
16+
tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
17+
tmp1_node = st.add_access(tmp1_name)
18+
init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s")
19+
st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data))
20+
21+
tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
22+
tmp2_node = st.add_access(tmp2_name)
23+
copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp")
24+
st.add_edge(tmp1_node, None, copy_tasklet, "inp", dace.Memlet(tmp1_node.data))
25+
st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data))
26+
27+
bcast_tasklet, _, _ = st.add_mapped_tasklet(
28+
"bcast",
29+
dict(i=f"0:{N}"),
30+
inputs={"inp": dace.Memlet(f"{tmp2_node.data}[0]")},
31+
input_nodes={tmp2_node.data: tmp2_node},
32+
code="out = inp",
33+
outputs={"out": dace.Memlet("v[i]")},
34+
external_edges=True,
35+
)
36+
37+
sdfg.validate()
38+
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
39+
assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet}
40+
41+
count = sdfg.apply_transformations_repeated(TrivialTaskletElimination)
42+
assert count == 1
43+
44+
assert len(st.out_edges(tmp1_node)) == 1
45+
assert st.out_edges(tmp1_node)[0].dst == tmp2_node
46+
47+
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
48+
assert tasklet_nodes == {init_tasklet, bcast_tasklet}
49+
50+
51+
def test_trivial_tasklet_with_map():
52+
ty_ = dace.int32
53+
sdfg = dace.SDFG("trivial_tasklet_with_map")
54+
sdfg.add_symbol("s", ty_)
55+
sdfg.add_array("v", (N,), ty_)
56+
st = sdfg.add_state()
57+
58+
tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
59+
tmp1_node = st.add_access(tmp1_name)
60+
init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s")
61+
st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data))
62+
63+
me, mx = st.add_map("bcast", dict(i=f"0:{N}"))
64+
65+
copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp")
66+
st.add_memlet_path(tmp1_node, me, copy_tasklet, dst_conn="inp", memlet=dace.Memlet(f"{tmp1_node.data}[0]"))
67+
tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
68+
tmp2_node = st.add_access(tmp2_name)
69+
st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data))
70+
71+
bcast_tasklet = st.add_tasklet("bcast", {"inp"}, {"out"}, "out = inp")
72+
st.add_edge(tmp2_node, None, bcast_tasklet, "inp", dace.Memlet(tmp2_node.data))
73+
st.add_memlet_path(bcast_tasklet, mx, st.add_access("v"), src_conn="out", memlet=dace.Memlet("v[i]"))
74+
75+
sdfg.validate()
76+
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
77+
assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet}
78+
79+
count = sdfg.apply_transformations_repeated(TrivialTaskletElimination)
80+
assert count == 2
81+
82+
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
83+
assert tasklet_nodes == {init_tasklet}
84+
85+
assert len(st.in_edges(tmp2_node)) == 1
86+
assert st.in_edges(tmp2_node)[0].src == me
87+
88+
assert len(st.out_edges(tmp2_node)) == 1
89+
assert st.out_edges(tmp2_node)[0].dst == mx
90+
91+
92+
def test_trivial_tasklet_with_implicit_cast():
93+
ty32_ = dace.int32
94+
ty64_ = dace.int64
95+
sdfg = dace.SDFG("trivial_tasklet_with_implicit_cast")
96+
sdfg.add_symbol("s", ty32_)
97+
sdfg.add_array("v", (N,), ty32_)
98+
st = sdfg.add_state()
99+
100+
tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty32_, transient=True)
101+
tmp1_node = st.add_access(tmp1_name)
102+
init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s")
103+
st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data))
104+
105+
me, mx = st.add_map("bcast", dict(i=f"0:{N}"))
106+
107+
copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp")
108+
st.add_memlet_path(tmp1_node, me, copy_tasklet, dst_conn="inp", memlet=dace.Memlet(f"{tmp1_node.data}[0]"))
109+
tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty64_, transient=True)
110+
tmp2_node = st.add_access(tmp2_name)
111+
st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data))
112+
113+
bcast_tasklet = st.add_tasklet("bcast", {"inp"}, {"out"}, "out = inp")
114+
st.add_edge(tmp2_node, None, bcast_tasklet, "inp", dace.Memlet(tmp2_node.data))
115+
st.add_memlet_path(bcast_tasklet, mx, st.add_access("v"), src_conn="out", memlet=dace.Memlet("v[i]"))
116+
117+
sdfg.validate()
118+
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
119+
assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet}
120+
121+
# not applied because of data types mismatch on read/write nodes
122+
count = sdfg.apply_transformations_repeated(TrivialTaskletElimination)
123+
assert count == 0
124+
125+
126+
if __name__ == '__main__':
127+
test_trivial_tasklet()
128+
test_trivial_tasklet_with_map()
129+
test_trivial_tasklet_with_implicit_cast()

0 commit comments

Comments
 (0)
Please sign in to comment.