@@ -898,13 +898,16 @@ class GaussianSmoothing(nn.Module):
898
898
Apply gaussian smoothing on a
899
899
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
900
900
in the input using a depthwise convolution.
901
- Arguments:
902
- channels (int, sequence): Number of channels of the input tensors. Output will
903
- have this number of channels as well.
904
- kernel_size (int, sequence): Size of the gaussian kernel.
905
- sigma (float, sequence): Standard deviation of the gaussian kernel.
906
- dim (int, optional): The number of dimensions of the data.
907
- Default value is 2 (spatial).
901
+
902
+ Args:
903
+ channels (`int` or `sequence`):
904
+ Number of channels of the input tensors. The output will have this number of channels as well.
905
+ kernel_size (`int` or `sequence`):
906
+ Size of the Gaussian kernel.
907
+ sigma (`float` or `sequence`):
908
+ Standard deviation of the Gaussian kernel.
909
+ dim (`int`, *optional*, defaults to `2`):
910
+ The number of dimensions of the data. Default is 2 (spatial dimensions).
908
911
"""
909
912
910
913
def __init__ (self , channels , kernel_size , sigma , dim = 2 ):
@@ -944,10 +947,14 @@ def __init__(self, channels, kernel_size, sigma, dim=2):
944
947
def forward (self , input ):
945
948
"""
946
949
Apply gaussian filter to input.
947
- Arguments:
948
- input (torch.Tensor): Input to apply gaussian filter on.
950
+
951
+ Args:
952
+ input (`torch.Tensor` of shape `(N, C, H, W)`):
953
+ Input to apply Gaussian filter on.
954
+
949
955
Returns:
950
- filtered (torch.Tensor): Filtered output.
956
+ `torch.Tensor`:
957
+ The filtered output tensor with the same shape as the input.
951
958
"""
952
959
return self .conv (input , weight = self .weight .to (input .dtype ), groups = self .groups , padding = "same" )
953
960
0 commit comments