diff --git a/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py b/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py index 9ea1a6a9..baaa4969 100644 --- a/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py +++ b/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py @@ -132,7 +132,7 @@ def _get_results_type(primal_out: jax.Array): def _adjoint_promote_dtypes(co_tangent: jax.Array, dtype: jnp.dtype): if co_tangent.dtype != dtype: - warnings.warn(f"Promoting co-tangent dtype from {co_tangent.dtype} to {primal_out.dtype}.") + warnings.warn(f"Promoting co-tangent dtype from {co_tangent.dtype} to {dtype}.") return co_tangent.astype(dtype) # v @ J