88import matplotlib
99
1010import numpy as np
11+ import numpy .typing as npt
1112from matplotlib import cm , colors , pyplot as plt
1213from matplotlib .axes import Axes
1314from matplotlib .collections import LineCollection
@@ -47,11 +48,11 @@ class VisualizeSign(Enum):
4748 all = 4
4849
4950
50- def _prepare_image (attr_visual : ndarray ) -> ndarray :
51+ def _prepare_image (attr_visual : npt . NDArray ) -> npt . NDArray :
5152 return np .clip (attr_visual .astype (int ), 0 , 255 )
5253
5354
54- def _normalize_scale (attr : ndarray , scale_factor : float ) -> ndarray :
55+ def _normalize_scale (attr : npt . NDArray , scale_factor : float ) -> npt . NDArray :
5556 assert scale_factor != 0 , "Cannot normalize by scale factor = 0"
5657 if abs (scale_factor ) < 1e-5 :
5758 warnings .warn (
@@ -64,7 +65,9 @@ def _normalize_scale(attr: ndarray, scale_factor: float) -> ndarray:
6465 return np .clip (attr_norm , - 1 , 1 )
6566
6667
67- def _cumulative_sum_threshold (values : ndarray , percentile : Union [int , float ]) -> float :
68+ def _cumulative_sum_threshold (
69+ values : npt .NDArray , percentile : Union [int , float ]
70+ ) -> float :
6871 # given values should be non-negative
6972 assert percentile >= 0 and percentile <= 100 , (
7073 "Percentile for thresholding must be " "between 0 and 100 inclusive."
@@ -76,11 +79,11 @@ def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]) ->
7679
7780
7881def _normalize_attr (
79- attr : ndarray ,
82+ attr : npt . NDArray ,
8083 sign : str ,
8184 outlier_perc : Union [int , float ] = 2 ,
8285 reduction_axis : Optional [int ] = None ,
83- ) -> ndarray :
86+ ) -> npt . NDArray :
8487 attr_combined = attr
8588 if reduction_axis is not None :
8689 attr_combined = np .sum (attr , axis = reduction_axis )
@@ -130,7 +133,7 @@ def _initialize_cmap_and_vmin_vmax(
130133
131134def _visualize_original_image (
132135 plt_axis : Axes ,
133- original_image : Optional [ndarray ],
136+ original_image : Optional [npt . NDArray ],
134137 ** kwargs : Any ,
135138) -> None :
136139 assert (
@@ -143,7 +146,7 @@ def _visualize_original_image(
143146
144147def _visualize_heat_map (
145148 plt_axis : Axes ,
146- norm_attr : ndarray ,
149+ norm_attr : npt . NDArray ,
147150 cmap : Union [str , Colormap ],
148151 vmin : float ,
149152 vmax : float ,
@@ -155,8 +158,8 @@ def _visualize_heat_map(
155158
156159def _visualize_blended_heat_map (
157160 plt_axis : Axes ,
158- original_image : ndarray ,
159- norm_attr : ndarray ,
161+ original_image : npt . NDArray ,
162+ norm_attr : npt . NDArray ,
160163 cmap : Union [str , Colormap ],
161164 vmin : float ,
162165 vmax : float ,
@@ -176,8 +179,8 @@ def _visualize_blended_heat_map(
176179def _visualize_masked_image (
177180 plt_axis : Axes ,
178181 sign : str ,
179- original_image : ndarray ,
180- norm_attr : ndarray ,
182+ original_image : npt . NDArray ,
183+ norm_attr : npt . NDArray ,
181184 ** kwargs : Any ,
182185) -> None :
183186 assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -190,8 +193,8 @@ def _visualize_masked_image(
190193def _visualize_alpha_scaling (
191194 plt_axis : Axes ,
192195 sign : str ,
193- original_image : ndarray ,
194- norm_attr : ndarray ,
196+ original_image : npt . NDArray ,
197+ norm_attr : npt . NDArray ,
195198 ** kwargs : Any ,
196199) -> None :
197200 assert VisualizeSign [sign ].value != VisualizeSign .all .value , (
@@ -210,8 +213,8 @@ def _visualize_alpha_scaling(
210213
211214
212215def visualize_image_attr (
213- attr : ndarray ,
214- original_image : Optional [ndarray ] = None ,
216+ attr : npt . NDArray ,
217+ original_image : Optional [npt . NDArray ] = None ,
215218 method : str = "heat_map" ,
216219 sign : str = "absolute_value" ,
217220 plt_fig_axis : Optional [Tuple [Figure , Axes ]] = None ,
@@ -417,8 +420,8 @@ def visualize_image_attr(
417420
418421
419422def visualize_image_attr_multiple (
420- attr : ndarray ,
421- original_image : Union [None , ndarray ],
423+ attr : npt . NDArray ,
424+ original_image : Union [None , npt . NDArray ],
422425 methods : List [str ],
423426 signs : List [str ],
424427 titles : Optional [List [str ]] = None ,
@@ -526,9 +529,9 @@ def visualize_image_attr_multiple(
526529
527530
528531def visualize_timeseries_attr (
529- attr : ndarray ,
530- data : ndarray ,
531- x_values : Optional [ndarray ] = None ,
532+ attr : npt . NDArray ,
533+ data : npt . NDArray ,
534+ x_values : Optional [npt . NDArray ] = None ,
532535 method : str = "overlay_individual" ,
533536 sign : str = "absolute_value" ,
534537 channel_labels : Optional [List [str ]] = None ,
0 commit comments