@@ -271,15 +271,17 @@ def test_matmul_grad_xla_kernelparams(self):
271
271
feature_dim = 3
272
272
273
273
def kernel_fn (eq_params , poly_params ):
274
- return (exponentiated_quadratic .ExponentiatedQuadratic (** eq_params ) *
275
- polynomial .Polynomial (** poly_params ))
274
+ return (exponentiated_quadratic .ExponentiatedQuadratic (* eq_params ) *
275
+ polynomial .Polynomial (bias_amplitude = poly_params [0 ],
276
+ shift = poly_params [1 ]))
276
277
278
+ # TODO(b/284106340): Return this to a dictionary.
277
279
kernel_args = (
278
- dict ( length_scale = tf .random .uniform ([], .5 , 1 .5 , dtype = tf .float64 ),
279
- amplitude = tf .random .uniform ([], 1 .5 , 2 .5 , dtype = tf .float64 )),
280
- dict ( bias_amplitude = tf .random .uniform ([feature_dim ], .5 , 1.5 ,
281
- dtype = tf .float64 ),
282
- shift = tf .random .normal ([feature_dim ], dtype = tf .float64 )))
280
+ ( tf .random .uniform ([], 1 .5 , 2 .5 , dtype = tf .float64 ), # amplitude
281
+ tf .random .uniform ([], .5 , 1 .5 , dtype = tf .float64 )), # length_scale
282
+ ( tf .random .uniform ([feature_dim ], .5 , 1.5 , # bias_amplitude
283
+ dtype = tf .float64 ),
284
+ tf .random .normal ([feature_dim ], dtype = tf .float64 ))) # shift
283
285
284
286
x1 = tf .random .normal ([5 , feature_dim ], dtype = tf .float64 )
285
287
x2 = tf .random .normal ([7 , feature_dim ], dtype = tf .float64 )
0 commit comments