@@ -626,7 +626,7 @@ function (g::GlobalMeanPool)(x)
626
626
end
627
627
628
628
"""
629
- GlobalLPNormPool(p::Float64 )
629
+ GlobalLPNormPool(p::T )
630
630
631
631
Global lp norm pooling layer.
632
632
@@ -636,16 +636,16 @@ by performing lp norm pooling on the complete (w,h)-shaped feature maps.
636
636
See also [`LPNormPool`](@ref).
637
637
638
638
```jldoctest
639
- julia> xs = rand(Float32, 100, 100, 3, 50)
639
+ julia> xs = rand(Float32, 100, 100, 3, 50);
640
640
641
- julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0))
641
+ julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0));
642
642
643
643
julia> m(xs) |> size
644
644
(1, 1, 7, 50)
645
645
```
646
646
"""
647
- struct GlobalLPNormPool
648
- p:: Float64
647
+ struct GlobalLPNormPool{T <: Number }
648
+ p:: T
649
649
end
650
650
651
651
function (g:: GlobalLPNormPool )(x)
@@ -778,7 +778,7 @@ function Base.show(io::IO, m::MeanPool)
778
778
end
779
779
780
780
"""
781
- LPNormPool(window::NTuple, p::Float64 ; pad=0, stride=window)
781
+ LPNormPool(window::NTuple, p::T ; pad=0, stride=window)
782
782
783
783
Lp norm pooling layer, calculating p-norm distance for each window,
784
784
also known as LPPool in pytorch.
@@ -801,7 +801,7 @@ julia> xs = rand(Float32, 100, 100, 3, 50);
801
801
julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2.0; pad=SamePad()))
802
802
Chain(
803
803
Conv((5, 5), 3 => 7), # 532 parameters
804
- LPNormPool((5, 5), p=2 , pad=2),
804
+ LPNormPool((5, 5), 2.0 , pad=2),
805
805
)
806
806
807
807
julia> m[1](xs) |> size
@@ -811,20 +811,20 @@ julia> m(xs) |> size
811
811
(20, 20, 7, 50)
812
812
813
813
julia> layer = LPNormPool((5,), 2.0, pad=2, stride=(3,)) # one-dimensional window
814
- LPNormPool((5,), p=2 , pad=2, stride=3)
814
+ LPNormPool((5,), 2.0 , pad=2, stride=3)
815
815
816
816
julia> layer(rand(Float32, 100, 7, 50)) |> size
817
817
(34, 7, 50)
818
818
```
819
819
"""
820
- struct LPNormPool{N,M}
820
+ struct LPNormPool{N,M,T <: Number }
821
821
k:: NTuple{N,Int}
822
- p:: Float64
822
+ p:: T
823
823
pad:: NTuple{M,Int}
824
824
stride:: NTuple{N,Int}
825
825
end
826
826
827
- function LPNormPool (k:: NTuple{N,Integer} , p:: Float64 ; pad = 0 , stride = k) where N
827
+ function LPNormPool (k:: NTuple{N,Integer} , p:: T ; pad = 0 , stride = k) where {N,T}
828
828
stride = expand (Val (N), stride)
829
829
pad = calc_padding (LPNormPool, pad, k, 1 , stride)
830
830
return LPNormPool (k, p, pad, stride)
0 commit comments