Skip to content

Commit a446360

Browse files
jburnimtensorflower-gardener
authored andcommitted
In open source, import tf_keras instead of setting TF_USE_LEGACY_KERAS=1.
With this change, other imported libraries are free to use Keras 3 while TFP uses Keras 2. PiperOrigin-RevId: 578625610
1 parent ef27f46 commit a446360

File tree

110 files changed

+835
-659
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+835
-659
lines changed

STYLE_GUIDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ they supersede all previous conventions.
187187
1. Submodule names should be singular, except where they overlap to TF.
188188
189189
Justification: Having plural looks strange in user code, ie,
190-
tf.optimizer.Foo reads nicer than tf.keras.optimizers.Foo since submodules
190+
tf.optimizer.Foo reads nicer than tf_keras.optimizers.Foo since submodules
191191
are only used to access a single, specific thing (at a time).
192192
193193
1. Use `tf.newaxis` rather than `None` to `tf.expand_dims`.

SUBSTRATES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ vmap, etc.), we will special-case using an `if JAX_MODE:` block.
7575
tests, TFP impl, etc), with `tfp.math.value_and_gradient` or similar. Then,
7676
we can special-case `JAX_MODE` inside the body of `value_and_gradient`.
7777

78-
* __`tf.Variable`, `tf.keras.optimizers.Optimizer`__
78+
* __`tf.Variable`, `tf_keras.optimizers.Optimizer`__
7979

8080
TF provides a `Variable` abstraction so that graph functions may modify
8181
state, including using the Keras `Optimizer` subclasses like `Adam`. JAX,

tensorflow_probability/examples/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ py_library(
8484
# six dep,
8585
# tensorflow dep,
8686
"//tensorflow_probability",
87+
"//tensorflow_probability/python/internal:tf_keras",
8788
],
8889
)
8990

tensorflow_probability/examples/bayesian_neural_network.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import numpy as np
3838
import tensorflow.compat.v2 as tf
3939
import tensorflow_probability as tfp
40+
from tensorflow_probability.python.internal import tf_keras
4041

4142
tf.enable_v2_behavior()
4243

@@ -174,26 +175,26 @@ def create_model():
174175
# and two fully connected dense layers. We use the Flipout
175176
# Monte Carlo estimator for these layers, which enables lower variance
176177
# stochastic gradients than naive reparameterization.
177-
model = tf.keras.models.Sequential([
178+
model = tf_keras.models.Sequential([
178179
tfp.layers.Convolution2DFlipout(
179180
6, kernel_size=5, padding='SAME',
180181
kernel_divergence_fn=kl_divergence_function,
181182
activation=tf.nn.relu),
182-
tf.keras.layers.MaxPooling2D(
183+
tf_keras.layers.MaxPooling2D(
183184
pool_size=[2, 2], strides=[2, 2],
184185
padding='SAME'),
185186
tfp.layers.Convolution2DFlipout(
186187
16, kernel_size=5, padding='SAME',
187188
kernel_divergence_fn=kl_divergence_function,
188189
activation=tf.nn.relu),
189-
tf.keras.layers.MaxPooling2D(
190+
tf_keras.layers.MaxPooling2D(
190191
pool_size=[2, 2], strides=[2, 2],
191192
padding='SAME'),
192193
tfp.layers.Convolution2DFlipout(
193194
120, kernel_size=5, padding='SAME',
194195
kernel_divergence_fn=kl_divergence_function,
195196
activation=tf.nn.relu),
196-
tf.keras.layers.Flatten(),
197+
tf_keras.layers.Flatten(),
197198
tfp.layers.DenseFlipout(
198199
84, kernel_divergence_fn=kl_divergence_function,
199200
activation=tf.nn.relu),
@@ -203,7 +204,7 @@ def create_model():
203204
])
204205

205206
# Model compilation.
206-
optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate)
207+
optimizer = tf_keras.optimizers.Adam(lr=FLAGS.learning_rate)
207208
# We use the categorical_crossentropy loss since the MNIST dataset contains
208209
# ten labels. The Keras API will then automatically add the
209210
# Kullback-Leibler divergence (contained on the individual layers of
@@ -214,7 +215,7 @@ def create_model():
214215
return model
215216

216217

217-
class MNISTSequence(tf.keras.utils.Sequence):
218+
class MNISTSequence(tf_keras.utils.Sequence):
218219
"""Produces a sequence of MNIST digits with labels."""
219220

220221
def __init__(self, data=None, batch_size=128, fake_data_size=None):
@@ -272,7 +273,7 @@ def __preprocessing(images, labels):
272273
images = 2 * (images / 255.) - 1.
273274
images = images[..., tf.newaxis]
274275

275-
labels = tf.keras.utils.to_categorical(labels)
276+
labels = tf_keras.utils.to_categorical(labels)
276277
return images, labels
277278

278279
def __len__(self):
@@ -298,7 +299,7 @@ def main(argv):
298299
heldout_seq = MNISTSequence(batch_size=FLAGS.batch_size,
299300
fake_data_size=NUM_HELDOUT_EXAMPLES)
300301
else:
301-
train_set, heldout_set = tf.keras.datasets.mnist.load_data()
302+
train_set, heldout_set = tf_keras.datasets.mnist.load_data()
302303
train_seq = MNISTSequence(data=train_set, batch_size=FLAGS.batch_size)
303304
heldout_seq = MNISTSequence(data=heldout_set, batch_size=FLAGS.batch_size)
304305

tensorflow_probability/examples/cifar10_bnn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from tensorflow_probability.examples.models.bayesian_resnet import bayesian_resnet
4848
from tensorflow_probability.examples.models.bayesian_vgg import bayesian_vgg
4949

50+
from tensorflow_probability.python.internal import tf_keras
51+
5052
matplotlib.use("Agg")
5153
warnings.simplefilter(action="ignore")
5254
tfd = tfp.distributions
@@ -169,7 +171,7 @@ def main(argv):
169171
if FLAGS.fake_data:
170172
(x_train, y_train), (x_test, y_test) = build_fake_data()
171173
else:
172-
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
174+
(x_train, y_train), (x_test, y_test) = tf_keras.datasets.cifar10.load_data()
173175

174176
(images, labels, handle,
175177
training_iterator,

0 commit comments

Comments
 (0)