Skip to content

Commit 09764f8

Browse files
Merge #1379
1379: Fix some issues with Zeros option 2 r=CarloLucibello a=DrChainsaw Fixes #1332, #1277 This is take two on fixing the above issues and is intended to be mutually exclusive to #1374. In this version Zeros is no longer an AbstractArray and is thus more like a there-is-no-such-parameter marker instead of this-parameter-is-an-immutable-array-of-zeros type. I made the change so that the advertised way of disabling bias is through `bias=false` rather than `bias=Zeros()`. `Zeros` itself is alot less capable here compared to the previous attempt, but that also keeps the number of moving parts down, i.e no need to test that both the 0-dim version and the full size version works the same. All in all there are fewer specializations compared to #1374. I think that a rename could definitely be in place. Better names I can think of are `bias=Off`, `bias=None` or as suggested by @mcabbott `bias=False()`. The best argument against renaming I can think of is to be a little bit less breaking. ### PR Checklist - [X] Tests are added - [X] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: DrChainsaw <[email protected]>
2 parents d01d5ba + c09657a commit 09764f8

File tree

11 files changed

+267
-127
lines changed

11 files changed

+267
-127
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# v0.11.3
2+
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
3+
14
# v0.11.2
25

36
* Adds the [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser.

docs/src/models/layers.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ ConvTranspose
2424
CrossCor
2525
SamePad
2626
flatten
27-
Flux.Zeros
2827
Flux.convfilter
2928
Flux.depthwiseconvfilter
3029
```

src/layers/basic.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ extraChain(::Tuple{}, x) = ()
8383

8484

8585
"""
86-
Dense(in::Integer, out::Integer, σ = identity)
86+
Dense(in::Integer, out::Integer, σ = identity; bias=true)
8787
8888
Create a traditional `Dense` layer with parameters `W` and `b`.
8989
@@ -92,6 +92,8 @@ Create a traditional `Dense` layer with parameters `W` and `b`.
9292
The input `x` must be a vector of length `in`, or a batch of vectors represented
9393
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
9494
95+
Setting `bias` to `false` will switch bias off for the layer.
96+
9597
# Example
9698
```
9799
julia> d = Dense(5, 2)
@@ -101,9 +103,12 @@ julia> d(rand(5))
101103
2-element Array{Float32,1}:
102104
-0.16210233
103105
0.123119034
106+
107+
julia> d = Dense(5, 2; bias=false)
108+
Dense(5, 2)
104109
```
105110
"""
106-
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
111+
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
107112
W::S
108113
b::T
109114
σ::F
@@ -112,8 +117,8 @@ end
112117
Dense(W, b) = Dense(W, b, identity)
113118

114119
function Dense(in::Integer, out::Integer, σ = identity;
115-
initW = glorot_uniform, initb = zeros)
116-
return Dense(initW(out, in), initb(out), σ)
120+
initW = glorot_uniform, initb = zeros, bias=true)
121+
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
117122
end
118123

119124
@functor Dense

src/layers/conv.jl

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
4646
and a batch of 50 would be a `100×100×3×50` array.
4747
4848
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
49-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
49+
Setting `bias` to `false` will switch bias off for the layer.
5050
5151
Takes the keyword arguments `pad`, `stride` and `dilation`.
5252
For input dimension N,
@@ -82,7 +82,7 @@ end
8282
8383
Constructs the convolutional layer with user defined weight and bias arrays.
8484
85-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
85+
Setting `bias` to `false` would switch `bias` off for the layer.
8686
8787
Takes the keyword arguments `pad`, `stride` and `dilation`.
8888
For input dimension N,
@@ -102,15 +102,16 @@ Conv(weight = weight,
102102
σ = sigmoid)
103103
```
104104
"""
105-
function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
105+
function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
106106
stride = 1, pad = 0, dilation = 1) where {T,N}
107107
stride = expand(Val(N-2), stride)
108108
dilation = expand(Val(N-2), dilation)
109109
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
110-
return Conv(σ, w, b, stride, pad, dilation)
110+
bias = create_bias(b, zeros, size(w, N))
111+
return Conv(σ, w, bias, stride, pad, dilation)
111112
end
112113

113-
function Conv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
114+
function Conv(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
114115
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
115116
Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
116117
end
@@ -131,7 +132,7 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
131132

132133
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
133134
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
134-
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
135+
weight = convfilter(k, ch, init = init), bias = true) where N
135136

136137
Conv(weight, bias, σ,
137138
stride = stride, pad = pad, dilation = dilation)
@@ -189,7 +190,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
189190
and a batch of 50 would be a `100×100×3×50` array.
190191
191192
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
192-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
193+
Setting `bias` to `false` will switch bias off for the layer.
193194
194195
Takes the keyword arguments `pad`, `stride` and `dilation`.
195196
For input dimension N,
@@ -215,7 +216,7 @@ end
215216
Constructs the convolutional transpose layer with user defined weight and bias arrays.
216217
forward pass.
217218
218-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
219+
Setting `bias` to `false` will switch bias off for the layer.
219220
220221
Takes the keyword arguments `pad`, `stride` and `dilation`.
221222
For input dimension N,
@@ -226,22 +227,23 @@ indicating padding values for each spatial dimension at both the ends.
226227
227228
For keyword-only constuctor, see also [`Conv`](@ref)
228229
"""
229-
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
230+
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
230231
stride = 1, pad = 0, dilation = 1) where {T,N}
231232
stride = expand(Val(N-2), stride)
232233
dilation = expand(Val(N-2), dilation)
233234
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
234-
return ConvTranspose(σ, w, b, stride, pad, dilation)
235+
bias = create_bias(b, zeros, size(w, N))
236+
return ConvTranspose(σ, w, bias, stride, pad, dilation)
235237
end
236238

237-
function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
239+
function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
238240
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
239241
ConvTranspose(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
240242
end
241243

242244
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
243245
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
244-
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
246+
weight = convfilter(k, reverse(ch), init = init), bias = true) where N
245247

246248
ConvTranspose(weight, bias, σ,
247249
stride = stride, pad = pad, dilation = dilation)
@@ -307,7 +309,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
307309
and a batch of 50 would be a `100×100×3×50` array.
308310
309311
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
310-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
312+
Setting `bias` to `false` will switch bias off for the layer.
311313
312314
Takes the keyword arguments `pad`, `stride` and `dilation`.
313315
For input dimension N,
@@ -333,7 +335,7 @@ end
333335
Constructs the `DepthwiseConv` layer with user defined weight and bias arrays.
334336
forward pass.
335337
336-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
338+
Setting `bias` to `false` would switch `bias` off for the layer.
337339
338340
Takes the keyword arguments `pad`, `stride` and `dilation`.
339341
For input dimension N,
@@ -344,15 +346,16 @@ indicating padding values for each spatial dimension at both the ends.
344346
345347
For keyword-only constuctor, see also [`Conv`](@ref)
346348
"""
347-
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
349+
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
348350
stride = 1, pad = 0, dilation = 1) where {T,N}
349351
stride = expand(Val(N-2), stride)
350352
dilation = expand(Val(N-2), dilation)
351353
pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride)
352-
return DepthwiseConv(σ, w, b, stride, pad, dilation)
354+
bias = create_bias(b, zeros, prod(size(w)[N-1:end]))
355+
return DepthwiseConv(σ, w, bias, stride, pad, dilation)
353356
end
354357

355-
function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
358+
function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
356359
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
357360
DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
358361
end
@@ -373,7 +376,7 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
373376

374377
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
375378
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
376-
weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
379+
weight = depthwiseconvfilter(k, ch, init = init), bias = true) where N
377380
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
378381

379382
return DepthwiseConv(
@@ -424,7 +427,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
424427
and a batch of 50 would be a `100×100×3×50` array.
425428
426429
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
427-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
430+
Setting `bias` to `false` will switch bias off for the layer.
428431
429432
Takes the keyword arguments `pad`, `stride` and `dilation`.
430433
For input dimension N,
@@ -461,7 +464,7 @@ end
461464
Constructs the standard cross convolutional layer with user defined weight and bias
462465
arrays.
463466
464-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
467+
Setting `bias` to `false` would switch `bias` off for the layer.
465468
466469
Takes the keyword arguments `pad`, `stride` and `dilation`.
467470
For input dimension N,
@@ -472,22 +475,23 @@ indicating padding values for each spatial dimension at both the ends.
472475
473476
For keyword-only constuctor, see also [`Conv`](@ref)
474477
"""
475-
function CrossCor(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
478+
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
476479
stride = 1, pad = 0, dilation = 1) where {T,N}
477480
stride = expand(Val(N-2), stride)
478481
dilation = expand(Val(N-2), dilation)
479482
pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride)
480-
return CrossCor(σ, w, b, stride, pad, dilation)
483+
bias = create_bias(b, zeros, size(w, N))
484+
return CrossCor(σ, w, bias, stride, pad, dilation)
481485
end
482486

483-
function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
487+
function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
484488
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
485489
CrossCor(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
486490
end
487491

488492
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
489493
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
490-
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
494+
weight = convfilter(k, ch, init = init), bias = true) where N
491495

492496
CrossCor(weight, bias, σ,
493497
stride = stride, pad = pad, dilation = dilation)

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
176176
ones(dims...) = Base.ones(Float32, dims...)
177177
zeros(dims...) = Base.zeros(Float32, dims...)
178178

179+
"""
180+
create_bias(shallcreate::Bool, iftrue, dims...)
181+
create_bias(x, ::Any...)
182+
183+
Return a bias parameter for a layer.
184+
185+
Essentially handles the allowed input options for the `bias` keyword:
186+
If `false`: Return the `Zeros` type which turns bias off.
187+
If `true` : Return the result of `iftrue(dims)`.
188+
If not a boolean, return self to handle the case of bias=somearray.
189+
"""
190+
create_bias(shallcreate::Bool, iftrue, dims...) = shallcreate ? iftrue(dims...) : Zeros()
191+
create_bias(x, ::Any...) = x
192+
179193
"""
180194
unsqueeze(xs, dim)
181195

src/zeros.jl

Lines changed: 26 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import Base: +, -, *, reshape, size
2-
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
1+
import Base: +, -, *,/, reshape, broadcasted
32

43
"""
54
Zeros()
6-
Zeros(size...)
7-
Zeros(Type, size...)
85
96
Acts as a stand-in for an array of zeros that can be
107
used during training which is ignored by the optimisers.
@@ -13,94 +10,40 @@ Useful to turn bias off for a forward pass of a layer.
1310
1411
## Examples
1512
16-
```julia
17-
julia> Flux.Zeros(3,3)
18-
3×3 Flux.Zeros{Bool,2}:
19-
false false false
20-
false false false
21-
false false false
22-
23-
julia> Flux.Zeros(Float32, 3,3)
24-
3×3 Flux.Zeros{Float32,2}:
25-
0.0 0.0 0.0
26-
0.0 0.0 0.0
27-
0.0 0.0 0.0
13+
```julia-repl
14+
julia> bias_less_conv = Conv((2,2), 1=>3; bias = false)
15+
Conv((2, 2), 1=>3)
2816
29-
julia> rand(3,3) .+ Flux.Zeros()
30-
3×3 Array{Float64,2}:
31-
0.198739 0.490459 0.785386
32-
0.779074 0.39986 0.66383
33-
0.854981 0.447292 0.314497
17+
julia> params(bias_less_conv) |> length
18+
1
3419
35-
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
36-
Conv((2, 2), 1=>3)
20+
julia> bias_less_conv.bias
21+
Flux.Zeros()
3722
```
3823
"""
39-
struct Zeros{T,N} <: AbstractArray{T,N}
40-
size::Tuple
41-
end
42-
43-
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
44-
Zeros(sz::Integer...) = Zeros(Bool, sz...)
45-
46-
Base.size(xs::Zeros) = xs.size
47-
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
48-
49-
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
50-
51-
Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T)
52-
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
53-
Zeros(T, length(inds))
54-
55-
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
56-
57-
@adjoint reshape(xs::Zeros{T}, dims...) where T =
58-
reshape(xs, dims...), _ -> nothing
59-
60-
# Define basic ops
61-
for f in (:+, :-)
62-
@eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros)
63-
@assert size(a) == size(b) throw(DimensionMismatch("dimensions must match"))
64-
a
65-
end
66-
end
67-
68-
+(a::Zeros, b::AbstractArray) = b + a
69-
-(a::Zeros, b::AbstractArray) = -b + a
70-
71-
Base.copy(xs::Zeros{T,N}) where {T,N} = xs
72-
73-
# Define broadcasting behaviour
74-
for op in (:+, :-)
75-
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
76-
bs = Broadcast.broadcast_shape(size(a), size(b))
77-
size(a) == bs && return a
78-
sz = similar(a, bs)
79-
sz .= a
80-
end
81-
end
82-
83-
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a)
84-
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a)
24+
struct Zeros end
25+
# To allow for things like Dense(10, 2, initb = Zeros)
26+
Zeros(args...) = Zeros()
8527

86-
function broadcasted(::typeof(*), a::AbstractArray, b::Zeros)
87-
Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
88-
end
28+
Base.reshape(x::Zeros, dims...) = x
8929

90-
broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a)
30+
+(::Zeros, b::AbstractArray) = b
31+
+(a::AbstractArray, ::Zeros) = a
32+
+(a::Zeros, ::Zeros) = a
9133

92-
for op in (:+, :-, :*)
93-
@eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
94-
end
34+
-(::Zeros, b::AbstractArray) = -b
35+
-(a::AbstractArray, ::Zeros) = a
36+
-(a::Zeros, ::Zeros) = a
9537

9638
# Some opportunities to avoid scalar indexing, intermediaries
9739
# Since it replicates a little of what we expect Base to do,
9840
# it should be possible to remove in the future, but for now,
9941
# these help with performance.
100-
broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a
101-
broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b
102-
broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a
103-
broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b
104-
broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a)
105-
broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
106-
broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
42+
broadcasted(::typeof(+), a::AbstractArray, b::Zeros) = a
43+
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = b
44+
broadcasted(::typeof(-), a::AbstractArray, b::Zeros) = a
45+
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = -b
46+
# Need adjoints for these or else the gradient w.r.t to the non-Zeros arg will be nothing as well
47+
@adjoint broadcasted(::typeof(*), a::AbstractArray, b::Zeros) = zero(a), _ -> (nothing, zero(a), nothing)
48+
@adjoint broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b))
49+
@adjoint broadcasted(::typeof(/), a::Zeros, b::AbstractArray) = zero(b), _ -> (nothing, nothing, zero(b))

0 commit comments

Comments
 (0)