Skip to content

Commit c09657a

Browse files
committed
Add option to set bias=false to use Zeros as bias
1 parent d17bcd7 commit c09657a

File tree

10 files changed

+78
-46
lines changed

10 files changed

+78
-46
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# v0.11.3
2+
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
3+
14
# v0.11.2
25

36
* Adds the [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser.

docs/src/models/layers.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ ConvTranspose
2424
CrossCor
2525
SamePad
2626
flatten
27-
Flux.Zeros
2827
Flux.convfilter
2928
Flux.depthwiseconvfilter
3029
```

src/layers/basic.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ extraChain(::Tuple{}, x) = ()
8383

8484

8585
"""
86-
Dense(in::Integer, out::Integer, σ = identity)
86+
Dense(in::Integer, out::Integer, σ = identity; bias=true)
8787
8888
Create a traditional `Dense` layer with parameters `W` and `b`.
8989
@@ -92,6 +92,8 @@ Create a traditional `Dense` layer with parameters `W` and `b`.
9292
The input `x` must be a vector of length `in`, or a batch of vectors represented
9393
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
9494
95+
Setting `bias` to `false` will switch bias off for the layer.
96+
9597
# Example
9698
```
9799
julia> d = Dense(5, 2)
@@ -101,6 +103,9 @@ julia> d(rand(5))
101103
2-element Array{Float32,1}:
102104
-0.16210233
103105
0.123119034
106+
107+
julia> d = Dense(5, 2; bias=false)
108+
Dense(5, 2)
104109
```
105110
"""
106111
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
@@ -112,8 +117,8 @@ end
112117
Dense(W, b) = Dense(W, b, identity)
113118

114119
function Dense(in::Integer, out::Integer, σ = identity;
115-
initW = glorot_uniform, initb = zeros)
116-
return Dense(initW(out, in), initb(out), σ)
120+
initW = glorot_uniform, initb = zeros, bias=true)
121+
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
117122
end
118123

119124
@functor Dense

src/layers/conv.jl

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
4646
and a batch of 50 would be a `100×100×3×50` array.
4747
4848
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
49-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
49+
Setting `bias` to `false` will switch bias off for the layer.
5050
5151
Takes the keyword arguments `pad`, `stride` and `dilation`.
5252
For input dimension N,
@@ -82,7 +82,7 @@ end
8282
8383
Constructs the convolutional layer with user defined weight and bias arrays.
8484
85-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
85+
Setting `bias` to `false` would switch `bias` off for the layer.
8686
8787
Takes the keyword arguments `pad`, `stride` and `dilation`.
8888
For input dimension N,
@@ -102,15 +102,16 @@ Conv(weight = weight,
102102
σ = sigmoid)
103103
```
104104
"""
105-
function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
105+
function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
106106
stride = 1, pad = 0, dilation = 1) where {T,N}
107107
stride = expand(Val(N-2), stride)
108108
dilation = expand(Val(N-2), dilation)
109109
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
110-
return Conv(σ, w, b, stride, pad, dilation)
110+
bias = create_bias(b, zeros, size(w, N))
111+
return Conv(σ, w, bias, stride, pad, dilation)
111112
end
112113

113-
function Conv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
114+
function Conv(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
114115
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
115116
Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
116117
end
@@ -131,7 +132,7 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
131132

132133
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
133134
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
134-
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
135+
weight = convfilter(k, ch, init = init), bias = true) where N
135136

136137
Conv(weight, bias, σ,
137138
stride = stride, pad = pad, dilation = dilation)
@@ -189,7 +190,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
189190
and a batch of 50 would be a `100×100×3×50` array.
190191
191192
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
192-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
193+
Setting `bias` to `false` will switch bias off for the layer.
193194
194195
Takes the keyword arguments `pad`, `stride` and `dilation`.
195196
For input dimension N,
@@ -215,7 +216,7 @@ end
215216
Constructs the convolutional transpose layer with user defined weight and bias arrays.
216217
forward pass.
217218
218-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
219+
Setting `bias` to `false` will switch bias off for the layer.
219220
220221
Takes the keyword arguments `pad`, `stride` and `dilation`.
221222
For input dimension N,
@@ -226,22 +227,23 @@ indicating padding values for each spatial dimension at both the ends.
226227
227228
For keyword-only constuctor, see also [`Conv`](@ref)
228229
"""
229-
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
230+
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
230231
stride = 1, pad = 0, dilation = 1) where {T,N}
231232
stride = expand(Val(N-2), stride)
232233
dilation = expand(Val(N-2), dilation)
233234
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
234-
return ConvTranspose(σ, w, b, stride, pad, dilation)
235+
bias = create_bias(b, zeros, size(w, N))
236+
return ConvTranspose(σ, w, bias, stride, pad, dilation)
235237
end
236238

237-
function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
239+
function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
238240
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
239241
ConvTranspose(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
240242
end
241243

242244
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
243245
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
244-
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
246+
weight = convfilter(k, reverse(ch), init = init), bias = true) where N
245247

246248
ConvTranspose(weight, bias, σ,
247249
stride = stride, pad = pad, dilation = dilation)
@@ -307,7 +309,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
307309
and a batch of 50 would be a `100×100×3×50` array.
308310
309311
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
310-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
312+
Setting `bias` to `false` will switch bias off for the layer.
311313
312314
Takes the keyword arguments `pad`, `stride` and `dilation`.
313315
For input dimension N,
@@ -333,7 +335,7 @@ end
333335
Constructs the `DepthwiseConv` layer with user defined weight and bias arrays.
334336
forward pass.
335337
336-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
338+
Setting `bias` to `false` would switch `bias` off for the layer.
337339
338340
Takes the keyword arguments `pad`, `stride` and `dilation`.
339341
For input dimension N,
@@ -344,15 +346,16 @@ indicating padding values for each spatial dimension at both the ends.
344346
345347
For keyword-only constuctor, see also [`Conv`](@ref)
346348
"""
347-
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
349+
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
348350
stride = 1, pad = 0, dilation = 1) where {T,N}
349351
stride = expand(Val(N-2), stride)
350352
dilation = expand(Val(N-2), dilation)
351353
pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride)
352-
return DepthwiseConv(σ, w, b, stride, pad, dilation)
354+
bias = create_bias(b, zeros, prod(size(w)[N-1:end]))
355+
return DepthwiseConv(σ, w, bias, stride, pad, dilation)
353356
end
354357

355-
function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
358+
function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
356359
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
357360
DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
358361
end
@@ -373,7 +376,7 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
373376

374377
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
375378
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
376-
weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
379+
weight = depthwiseconvfilter(k, ch, init = init), bias = true) where N
377380
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
378381

379382
return DepthwiseConv(
@@ -424,7 +427,7 @@ In other words, a 100×100 RGB image would be a `100×100×3×1` array,
424427
and a batch of 50 would be a `100×100×3×50` array.
425428
426429
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
427-
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
430+
Setting `bias` to `false` will switch bias off for the layer.
428431
429432
Takes the keyword arguments `pad`, `stride` and `dilation`.
430433
For input dimension N,
@@ -461,7 +464,7 @@ end
461464
Constructs the standard cross convolutional layer with user defined weight and bias
462465
arrays.
463466
464-
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
467+
Setting `bias` to `false` would switch `bias` off for the layer.
465468
466469
Takes the keyword arguments `pad`, `stride` and `dilation`.
467470
For input dimension N,
@@ -472,22 +475,23 @@ indicating padding values for each spatial dimension at both the ends.
472475
473476
For keyword-only constuctor, see also [`Conv`](@ref)
474477
"""
475-
function CrossCor(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
478+
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
476479
stride = 1, pad = 0, dilation = 1) where {T,N}
477480
stride = expand(Val(N-2), stride)
478481
dilation = expand(Val(N-2), dilation)
479482
pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride)
480-
return CrossCor(σ, w, b, stride, pad, dilation)
483+
bias = create_bias(b, zeros, size(w, N))
484+
return CrossCor(σ, w, bias, stride, pad, dilation)
481485
end
482486

483-
function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
487+
function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Bool, Zeros, AbstractVector{T}},
484488
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
485489
CrossCor(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
486490
end
487491

488492
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
489493
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
490-
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
494+
weight = convfilter(k, ch, init = init), bias = true) where N
491495

492496
CrossCor(weight, bias, σ,
493497
stride = stride, pad = pad, dilation = dilation)

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
176176
ones(dims...) = Base.ones(Float32, dims...)
177177
zeros(dims...) = Base.zeros(Float32, dims...)
178178

179+
"""
180+
create_bias(shallcreate::Bool, iftrue, dims...)
181+
create_bias(x, ::Any...)
182+
183+
Return a bias parameter for a layer.
184+
185+
Essentially handles the allowed input options for the `bias` keyword:
186+
If `false`: Return the `Zeros` type which turns bias off.
187+
If `true` : Return the result of `iftrue(dims)`.
188+
If not a boolean, return self to handle the case of bias=somearray.
189+
"""
190+
create_bias(shallcreate::Bool, iftrue, dims...) = shallcreate ? iftrue(dims...) : Zeros()
191+
create_bias(x, ::Any...) = x
192+
179193
"""
180194
unsqueeze(xs, dim)
181195

src/zeros.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ Useful to turn bias off for a forward pass of a layer.
1111
## Examples
1212
1313
```julia-repl
14-
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
14+
julia> bias_less_conv = Conv((2,2), 1=>3; bias = false)
1515
Conv((2, 2), 1=>3)
1616
17-
julia> bias_less_dense = Dense(10, 2, initb = Zeros)
18-
Dense(10, 2)
17+
julia> params(bias_less_conv) |> length
18+
1
19+
20+
julia> bias_less_conv.bias
21+
Flux.Zeros()
1922
```
2023
"""
2124
struct Zeros end

test/cuda/layers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ end
4646
# Repeats from Conv, CrossCor
4747

4848
# Just to give testset in gradtest meaningful labels
49-
ConvNoBias(args...) = Conv(args...; bias=Flux.Zeros())
50-
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias=Flux.Zeros())
51-
CrossCorNoBias(args...) = CrossCor(args...; bias=Flux.Zeros())
52-
DepthwiseConvNoBias(args...) = DepthwiseConv(args...;bias=Flux.Zeros())
49+
ConvNoBias(args...) = Conv(args...; bias=false)
50+
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias=false)
51+
CrossCorNoBias(args...) = CrossCor(args...; bias=false)
52+
DepthwiseConvNoBias(args...) = DepthwiseConv(args...;bias=false)
5353
r = rand(Float32, 28, 28, 1, 1)
5454
conv_layers = [Conv, ConvNoBias, ConvTranspose, ConvTransposeNoBias, CrossCor, CrossCorNoBias, DepthwiseConv, DepthwiseConvNoBias]
5555
gradtest("Conv", conv_layers, r, (2,2), 1=>3)
@@ -102,7 +102,7 @@ end
102102
end
103103

104104
@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv)
105-
l = cl((2,2), 1=>3, bias = Flux.Zeros()) |> gpu
105+
l = cl((2,2), 1=>3, bias = false) |> gpu
106106
ip = zeros(Float32, 28,28,1,1) |> gpu
107107
if cl in BROKEN_LAYERS
108108
@test_broken sum(l(ip)) 0.f0

test/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import Flux: activations
4545
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
4646
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
4747
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
48-
@test Dense(10, 2, identity, initW = ones, initb = Zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
48+
@test Dense(10, 2, identity, initW = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
4949
end
5050

5151
@testset "Diagonal" begin

test/layers/conv.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ end
4242
op = bias(ip)
4343
@test sum(op) == prod(size(op))
4444

45-
@testset "Zeros mapped through $lmap" for lmap in (identity, cpu, f32)
46-
bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) |> lmap
45+
@testset "No bias mapped through $lmap" for lmap in (identity, cpu, f32)
46+
bias = Conv((2,2), 1=>3, bias = false) |> lmap
4747
op = bias(ip)
4848
@test sum(op) 0.f0
4949
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
@@ -52,7 +52,7 @@ end
5252

5353
# Train w/o bias and make sure no convergence happens
5454
# when only bias can be converged
55-
bias = Conv((2, 2), 1=>3, bias = Flux.Zeros());
55+
bias = Conv((2, 2), 1=>3, bias = false);
5656
ip = zeros(Float32, 28,28,1,1)
5757
op = zeros(Float32, 27,27,3,1) .+ 2.f0
5858
opt = Descent()
@@ -87,8 +87,11 @@ end
8787
m1 = DepthwiseConv((2, 2), 3=>15)
8888
@test size(m1(r), 3) == 15
8989

90-
m3 = DepthwiseConv((2, 3), 3=>9)
91-
@test size(m3(r), 3) == 9
90+
m2 = DepthwiseConv((2, 3), 3=>9)
91+
@test size(m2(r), 3) == 9
92+
93+
m3 = DepthwiseConv((2, 3), 3=>9; bias=false)
94+
@test size(m2(r), 3) == 9
9295

9396
# Test that we cannot ask for non-integer multiplication factors
9497
@test_throws AssertionError DepthwiseConv((2,2), 3=>10)
@@ -97,8 +100,9 @@ end
97100
@testset "ConvTranspose" begin
98101
x = zeros(Float32, 28, 28, 1, 1)
99102
y = Conv((3,3), 1 => 1)(x)
100-
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
101-
@test size(x_hat) == size(x)
103+
x_hat1 = ConvTranspose((3, 3), 1 => 1)(y)
104+
x_hat2 = ConvTranspose((3, 3), 1 => 1, bias=false)(y)
105+
@test size(x_hat1) == size(x_hat2) == size(x)
102106

103107
m = ConvTranspose((3,3), 1=>1)
104108
# Test that the gradient call does not throw: #900
@@ -116,7 +120,7 @@ end
116120
m = Chain(
117121
CrossCor((2, 2), 1=>16, relu),
118122
MaxPool((2,2)),
119-
CrossCor((2, 2), 16=>8, relu),
123+
CrossCor((2, 2), 16=>8, relu; bias=false),
120124
MaxPool((2,2)),
121125
x -> reshape(x, :, size(x, 4)),
122126
Dense(288, 10), softmax)

test/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130
end
131131

132132
@testset "Zeros" begin
133-
m = Dense(randn(2,3), Zeros())
133+
m = Dense(3,2; bias=false)
134134
@test f64(m).b === m.b === Zeros()
135135
@test f32(m).b === m.b === Zeros()
136136

0 commit comments

Comments
 (0)