25
25
from tensorboard_plugin_fairness_indicators import plugin
26
26
from tensorboard_plugin_fairness_indicators import summary_v2
27
27
import six
28
- import tensorflow . compat . v1 as tf
29
- import tensorflow .compat . v2 as tf2
28
+ import tensorflow as tf2
29
+ from tensorflow .keras import layers
30
30
import tensorflow_model_analysis as tfma
31
- from tensorflow_model_analysis .eval_saved_model .example_trainers import linear_classifier
32
31
from werkzeug import test as werkzeug_test
33
32
from werkzeug import wrappers
34
33
35
34
from tensorboard .backend import application
36
35
from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer
37
36
from tensorboard .plugins import base_plugin
38
37
39
- tf .enable_eager_execution ()
40
38
tf = tf2
41
39
42
40
41
+ # Define keras based linear classifier.
42
+ def create_linear_classifier (model_dir ):
43
+
44
+ inputs = tf .keras .Input (shape = (2 ,))
45
+ outputs = layers .Dense (1 , activation = "sigmoid" )(inputs )
46
+ model = tf .keras .Model (inputs = inputs , outputs = outputs )
47
+
48
+ model .compile (
49
+ optimizer = "adam" , loss = "binary_crossentropy" , metrics = ["accuracy" ]
50
+ )
51
+
52
+ tf .saved_model .save (model , model_dir )
53
+ return model
54
+
55
+
43
56
class PluginTest (tf .test .TestCase ):
44
57
"""Tests for Fairness Indicators plugin server."""
45
58
@@ -74,19 +87,19 @@ def tearDown(self):
74
87
super (PluginTest , self ).tearDown ()
75
88
shutil .rmtree (self ._log_dir , ignore_errors = True )
76
89
77
- def _exportEvalSavedModel (self , classifier ):
90
+ def _export_eval_saved_model (self ):
91
+ """Export the evaluation saved model."""
78
92
temp_eval_export_dir = os .path .join (self .get_temp_dir (), "eval_export_dir" )
79
- _ , eval_export_dir = classifier (None , temp_eval_export_dir )
80
- return eval_export_dir
93
+ return create_linear_classifier (temp_eval_export_dir )
81
94
82
- def _writeTFExamplesToTFRecords (self , examples ):
95
+ def _write_tf_examples_to_tfrecords (self , examples ):
83
96
data_location = os .path .join (self .get_temp_dir (), "input_data.rio" )
84
97
with tf .io .TFRecordWriter (data_location ) as writer :
85
98
for example in examples :
86
99
writer .write (example .SerializeToString ())
87
100
return data_location
88
101
89
- def _makeExample (self , age , language , label ):
102
+ def _make_tf_example (self , age , language , label ):
90
103
example = tf .train .Example ()
91
104
example .features .feature ["age" ].float_list .value [:] = [age ]
92
105
example .features .feature ["language" ].bytes_list .value [:] = [
@@ -112,14 +125,14 @@ def testRoutes(self):
112
125
"foo" : "" .encode ("utf-8" )
113
126
}},
114
127
)
115
- def testIsActive (self , get_random_stub ):
128
+ def testIsActive (self ):
116
129
self .assertTrue (self ._plugin .is_active ())
117
130
118
131
@mock .patch .object (
119
132
event_multiplexer .EventMultiplexer ,
120
133
"PluginRunToTagToContent" ,
121
134
return_value = {})
122
- def testIsInactive (self , get_random_stub ):
135
+ def testIsInactive (self ):
123
136
self .assertFalse (self ._plugin .is_active ())
124
137
125
138
def testIndexJsRoute (self ):
@@ -134,16 +147,15 @@ def testVulcanizedTemplateRoute(self):
134
147
self .assertEqual (200 , response .status_code )
135
148
136
149
def testGetEvalResultsRoute (self ):
137
- model_location = self ._exportEvalSavedModel (
138
- linear_classifier .simple_linear_classifier )
150
+ model_location = self ._export_eval_saved_model () # Call the method
139
151
examples = [
140
- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
141
- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
142
- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
143
- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
144
- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
152
+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
153
+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
154
+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
155
+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
156
+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
145
157
]
146
- data_location = self ._writeTFExamplesToTFRecords (examples )
158
+ data_location = self ._write_tf_examples_to_tfrecords (examples )
147
159
_ = tfma .run_model_analysis (
148
160
eval_shared_model = tfma .default_eval_shared_model (
149
161
eval_saved_model_path = model_location , example_weight_key = "age" ),
@@ -155,32 +167,36 @@ def testGetEvalResultsRoute(self):
155
167
self .assertEqual (200 , response .status_code )
156
168
157
169
def testGetEvalResultsFromURLRoute (self ):
158
- model_location = self ._exportEvalSavedModel (
159
- linear_classifier .simple_linear_classifier )
170
+ model_location = self ._export_eval_saved_model () # Call the method
160
171
examples = [
161
- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
162
- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
163
- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
164
- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
165
- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
172
+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
173
+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
174
+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
175
+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
176
+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
166
177
]
167
- data_location = self ._writeTFExamplesToTFRecords (examples )
178
+ data_location = self ._write_tf_examples_to_tfrecords (examples )
168
179
_ = tfma .run_model_analysis (
169
180
eval_shared_model = tfma .default_eval_shared_model (
170
181
eval_saved_model_path = model_location , example_weight_key = "age" ),
171
182
data_location = data_location ,
172
183
output_path = self ._eval_result_output_dir )
173
184
174
185
response = self ._server .get (
175
- "/data/plugin/fairness_indicators/" +
176
- "get_evaluation_result_from_remote_path?evaluation_output_path=" +
177
- os .path .join (self ._eval_result_output_dir , tfma .METRICS_KEY ))
186
+ "/data/plugin/fairness_indicators/"
187
+ + "get_evaluation_result_from_remote_path?evaluation_output_path="
188
+ + self ._eval_result_output_dir
189
+ )
178
190
self .assertEqual (200 , response .status_code )
179
191
180
- def testGetOutputFileFormat (self ):
181
- self .assertEqual ("" , self ._plugin ._get_output_file_format ("abc_path" ))
182
- self .assertEqual ("tfrecord" ,
183
- self ._plugin ._get_output_file_format ("abc_path.tfrecord" ))
192
+ def test_get_output_file_format (self ):
193
+ evaluation_output_path = os .path .join (
194
+ self ._eval_result_output_dir , "eval_result.tfrecord"
195
+ )
196
+ self .assertEqual (
197
+ self ._plugin ._get_output_file_format (evaluation_output_path ),
198
+ "tfrecord" ,
199
+ )
184
200
185
201
186
202
if __name__ == "__main__" :
0 commit comments