25
25
from tensorboard_plugin_fairness_indicators import plugin
26
26
from tensorboard_plugin_fairness_indicators import summary_v2
27
27
import six
28
- import tensorflow .compat .v1 as tf
29
- import tensorflow .compat .v2 as tf2
28
+ import tensorflow as tf2
29
+ from tensorflow .keras import layers
30
+ from tensorflow .keras import models
30
31
import tensorflow_model_analysis as tfma
31
- from tensorflow_model_analysis .eval_saved_model .example_trainers import linear_classifier
32
32
from werkzeug import test as werkzeug_test
33
33
from werkzeug import wrappers
34
34
35
35
from tensorboard .backend import application
36
36
from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer
37
37
from tensorboard .plugins import base_plugin
38
38
39
- tf .enable_eager_execution ()
39
+ Sequential = models .Sequential
40
+ Dense = layers .Dense
41
+
40
42
tf = tf2
41
43
42
44
45
+ # Define keras based linear classifier.
46
+ def create_linear_classifier (model_dir ):
47
+
48
+ inputs = tf .keras .Input (shape = (2 ,))
49
+ outputs = layers .Dense (1 , activation = "sigmoid" )(inputs )
50
+ model = tf .keras .Model (inputs = inputs , outputs = outputs )
51
+
52
+ model .compile (
53
+ optimizer = "adam" , loss = "binary_crossentropy" , metrics = ["accuracy" ]
54
+ )
55
+
56
+ tf .saved_model .save (model , model_dir )
57
+ return model
58
+
59
+
43
60
class PluginTest (tf .test .TestCase ):
44
61
"""Tests for Fairness Indicators plugin server."""
45
62
@@ -74,19 +91,19 @@ def tearDown(self):
74
91
super (PluginTest , self ).tearDown ()
75
92
shutil .rmtree (self ._log_dir , ignore_errors = True )
76
93
77
- def _exportEvalSavedModel (self , classifier ):
94
+ def _export_eval_saved_model (self ):
95
+ """Export the evaluation saved model."""
78
96
temp_eval_export_dir = os .path .join (self .get_temp_dir (), "eval_export_dir" )
79
- _ , eval_export_dir = classifier (None , temp_eval_export_dir )
80
- return eval_export_dir
97
+ return create_linear_classifier (temp_eval_export_dir )
81
98
82
- def _writeTFExamplesToTFRecords (self , examples ):
99
+ def _write_tf_examples_to_tfrecords (self , examples ):
83
100
data_location = os .path .join (self .get_temp_dir (), "input_data.rio" )
84
101
with tf .io .TFRecordWriter (data_location ) as writer :
85
102
for example in examples :
86
103
writer .write (example .SerializeToString ())
87
104
return data_location
88
105
89
- def _makeExample (self , age , language , label ):
106
+ def _make_tf_example (self , age , language , label ):
90
107
example = tf .train .Example ()
91
108
example .features .feature ["age" ].float_list .value [:] = [age ]
92
109
example .features .feature ["language" ].bytes_list .value [:] = [
@@ -112,14 +129,14 @@ def testRoutes(self):
112
129
"foo" : "" .encode ("utf-8" )
113
130
}},
114
131
)
115
- def testIsActive (self , get_random_stub ):
132
+ def testIsActive (self ):
116
133
self .assertTrue (self ._plugin .is_active ())
117
134
118
135
@mock .patch .object (
119
136
event_multiplexer .EventMultiplexer ,
120
137
"PluginRunToTagToContent" ,
121
138
return_value = {})
122
- def testIsInactive (self , get_random_stub ):
139
+ def testIsInactive (self ):
123
140
self .assertFalse (self ._plugin .is_active ())
124
141
125
142
def testIndexJsRoute (self ):
@@ -134,16 +151,15 @@ def testVulcanizedTemplateRoute(self):
134
151
self .assertEqual (200 , response .status_code )
135
152
136
153
def testGetEvalResultsRoute (self ):
137
- model_location = self ._exportEvalSavedModel (
138
- linear_classifier .simple_linear_classifier )
154
+ model_location = self ._export_eval_saved_model () # Call the method
139
155
examples = [
140
- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
141
- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
142
- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
143
- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
144
- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
156
+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
157
+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
158
+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
159
+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
160
+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
145
161
]
146
- data_location = self ._writeTFExamplesToTFRecords (examples )
162
+ data_location = self ._write_tf_examples_to_tfrecords (examples )
147
163
_ = tfma .run_model_analysis (
148
164
eval_shared_model = tfma .default_eval_shared_model (
149
165
eval_saved_model_path = model_location , example_weight_key = "age" ),
@@ -155,32 +171,36 @@ def testGetEvalResultsRoute(self):
155
171
self .assertEqual (200 , response .status_code )
156
172
157
173
def testGetEvalResultsFromURLRoute (self ):
158
- model_location = self ._exportEvalSavedModel (
159
- linear_classifier .simple_linear_classifier )
174
+ model_location = self ._export_eval_saved_model () # Call the method
160
175
examples = [
161
- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
162
- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
163
- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
164
- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
165
- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
176
+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
177
+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
178
+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
179
+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
180
+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
166
181
]
167
- data_location = self ._writeTFExamplesToTFRecords (examples )
182
+ data_location = self ._write_tf_examples_to_tfrecords (examples )
168
183
_ = tfma .run_model_analysis (
169
184
eval_shared_model = tfma .default_eval_shared_model (
170
185
eval_saved_model_path = model_location , example_weight_key = "age" ),
171
186
data_location = data_location ,
172
187
output_path = self ._eval_result_output_dir )
173
188
174
189
response = self ._server .get (
175
- "/data/plugin/fairness_indicators/" +
176
- "get_evaluation_result_from_remote_path?evaluation_output_path=" +
177
- os .path .join (self ._eval_result_output_dir , tfma .METRICS_KEY ))
190
+ "/data/plugin/fairness_indicators/"
191
+ + "get_evaluation_result_from_remote_path?evaluation_output_path="
192
+ + self ._eval_result_output_dir
193
+ )
178
194
self .assertEqual (200 , response .status_code )
179
195
180
- def testGetOutputFileFormat (self ):
181
- self .assertEqual ("" , self ._plugin ._get_output_file_format ("abc_path" ))
182
- self .assertEqual ("tfrecord" ,
183
- self ._plugin ._get_output_file_format ("abc_path.tfrecord" ))
196
+ def test_get_output_file_format (self ):
197
+ evaluation_output_path = os .path .join (
198
+ self ._eval_result_output_dir , "eval_result.tfrecord"
199
+ )
200
+ self .assertEqual (
201
+ self ._plugin ._get_output_file_format (evaluation_output_path ),
202
+ "tfrecord" ,
203
+ )
184
204
185
205
186
206
if __name__ == "__main__" :
0 commit comments