-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
136 lines (96 loc) · 3.28 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#! /usr/bin/python
from __future__ import print_function
from workflow import WorkFlow
def print_delimeter(c = "=", n = 20, title = "", leading = "\n", ending = "\n"):
d = [c for i in range(n/2)]
if ( 0 == len(title) ):
s = "".join(d) + "".join(d)
else:
s = "".join(d) + " " + title + " " + "".join(d)
print("%s%s%s" % (leading, s, ending))
# Template for custom WorkFlow object.
class MyWF(WorkFlow.WorkFlow):
def __init__(self, workingDir):
super(MyWF, self).__init__(workingDir, prefix = "", suffix = "")
# === Create the AccumulatedObjects. ===
self.add_accumulated_value("lossTest")
# This should raise an exception.
# self.add_accumulated_value("loss")
# === Custom member variables. ===
self.countTrain = 0
self.countTest = 0
# Overload the function initialize().
def initialize(self):
super(MyWF, self).initialize()
# === Custom code. ===
self.logger.info("Initialized.")
# Overload the function train().
def train(self):
super(MyWF, self).train()
# === Custom code. ===
self.logger.info("Train loop #%d" % self.countTrain)
# Test the existance of an AccumulatedValue object.
if ( True == self.have_accumulated_value("loss") ):
self.AV["loss"].push_back(0.01, self.countTrain)
else:
self.logger.info("Could not find \"loss\"")
self.countTrain += 1
self.logger.info("Trained.")
# Overload the function test().
def test(self):
super(MyWF, self).test()
# === Custom code. ===
# Test the existance of an AccumulatedValue object.
if ( True == self.have_accumulated_value("lossTest") ):
self.AV["lossTest"].push_back(0.01, self.countTest)
else:
self.logger.info("Could not find \"lossTest\"")
self.logger.info("Tested.")
# Overload the function finalize().
def finalize(self):
super(MyWF, self).finalize()
# === Custom code. ===
self.logger.info("Finalized.")
if __name__ == "__main__":
print("Hello WorkFlow.")
print_delimeter(title = "Before initialization.")
# Instantiate an object for MyWF.
wf = MyWF("/tmp/WorkFlowDir")
wf.verbose = True
# Trigger an exception.
try:
wf.train()
except WorkFlow.WFException as exp:
print(exp.describe())
# Trigger an exception.
try:
wf.test()
except WorkFlow.WFException as exp:
print(exp.describe())
# Trigger an exception.
try:
wf.finalize()
except WorkFlow.WFException as exp:
print(exp.describe())
# Actual initialization.
print_delimeter(title = "Initialize.")
wf.initialize()
# Trigger an exception.
try:
wf.initialize()
except WorkFlow.WFException as exp:
print(exp.describe())
# Training loop.
print_delimeter(title = "Loop.")
for i in range(5):
wf.train()
# Test and finalize.
print_delimeter(title = "Test and finalize.")
wf.test()
wf.finalize()
# Show the accululated values.
print_delimeter(title = "Accumulated values.")
wf.AV["loss"].show_raw_data()
print_delimeter()
wf.AV["lossTest"].show_raw_data()
print("Done.")