Skip to content

Commit 53c4dc6

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Remove pin on typing_extensions.
Fixes #1753 See also tensorflow/tensorflow#60687 and tensorflow/tensorflow#61387 PiperOrigin-RevId: 572657594
1 parent 2577f7a commit 53c4dc6

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

required_packages.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
'cloudpickle>=1.3',
2727
'gast>=0.3.2', # For autobatching
2828
'dm-tree', # For NumPy/JAX backends (hence, also for prefer_static)
29-
'typing-extensions<4.6.0', # TODO(b/284106340): Remove this pin
3029
]
3130

3231
if __name__ == '__main__':

tensorflow_probability/python/experimental/linalg/linear_operator_psd_kernel_test.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,17 @@ def test_matmul_grad_xla_kernelparams(self):
271271
feature_dim = 3
272272

273273
def kernel_fn(eq_params, poly_params):
274-
return (exponentiated_quadratic.ExponentiatedQuadratic(**eq_params) *
275-
polynomial.Polynomial(**poly_params))
274+
return (exponentiated_quadratic.ExponentiatedQuadratic(*eq_params) *
275+
polynomial.Polynomial(bias_amplitude=poly_params[0],
276+
shift=poly_params[1]))
276277

278+
# TODO(b/284106340): Return this to a dictionary.
277279
kernel_args = (
278-
dict(length_scale=tf.random.uniform([], .5, 1.5, dtype=tf.float64),
279-
amplitude=tf.random.uniform([], 1.5, 2.5, dtype=tf.float64)),
280-
dict(bias_amplitude=tf.random.uniform([feature_dim], .5, 1.5,
281-
dtype=tf.float64),
282-
shift=tf.random.normal([feature_dim], dtype=tf.float64)))
280+
(tf.random.uniform([], 1.5, 2.5, dtype=tf.float64), # amplitude
281+
tf.random.uniform([], .5, 1.5, dtype=tf.float64)), # length_scale
282+
(tf.random.uniform([feature_dim], .5, 1.5, # bias_amplitude
283+
dtype=tf.float64),
284+
tf.random.normal([feature_dim], dtype=tf.float64))) # shift
283285

284286
x1 = tf.random.normal([5, feature_dim], dtype=tf.float64)
285287
x2 = tf.random.normal([7, feature_dim], dtype=tf.float64)

0 commit comments

Comments
 (0)