Skip to content

Commit c58f9d7

Browse files
committed
WIP: Add an end-to-end GGNN test.
1 parent 068c829 commit c58f9d7

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

programl/task/dataflow/BUILD

+10
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ py_binary(
161161
],
162162
)
163163

164+
py_test(
165+
name = "train_ggnn_test",
166+
srcs = ["train_ggnn_test.py"],
167+
data = ["//programl/test/data:reachability_dataflow_dataset"],
168+
deps = [
169+
":train_ggnn",
170+
"//third_party/py/labm8",
171+
],
172+
)
173+
164174
py_binary(
165175
name = "train_lstm",
166176
srcs = ["train_lstm.py"],
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2019-2020 the ProGraML authors.
2+
#
3+
# Contact Chris Cummins <[email protected]>.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import subprocess
17+
import sys
18+
19+
from labm8.py import bazelutil, test
20+
21+
TRAIN_GGNN = bazelutil.DataPath("programl/programl/task/dataflow/train_ggnn")
22+
23+
24+
REACHABILITY_DATAFLOW_DATASET = bazelutil.DataArchive(
25+
"programl/programl/test/data/reachability_dataflow_dataset.tar.bz2"
26+
)
27+
28+
29+
def test_reachability_end_to_end():
30+
with REACHABILITY_DATAFLOW_DATASET as d:
31+
p = subprocess.Popen(
32+
[
33+
TRAIN_GGNN,
34+
"--path",
35+
str(d),
36+
"--analysis",
37+
"reachability",
38+
"--max_data_flow_steps",
39+
str(10),
40+
"--val_graph_count",
41+
str(10),
42+
"--val_seed",
43+
str(0xCC),
44+
"--train_graph_counts",
45+
"10,20",
46+
"--batch_size",
47+
str(8),
48+
]
49+
)
50+
p.communicate()
51+
if p.returncode:
52+
sys.exit(1)
53+
54+
55+
if __name__ == "__main__":
56+
test.Main()

0 commit comments

Comments
 (0)