From 086403a2b0f290b29a8bcc98096e98d41570183d Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Tue, 10 Sep 2024 10:24:55 +0200 Subject: [PATCH] * add caller info --- debug/add_caller_of_caller_info.py | 24 +++++++++++++++++++ dsa2000_cal/dsa2000_cal/common/alert_utils.py | 24 +++++++++++++++++++ .../dsa2000_cal/common/jvp_linear_op.py | 8 +++++-- 3 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 debug/add_caller_of_caller_info.py diff --git a/debug/add_caller_of_caller_info.py b/debug/add_caller_of_caller_info.py new file mode 100644 index 00000000..2c3b2a1f --- /dev/null +++ b/debug/add_caller_of_caller_info.py @@ -0,0 +1,24 @@ +import inspect + + +def get_grandparent_info(): + # Get the grandparent frame (caller of the caller) + depth = min(2, len(inspect.stack()) - 1) + caller_frame = inspect.stack()[depth] + caller_file = caller_frame.filename + caller_line = caller_frame.lineno + caller_func = caller_frame.function + + return f"at {caller_file}:{caller_line} in {caller_func}" + + +def foo(): + def bar(): + print(get_grandparent_info()) + + bar() +def main(): + foo() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/dsa2000_cal/dsa2000_cal/common/alert_utils.py b/dsa2000_cal/dsa2000_cal/common/alert_utils.py index 1104bcfd..b324c1b1 100644 --- a/dsa2000_cal/dsa2000_cal/common/alert_utils.py +++ b/dsa2000_cal/dsa2000_cal/common/alert_utils.py @@ -1,3 +1,4 @@ +import inspect import json import os import socket @@ -7,6 +8,29 @@ import requests +def get_grandparent_info(relative_depth: int = 0): + """ + Get the file, line number and function name of the caller of the caller of this function. + + Args: + relative_depth: the number of frames to go back from the caller of this function. Default is 0. + This is interpreted as the number of frames to go back from the caller of the caller of this function. + 0 means the caller of the caller of this function, 1 means the caller of the caller of the caller of this + function, and so on. + + Returns: + str: a string with the file, line number and function name of the caller of the caller of this function. + """ + # Get the grandparent frame (caller of the caller) + depth = min(1 + relative_depth, len(inspect.stack()) - 1) + caller_frame = inspect.stack()[depth] + caller_file = caller_frame.filename + caller_line = caller_frame.lineno + caller_func = caller_frame.function + + return f"at {caller_file}:{caller_line} in {caller_func}" + + def post_completed_forward_modelling_run(run_dir: str, start_time: datetime, duration: timedelta, hook_url: str | None = None): """ diff --git a/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py b/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py index 319e8c67..26a12cf8 100644 --- a/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py +++ b/dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py @@ -6,6 +6,8 @@ import jax.numpy as jnp import numpy as np +from dsa2000_cal.common.alert_utils import get_grandparent_info + def isinstance_namedtuple(obj) -> bool: """ @@ -171,7 +173,9 @@ def _get_results_type(primal_out: jax.Array): def _adjoint_promote_dtypes(co_tangent: jax.Array, dtype: jnp.dtype): if co_tangent.dtype != dtype: - warnings.warn(f"Promoting co-tangent dtype from {co_tangent.dtype} to {dtype}.") + warnings.warn( + f"Promoting co-tangent dtype from {co_tangent.dtype} to {dtype}, {get_grandparent_info(2)}." + ) return co_tangent.astype(dtype) # v @ J @@ -197,7 +201,7 @@ def _adjoint_promote_dtypes(co_tangent: jax.Array, dtype: jnp.dtype): def _promote_dtype(primal: jax.Array, dtype: jnp.dtype): if primal.dtype != dtype: - warnings.warn(f"Promoting primal dtype from {primal.dtype} to {dtype}.") + warnings.warn(f"Promoting primal dtype from {primal.dtype} to {dtype}, at {get_grandparent_info(2)}.") return primal.astype(dtype) def _get_result_type(primal: jax.Array):