From 15b6001d7553951e7cf195cb0ed0bc3f2561e44e Mon Sep 17 00:00:00 2001 From: gschnabel <40870991+gschnabel@users.noreply.github.com> Date: Sun, 30 Jun 2024 21:13:46 +0200 Subject: [PATCH] update test to use original LinearOperatorLowRankUpdate class --- tests/test_tf_uq_custom_distributions.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/test_tf_uq_custom_distributions.py b/tests/test_tf_uq_custom_distributions.py index 5b75b583..998b345c 100644 --- a/tests/test_tf_uq_custom_distributions.py +++ b/tests/test_tf_uq_custom_distributions.py @@ -29,7 +29,7 @@ ) from gmapy.mappings.tf.restricted_map import RestrictedMap from gmapy.mappings.helperfuns.numeric_jacobian import numeric_jacobian -from gmapy.tf_uq.custom_linear_operators import MyLinearOperatorLowRankUpdate +# from gmapy.tf_uq.custom_linear_operators import MyLinearOperatorLowRankUpdate class TestTfUQCustomDistributions(unittest.TestCase): @@ -201,8 +201,11 @@ def create_Smat(self, exptable): def create_like_cov_fun(self, expcov_linop, Smat): def like_cov_fun(u): - covop = MyLinearOperatorLowRankUpdate( - expcov_linop, Smat, u + covop = tf.linalg.LinearOperatorLowRankUpdate( + expcov_linop, Smat, u, + is_self_adjoint=expcov_linop.is_self_adjoint, + is_positive_definite=expcov_linop.is_positive_definite, + is_diag_update_positive=True ) return covop return like_cov_fun @@ -222,10 +225,14 @@ def test_hessian_of_likelihood_with_covpars(self): tf.linalg.LinearOperatorLowerTriangular(p) for p in expchol_list ] expcov_linop_list = [ - tf.linalg.LinearOperatorComposition([p, p.adjoint()]) + tf.linalg.LinearOperatorComposition( + [p, p.adjoint()], is_self_adjoint=True, is_positive_definite=True + ) for p in expchol_linop_list ] - expcov_linop = tf.linalg.LinearOperatorBlockDiag(expcov_linop_list) + expcov_linop = tf.linalg.LinearOperatorBlockDiag( + expcov_linop_list, is_self_adjoint=True, is_positive_definite=True + ) Smat = self.create_Smat(self._exptable) like_cov_fun = self.create_like_cov_fun(expcov_linop, Smat) num_params = len(priorvals)