Skip to content

Commit 4f844df

Browse files
vkarampudiResponsible ML Infra Team
authored and
Responsible ML Infra Team
committedJan 8, 2025·
Remove unused keyword arguments to Keras Model.save and Model.load.
PiperOrigin-RevId: 707577369
1 parent fd265d4 commit 4f844df

File tree

3 files changed

+56
-39
lines changed

3 files changed

+56
-39
lines changed
 

‎tensorboard_plugin/setup.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def select_constraint(default, nightly=None, git_master=None):
4343

4444
REQUIRED_PACKAGES = [
4545
'protobuf>=3.20.3,<5',
46-
'tensorboard>=2.15.2,<2.16.0',
47-
'tensorflow>=2.15,<2.16',
46+
'tensorboard>=2.16.2,<2.17.0',
47+
'tensorflow>=2.16,<2.17',
4848
'tensorflow-model-analysis'
4949
+ select_constraint(
50-
default='>=0.46,<0.47',
51-
nightly='>=0.47.0.dev',
50+
default='>=0.47,<0.48',
51+
nightly='>=0.48.0.dev',
5252
git_master='@git+https://github.com/tensorflow/model-analysis@master',
5353
),
5454
'werkzeug<2',

‎tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from tensorboard_plugin_fairness_indicators import metadata
2424
import six
2525
import tensorflow_model_analysis as tfma
26-
from tensorflow_model_analysis.addons.fairness.view import widget_view
26+
# from tensorflow_model_analysis.addons.fairness.view import widget_view
27+
from tensorflow_model_analysis.view import widget_view
2728
from werkzeug import wrappers
2829
from google.protobuf import json_format
2930
from tensorboard.backend import http_util

‎tensorboard_plugin/tensorboard_plugin_fairness_indicators/plugin_test.py

+50-34
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,34 @@
2525
from tensorboard_plugin_fairness_indicators import plugin
2626
from tensorboard_plugin_fairness_indicators import summary_v2
2727
import six
28-
import tensorflow.compat.v1 as tf
29-
import tensorflow.compat.v2 as tf2
28+
import tensorflow as tf2
29+
from tensorflow.keras import layers
3030
import tensorflow_model_analysis as tfma
31-
from tensorflow_model_analysis.eval_saved_model.example_trainers import linear_classifier
3231
from werkzeug import test as werkzeug_test
3332
from werkzeug import wrappers
3433

3534
from tensorboard.backend import application
3635
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
3736
from tensorboard.plugins import base_plugin
3837

39-
tf.enable_eager_execution()
4038
tf = tf2
4139

4240

41+
# Define keras based linear classifier.
42+
def create_linear_classifier(model_dir):
43+
44+
inputs = tf.keras.Input(shape=(2,))
45+
outputs = layers.Dense(1, activation="sigmoid")(inputs)
46+
model = tf.keras.Model(inputs=inputs, outputs=outputs)
47+
48+
model.compile(
49+
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
50+
)
51+
52+
tf.saved_model.save(model, model_dir)
53+
return model
54+
55+
4356
class PluginTest(tf.test.TestCase):
4457
"""Tests for Fairness Indicators plugin server."""
4558

@@ -74,19 +87,19 @@ def tearDown(self):
7487
super(PluginTest, self).tearDown()
7588
shutil.rmtree(self._log_dir, ignore_errors=True)
7689

77-
def _exportEvalSavedModel(self, classifier):
90+
def _export_eval_saved_model(self):
91+
"""Export the evaluation saved model."""
7892
temp_eval_export_dir = os.path.join(self.get_temp_dir(), "eval_export_dir")
79-
_, eval_export_dir = classifier(None, temp_eval_export_dir)
80-
return eval_export_dir
93+
return create_linear_classifier(temp_eval_export_dir)
8194

82-
def _writeTFExamplesToTFRecords(self, examples):
95+
def _write_tf_examples_to_tfrecords(self, examples):
8396
data_location = os.path.join(self.get_temp_dir(), "input_data.rio")
8497
with tf.io.TFRecordWriter(data_location) as writer:
8598
for example in examples:
8699
writer.write(example.SerializeToString())
87100
return data_location
88101

89-
def _makeExample(self, age, language, label):
102+
def _make_tf_example(self, age, language, label):
90103
example = tf.train.Example()
91104
example.features.feature["age"].float_list.value[:] = [age]
92105
example.features.feature["language"].bytes_list.value[:] = [
@@ -112,14 +125,14 @@ def testRoutes(self):
112125
"foo": "".encode("utf-8")
113126
}},
114127
)
115-
def testIsActive(self, get_random_stub):
128+
def testIsActive(self):
116129
self.assertTrue(self._plugin.is_active())
117130

118131
@mock.patch.object(
119132
event_multiplexer.EventMultiplexer,
120133
"PluginRunToTagToContent",
121134
return_value={})
122-
def testIsInactive(self, get_random_stub):
135+
def testIsInactive(self):
123136
self.assertFalse(self._plugin.is_active())
124137

125138
def testIndexJsRoute(self):
@@ -134,16 +147,15 @@ def testVulcanizedTemplateRoute(self):
134147
self.assertEqual(200, response.status_code)
135148

136149
def testGetEvalResultsRoute(self):
137-
model_location = self._exportEvalSavedModel(
138-
linear_classifier.simple_linear_classifier)
150+
model_location = self._export_eval_saved_model() # Call the method
139151
examples = [
140-
self._makeExample(age=3.0, language="english", label=1.0),
141-
self._makeExample(age=3.0, language="chinese", label=0.0),
142-
self._makeExample(age=4.0, language="english", label=1.0),
143-
self._makeExample(age=5.0, language="chinese", label=1.0),
144-
self._makeExample(age=5.0, language="hindi", label=1.0)
152+
self._make_tf_example(age=3.0, language="english", label=1.0),
153+
self._make_tf_example(age=3.0, language="chinese", label=0.0),
154+
self._make_tf_example(age=4.0, language="english", label=1.0),
155+
self._make_tf_example(age=5.0, language="chinese", label=1.0),
156+
self._make_tf_example(age=5.0, language="hindi", label=1.0),
145157
]
146-
data_location = self._writeTFExamplesToTFRecords(examples)
158+
data_location = self._write_tf_examples_to_tfrecords(examples)
147159
_ = tfma.run_model_analysis(
148160
eval_shared_model=tfma.default_eval_shared_model(
149161
eval_saved_model_path=model_location, example_weight_key="age"),
@@ -155,32 +167,36 @@ def testGetEvalResultsRoute(self):
155167
self.assertEqual(200, response.status_code)
156168

157169
def testGetEvalResultsFromURLRoute(self):
158-
model_location = self._exportEvalSavedModel(
159-
linear_classifier.simple_linear_classifier)
170+
model_location = self._export_eval_saved_model() # Call the method
160171
examples = [
161-
self._makeExample(age=3.0, language="english", label=1.0),
162-
self._makeExample(age=3.0, language="chinese", label=0.0),
163-
self._makeExample(age=4.0, language="english", label=1.0),
164-
self._makeExample(age=5.0, language="chinese", label=1.0),
165-
self._makeExample(age=5.0, language="hindi", label=1.0)
172+
self._make_tf_example(age=3.0, language="english", label=1.0),
173+
self._make_tf_example(age=3.0, language="chinese", label=0.0),
174+
self._make_tf_example(age=4.0, language="english", label=1.0),
175+
self._make_tf_example(age=5.0, language="chinese", label=1.0),
176+
self._make_tf_example(age=5.0, language="hindi", label=1.0),
166177
]
167-
data_location = self._writeTFExamplesToTFRecords(examples)
178+
data_location = self._write_tf_examples_to_tfrecords(examples)
168179
_ = tfma.run_model_analysis(
169180
eval_shared_model=tfma.default_eval_shared_model(
170181
eval_saved_model_path=model_location, example_weight_key="age"),
171182
data_location=data_location,
172183
output_path=self._eval_result_output_dir)
173184

174185
response = self._server.get(
175-
"/data/plugin/fairness_indicators/" +
176-
"get_evaluation_result_from_remote_path?evaluation_output_path=" +
177-
os.path.join(self._eval_result_output_dir, tfma.METRICS_KEY))
186+
"/data/plugin/fairness_indicators/"
187+
+ "get_evaluation_result_from_remote_path?evaluation_output_path="
188+
+ self._eval_result_output_dir
189+
)
178190
self.assertEqual(200, response.status_code)
179191

180-
def testGetOutputFileFormat(self):
181-
self.assertEqual("", self._plugin._get_output_file_format("abc_path"))
182-
self.assertEqual("tfrecord",
183-
self._plugin._get_output_file_format("abc_path.tfrecord"))
192+
def test_get_output_file_format(self):
193+
evaluation_output_path = os.path.join(
194+
self._eval_result_output_dir, "eval_result.tfrecord"
195+
)
196+
self.assertEqual(
197+
self._plugin._get_output_file_format(evaluation_output_path),
198+
"tfrecord",
199+
)
184200

185201

186202
if __name__ == "__main__":

0 commit comments

Comments
 (0)
Please sign in to comment.