Skip to content

Commit

Permalink
* add caller info
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 10, 2024
1 parent db4e79f commit 086403a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
24 changes: 24 additions & 0 deletions debug/add_caller_of_caller_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import inspect


def get_grandparent_info():
# Get the grandparent frame (caller of the caller)
depth = min(2, len(inspect.stack()) - 1)
caller_frame = inspect.stack()[depth]
caller_file = caller_frame.filename
caller_line = caller_frame.lineno
caller_func = caller_frame.function

return f"at {caller_file}:{caller_line} in {caller_func}"


def foo():
def bar():
print(get_grandparent_info())

bar()
def main():
foo()

if __name__ == '__main__':
main()
24 changes: 24 additions & 0 deletions dsa2000_cal/dsa2000_cal/common/alert_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
import os
import socket
Expand All @@ -7,6 +8,29 @@
import requests


def get_grandparent_info(relative_depth: int = 0):
"""
Get the file, line number and function name of the caller of the caller of this function.
Args:
relative_depth: the number of frames to go back from the caller of this function. Default is 0.
This is interpreted as the number of frames to go back from the caller of the caller of this function.
0 means the caller of the caller of this function, 1 means the caller of the caller of the caller of this
function, and so on.
Returns:
str: a string with the file, line number and function name of the caller of the caller of this function.
"""
# Get the grandparent frame (caller of the caller)
depth = min(1 + relative_depth, len(inspect.stack()) - 1)
caller_frame = inspect.stack()[depth]
caller_file = caller_frame.filename
caller_line = caller_frame.lineno
caller_func = caller_frame.function

return f"at {caller_file}:{caller_line} in {caller_func}"


def post_completed_forward_modelling_run(run_dir: str, start_time: datetime, duration: timedelta,
hook_url: str | None = None):
"""
Expand Down
8 changes: 6 additions & 2 deletions dsa2000_cal/dsa2000_cal/common/jvp_linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import jax.numpy as jnp
import numpy as np

from dsa2000_cal.common.alert_utils import get_grandparent_info


def isinstance_namedtuple(obj) -> bool:
"""
Expand Down Expand Up @@ -171,7 +173,9 @@ def _get_results_type(primal_out: jax.Array):

def _adjoint_promote_dtypes(co_tangent: jax.Array, dtype: jnp.dtype):
if co_tangent.dtype != dtype:
warnings.warn(f"Promoting co-tangent dtype from {co_tangent.dtype} to {dtype}.")
warnings.warn(
f"Promoting co-tangent dtype from {co_tangent.dtype} to {dtype}, {get_grandparent_info(2)}."
)
return co_tangent.astype(dtype)

# v @ J
Expand All @@ -197,7 +201,7 @@ def _adjoint_promote_dtypes(co_tangent: jax.Array, dtype: jnp.dtype):

def _promote_dtype(primal: jax.Array, dtype: jnp.dtype):
if primal.dtype != dtype:
warnings.warn(f"Promoting primal dtype from {primal.dtype} to {dtype}.")
warnings.warn(f"Promoting primal dtype from {primal.dtype} to {dtype}, at {get_grandparent_info(2)}.")
return primal.astype(dtype)

def _get_result_type(primal: jax.Array):
Expand Down

0 comments on commit 086403a

Please sign in to comment.