Skip to content

Commit 85a6eff

Browse files
committed
fix: concrete type Number => parameterization
1 parent bddbb4b commit 85a6eff

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/layers/conv.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ function (g::GlobalMeanPool)(x)
626626
end
627627

628628
"""
629-
GlobalLPNormPool(p::Float64)
629+
GlobalLPNormPool(p::T)
630630
631631
Global lp norm pooling layer.
632632
@@ -636,16 +636,16 @@ by performing lp norm pooling on the complete (w,h)-shaped feature maps.
636636
See also [`LPNormPool`](@ref).
637637
638638
```jldoctest
639-
julia> xs = rand(Float32, 100, 100, 3, 50)
639+
julia> xs = rand(Float32, 100, 100, 3, 50);
640640
641-
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0))
641+
julia> m = Chain(Conv((3,3), 3 => 7), GlobalLPNormPool(2.0));
642642
643643
julia> m(xs) |> size
644644
(1, 1, 7, 50)
645645
```
646646
"""
647-
struct GlobalLPNormPool
648-
p::Float64
647+
struct GlobalLPNormPool{T<:Number}
648+
p::T
649649
end
650650

651651
function (g::GlobalLPNormPool)(x)
@@ -778,7 +778,7 @@ function Base.show(io::IO, m::MeanPool)
778778
end
779779

780780
"""
781-
LPNormPool(window::NTuple, p::Float64; pad=0, stride=window)
781+
LPNormPool(window::NTuple, p::T; pad=0, stride=window)
782782
783783
Lp norm pooling layer, calculating p-norm distance for each window,
784784
also known as LPPool in pytorch.
@@ -801,7 +801,7 @@ julia> xs = rand(Float32, 100, 100, 3, 50);
801801
julia> m = Chain(Conv((5,5), 3 => 7), LPNormPool((5,5), 2.0; pad=SamePad()))
802802
Chain(
803803
Conv((5, 5), 3 => 7), # 532 parameters
804-
LPNormPool((5, 5), p=2, pad=2),
804+
LPNormPool((5, 5), 2.0, pad=2),
805805
)
806806
807807
julia> m[1](xs) |> size
@@ -811,20 +811,20 @@ julia> m(xs) |> size
811811
(20, 20, 7, 50)
812812
813813
julia> layer = LPNormPool((5,), 2.0, pad=2, stride=(3,)) # one-dimensional window
814-
LPNormPool((5,), p=2, pad=2, stride=3)
814+
LPNormPool((5,), 2.0, pad=2, stride=3)
815815
816816
julia> layer(rand(Float32, 100, 7, 50)) |> size
817817
(34, 7, 50)
818818
```
819819
"""
820-
struct LPNormPool{N,M}
820+
struct LPNormPool{N,M,T<:Number}
821821
k::NTuple{N,Int}
822-
p::Float64
822+
p::T
823823
pad::NTuple{M,Int}
824824
stride::NTuple{N,Int}
825825
end
826826

827-
function LPNormPool(k::NTuple{N,Integer}, p::Float64; pad = 0, stride = k) where N
827+
function LPNormPool(k::NTuple{N,Integer}, p::T; pad = 0, stride = k) where {N,T}
828828
stride = expand(Val(N), stride)
829829
pad = calc_padding(LPNormPool, pad, k, 1, stride)
830830
return LPNormPool(k, p, pad, stride)

0 commit comments

Comments
 (0)