From a10107beff4f166efbd930e952036439e4b12fa3 Mon Sep 17 00:00:00 2001 From: gschnabel <40870991+gschnabel@users.noreply.github.com> Date: Tue, 2 Jul 2024 16:43:30 +0200 Subject: [PATCH] set `experimental_use_pfor=False` in jacobian to avoid excessive memory usage --- gmapy/tf_uq/custom_distributions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gmapy/tf_uq/custom_distributions.py b/gmapy/tf_uq/custom_distributions.py index 99cacd9e..03001964 100644 --- a/gmapy/tf_uq/custom_distributions.py +++ b/gmapy/tf_uq/custom_distributions.py @@ -314,7 +314,7 @@ def _log_prob_hessian_offdiag_part(self, x, covpars): like_data = tf.reshape(self._like_data, (-1, 1)) propvals = tf.reshape(self._propfun(x), (-1, 1)) d = like_data - propvals - with tf.GradientTape(persistent=False) as tape: + with tf.GradientTape(persistent=True) as tape: tape.watch(covpars) j = self._jacfun(x) like_cov = like_cov_fun(covpars, tf.stop_gradient(propvals)) @@ -322,7 +322,7 @@ def _log_prob_hessian_offdiag_part(self, x, covpars): u = tf.sparse.sparse_dense_matmul(j, constvec, adjoint_a=True) u = tf.reshape(u, (-1,)) g = tape.jacobian( - u, covpars, experimental_use_pfor=True, + u, covpars, experimental_use_pfor=False, unconnected_gradients=tf.UnconnectedGradients.ZERO ) return g @@ -336,7 +336,7 @@ def _log_prob_hessian_chisqr_wrt_covpars(self, x, covpars): propvals = tf.reshape(self._propfun(x), (-1, 1)) d = like_data - propvals d = tf.reshape(d, (-1, 1)) - with tf.GradientTape(persistent=False) as tape1: + with tf.GradientTape(persistent=True) as tape1: tape1.watch(covpars) with tf.GradientTape() as tape2: tape2.watch(covpars) @@ -346,7 +346,7 @@ def _log_prob_hessian_chisqr_wrt_covpars(self, x, covpars): u, covpars, unconnected_gradients=tf.UnconnectedGradients.ZERO ) h = tape1.jacobian( - g, covpars, experimental_use_pfor=True, + g, covpars, experimental_use_pfor=False, unconnected_gradients=tf.UnconnectedGradients.ZERO ) return h @@ -357,7 +357,7 @@ def _log_prob_hessian_logdet_wrt_covpars(self, x, covpars): covpars = tf.constant(covpars, dtype=tf.float64) like_cov_fun = self._like_cov_fun propvals = self._propfun(x) - with tf.GradientTape(persistent=False) as tape1: + with tf.GradientTape(persistent=True) as tape1: tape1.watch(covpars) with tf.GradientTape() as tape2: tape2.watch(covpars) @@ -367,7 +367,7 @@ def _log_prob_hessian_logdet_wrt_covpars(self, x, covpars): u, covpars, unconnected_gradients=tf.UnconnectedGradients.ZERO ) h = tape1.jacobian( - g, covpars, experimental_use_pfor=True, + g, covpars, experimental_use_pfor=False, unconnected_gradients=tf.UnconnectedGradients.ZERO ) return h