Skip to content

Commit

Permalink
set experimental_use_pfor=False in jacobian to avoid excessive memo…
Browse files Browse the repository at this point in the history
…ry usage
  • Loading branch information
gschnabel committed Jul 2, 2024
1 parent fef6b33 commit a10107b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions gmapy/tf_uq/custom_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,15 @@ def _log_prob_hessian_offdiag_part(self, x, covpars):
like_data = tf.reshape(self._like_data, (-1, 1))
propvals = tf.reshape(self._propfun(x), (-1, 1))
d = like_data - propvals
with tf.GradientTape(persistent=False) as tape:
with tf.GradientTape(persistent=True) as tape:
tape.watch(covpars)
j = self._jacfun(x)
like_cov = like_cov_fun(covpars, tf.stop_gradient(propvals))
constvec = like_cov.solve(d)
u = tf.sparse.sparse_dense_matmul(j, constvec, adjoint_a=True)
u = tf.reshape(u, (-1,))
g = tape.jacobian(
u, covpars, experimental_use_pfor=True,
u, covpars, experimental_use_pfor=False,
unconnected_gradients=tf.UnconnectedGradients.ZERO
)
return g
Expand All @@ -336,7 +336,7 @@ def _log_prob_hessian_chisqr_wrt_covpars(self, x, covpars):
propvals = tf.reshape(self._propfun(x), (-1, 1))
d = like_data - propvals
d = tf.reshape(d, (-1, 1))
with tf.GradientTape(persistent=False) as tape1:
with tf.GradientTape(persistent=True) as tape1:
tape1.watch(covpars)
with tf.GradientTape() as tape2:
tape2.watch(covpars)
Expand All @@ -346,7 +346,7 @@ def _log_prob_hessian_chisqr_wrt_covpars(self, x, covpars):
u, covpars, unconnected_gradients=tf.UnconnectedGradients.ZERO
)
h = tape1.jacobian(
g, covpars, experimental_use_pfor=True,
g, covpars, experimental_use_pfor=False,
unconnected_gradients=tf.UnconnectedGradients.ZERO
)
return h
Expand All @@ -357,7 +357,7 @@ def _log_prob_hessian_logdet_wrt_covpars(self, x, covpars):
covpars = tf.constant(covpars, dtype=tf.float64)
like_cov_fun = self._like_cov_fun
propvals = self._propfun(x)
with tf.GradientTape(persistent=False) as tape1:
with tf.GradientTape(persistent=True) as tape1:
tape1.watch(covpars)
with tf.GradientTape() as tape2:
tape2.watch(covpars)
Expand All @@ -367,7 +367,7 @@ def _log_prob_hessian_logdet_wrt_covpars(self, x, covpars):
u, covpars, unconnected_gradients=tf.UnconnectedGradients.ZERO
)
h = tape1.jacobian(
g, covpars, experimental_use_pfor=True,
g, covpars, experimental_use_pfor=False,
unconnected_gradients=tf.UnconnectedGradients.ZERO
)
return h
Expand Down

0 comments on commit a10107b

Please sign in to comment.