Skip to content

Commit

Permalink
Speedup: use loss only for jac=False
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuen committed Oct 31, 2024
1 parent 2cf3461 commit 7546fc0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
8 changes: 7 additions & 1 deletion egenerator/manager/reconstruction/modules/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,17 @@ def __init__(
# choose reconstruction method depending on the optimizer interface
if reco_optimizer_interface.lower() == "scipy":

# choose function according to jac (default: True)
scipy_loss_function = loss_and_gradients_function
if "jac" in scipy_optimizer_settings:
if not scipy_optimizer_settings["jac"]:
scipy_loss_function = self.parameter_loss_function

def reconstruction_method(data_batch, seed_tensor):
return manager.reconstruct_events(
data_batch,
loss_module,
loss_and_gradients_function=loss_and_gradients_function,
loss_and_gradients_function=scipy_loss_function,
fit_parameter_list=fit_parameter_list,
minimize_in_trafo_space=minimize_in_trafo_space,
seed=seed_tensor,
Expand Down
41 changes: 31 additions & 10 deletions egenerator/manager/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,10 @@ def reconstruct_events(
method.
loss_and_gradients_function : tf.function
The tensorflow function:
f(parameters, data_batch, seed_tensor) -> loss, gradients
if jac=True:
f(parameters, data_batch, seed_tensor) -> loss, gradients
else:
f(parameters, data_batch, seed_tensor) -> loss
Note: it is imperative that this function uses the same settings
for trafo space!
fit_parameter_list : bool or list of bool, optional
Expand Down Expand Up @@ -830,16 +833,34 @@ def reconstruct_events(
x0 = seed_array_trafo[:, fit_parameter_list]

# define helper function
def func(x):
# reshape and convert to proper
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(param_dtype)
loss, grad = loss_and_gradients_function(x, data_batch, seed=seed)
loss = loss.numpy().astype("float64")
grad = grad.numpy().astype("float64")
if jac:

grad_flat = np.reshape(grad, [-1])
return loss, grad_flat
def func(x):
# reshape and convert to proper
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(
param_dtype
)
loss, grad = loss_and_gradients_function(
x, data_batch, seed=seed
)
loss = loss.numpy().astype("float64")
grad = grad.numpy().astype("float64")

grad_flat = np.reshape(grad, [-1])
return loss, grad_flat

else:

def func(x):
# reshape and convert to proper
x = np.reshape(x, param_shape).astype(param_dtype)
seed = np.reshape(seed_array, param_shape_full).astype(
param_dtype
)
loss = loss_and_gradients_function(x, data_batch, seed=seed)
loss = loss.numpy().astype("float64")
return loss

if hessian_function is not None:

Expand Down

0 comments on commit 7546fc0

Please sign in to comment.