@@ -518,12 +518,35 @@ class GenerateHeatmapd(MapTransform):
518518 Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.
519519 Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.
520520
521+ Args:
522+ keys: keys of the corresponding items in the dictionary.
523+ sigma: standard deviation for the Gaussian kernel. Can be a single value or sequence matching number of points.
524+ heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
525+ ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will
526+ have the same shape, affine, and spatial metadata as the reference images.
527+ spatial_shape: spatial dimensions of output heatmaps. Can be:
528+ - Single shape (tuple): applied to all keys
529+ - List of shapes: one per key (must match keys length)
530+ truncated: truncation distance for Gaussian kernel computation (in sigmas).
531+ normalize: if True, normalize each heatmap's peak value to 1.0.
532+ dtype: output data type for heatmaps. Defaults to np.float32.
533+ allow_missing_keys: if True, don't raise error if some keys are missing in data.
534+
535+ Returns:
536+ Dictionary with original data plus generated heatmaps at specified keys.
537+
538+ Raises:
539+ ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
540+ ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
541+ ValueError: If input points have invalid shape (must be 2D or 3D).
542+
521543 Notes:
522544 - Default heatmap_keys are generated as "{key}_heatmap" for each input key
523545 - Shape inference precedence: static spatial_shape > ref_image
524546 - Output shapes:
525547 - Non-batched points (N, D): (N, H, W[, D])
526548 - Batched points (B, N, D): (B, N, H, W[, D])
549+ - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
527550 """
528551
529552 backend = GenerateHeatmap .backend
@@ -575,7 +598,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
575598 # Copy metadata if reference is MetaTensor
576599 if isinstance (reference , MetaTensor ) and isinstance (heatmap , MetaTensor ):
577600 heatmap .affine = reference .affine
578- self ._update_spatial_metadata (heatmap , reference )
601+ self ._update_spatial_metadata (heatmap , shape )
579602 d [out_key ] = heatmap
580603 return d
581604
@@ -628,7 +651,7 @@ def _determine_shape(
628651 return static_shape
629652 points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
630653 if points_t .ndim not in (2 , 3 ):
631- raise ValueError (self ._ERR_INVALID_POINTS )
654+ raise ValueError (f" { self ._ERR_INVALID_POINTS } Got { points_t . ndim } D tensor." )
632655 spatial_dims = int (points_t .shape [- 1 ])
633656 if ref_key is not None and ref_key in data :
634657 return self ._shape_from_reference (data [ref_key ], spatial_dims )
@@ -646,10 +669,8 @@ def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int,
646669 return tuple (int (v ) for v in reference .shape [- spatial_dims :])
647670 raise ValueError (self ._ERR_REF_NO_SHAPE )
648671
649- def _update_spatial_metadata (self , heatmap : MetaTensor , reference : MetaTensor ) -> None :
650- """Update spatial metadata of heatmap based on its dimensions."""
651- # trailing dims after channel are spatial regardless of batch presence
652- spatial_shape = heatmap .shape [- (reference .ndim - 1 ) :]
672+ def _update_spatial_metadata (self , heatmap : MetaTensor , spatial_shape : tuple [int , ...]) -> None :
673+ """Set spatial_shape explicitly from resolved shape."""
653674 heatmap .meta ["spatial_shape" ] = tuple (int (v ) for v in spatial_shape )
654675
655676
0 commit comments