Skip to content

Commit 74aaa62

Browse files
committed
WIP: Add an end-to-end GGNN test.
1 parent b12bc1c commit 74aaa62

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

programl/task/dataflow/BUILD

+10
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ py_binary(
161161
],
162162
)
163163

164+
py_test(
165+
name = "train_ggnn_test",
166+
srcs = ["train_ggnn_test.py"],
167+
data = ["//programl/test/data:reachability_dataflow_dataset"],
168+
deps = [
169+
":train_ggnn",
170+
"//third_party/py/labm8",
171+
],
172+
)
173+
164174
py_binary(
165175
name = "train_lstm",
166176
srcs = ["train_lstm.py"],
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2019-2020 the ProGraML authors.
2+
#
3+
# Contact Chris Cummins <[email protected]>.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import subprocess
17+
import sys
18+
19+
from labm8.py import bazelutil, test
20+
21+
TRAIN_GGNN = bazelutil.DataPath("programl/programl/task/dataflow/train_ggnn")
22+
23+
24+
REACHABILITY_DATAFLOW_DATASET = bazelutil.DataArchive(
25+
"programl/programl/test/data/reachability_dataflow_dataset.tar.bz2"
26+
)
27+
28+
29+
def test_reachability_end_to_end():
30+
with REACHABILITY_DATAFLOW_DATASET as d:
31+
p = subprocess.Popen(
32+
[
33+
TRAIN_GGNN,
34+
f"--path={d}",
35+
"--analysis",
36+
"reachability",
37+
"--limit_max_data_flow_steps",
38+
"--layer_timesteps=10",
39+
"--val_graph_count=10",
40+
"--val_seed=204",
41+
"--train_graph_counts=10,20",
42+
"--batch_size=8",
43+
]
44+
)
45+
p.communicate()
46+
if p.returncode:
47+
sys.exit(1)
48+
49+
50+
if __name__ == "__main__":
51+
test.Main()

0 commit comments

Comments
 (0)