Skip to content

RNN and GRU give mutation error; LSTM gives ArgumentError about number of fields #1483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
maetshju opened this issue Jan 27, 2021 · 28 comments
Closed

Comments

@maetshju
Copy link
Contributor

I couldn't find another issue mentioning this, so my apologies if this has already come up. On the current master branch, 08e79c4, RNN layers give a mutation error from Zygote during the backward pass.

using Flux
m = RNN(5, 5)
loss(x) = sum(Flux.stack(m.(x), 2))
x = [rand(5) for i in 1:2]
Flux.train!(loss, Flux.params(m), x, ADAM())
RNN stacktrace
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.var"#375#376")(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:61
 [3] (::Zygote.var"#2270#back#377"{Zygote.var"#375#376"})(::Nothing) at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] copyto_axcheck! at .\abstractarray.jl:946 [inlined]
 [5] (::typeof((copyto_axcheck!)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [6] Array at .\array.jl:562 [inlined]
 [7] (::typeof((Array{Float32,2})))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [8] convert at .\array.jl:554 [inlined]
 [9] (::typeof((convert)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [10] setproperty! at .\Base.jl:34 [inlined]
 [11] (::typeof((setproperty!)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [12] Recur at C:\Users\freez\.julia\dev\Flux\src\layers\recurrent.jl:34 [inlined]
 [13] (::typeof((λ)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [14] #509 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:187 [inlined]
 [15] #3 at .\generator.jl:36 [inlined]
 [16] iterate at .\generator.jl:47 [inlined]
 [17] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof((λ)),1},Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}}},Base.var"#3#4"{Zygote.var"#509#513"}}) at .\array.jl:686
 [18] map at .\abstractarray.jl:2248 [inlined]
 [19] (::Zygote.var"#508#512"{Array{typeof((λ)),1}})(::Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:187
 [20] (::Flux.var"#324#back#177"{Zygote.var"#508#512"{Array{typeof((λ)),1}}})(::Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}) at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [21] loss at .\REPL[3]:1 [inlined]
 [22] (::typeof((loss)))(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [23] #150 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:191 [inlined]
 [24] #1693#back at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [25] #39 at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:103 [inlined]
 [26] (::typeof((λ)))(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof((λ))})(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface.jl:172
 [28] gradient(::Function, ::Zygote.Params) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface.jl:49        
 [29] macro expansion at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:102 [inlined]
 [30] macro expansion at C:\Users\freez\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [31] train!(::Function, ::Zygote.Params, ::Array{Array{Float64,1},1}, ::ADAM; cb::Flux.Optimise.var"#40#46") at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:100
 [32] train!(::Function, ::Zygote.Params, ::Array{Array{Float64,1},1}, ::ADAM) at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:98
 [33] top-level scope at REPL[5]:1

GRU layers throw a similar error.

m = GRU(5, 5)
Flux.train!(loss, Flux.params(m), x, ADAM())
GRU stacktrace
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.var"#375#376")(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:61
 [3] (::Zygote.var"#2270#back#377"{Zygote.var"#375#376"})(::Nothing) at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] copyto_axcheck! at .\abstractarray.jl:946 [inlined]
 [5] (::typeof((copyto_axcheck!)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0     
 [6] Array at .\array.jl:562 [inlined]
 [7] (::typeof((Array{Float32,2})))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0    
 [8] convert at .\array.jl:554 [inlined]
 [9] (::typeof((convert)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [10] setproperty! at .\Base.jl:34 [inlined]
 [11] (::typeof((setproperty!)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0       
 [12] Recur at C:\Users\freez\.julia\dev\Flux\src\layers\recurrent.jl:34 [inlined]
 [13] (::typeof((λ)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [14] #509 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:187 [inlined]
 [15] #3 at .\generator.jl:36 [inlined]
 [16] iterate at .\generator.jl:47 [inlined]
 [17] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof((λ)),1},Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}}},Base.var"#3#4"{Zygote.var"#509#513"}}) at .\array.jl:686
 [18] map at .\abstractarray.jl:2248 [inlined]
 [19] (::Zygote.var"#508#512"{Array{typeof((λ)),1}})(::Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:187
 [20] (::Flux.var"#324#back#177"{Zygote.var"#508#512"{Array{typeof((λ)),1}}})(::Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}) at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [21] loss at .\REPL[3]:1 [inlined]
 [22] (::typeof((loss)))(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [23] #150 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:191 [inlined]
 [24] #1693#back at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [25] #39 at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:103 [inlined]
 [26] (::typeof((λ)))(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof((λ))})(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface.jl:172
 [28] gradient(::Function, ::Zygote.Params) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface.jl:49        
 [29] macro expansion at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:102 [inlined]
 [30] macro expansion at C:\Users\freez\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [31] train!(::Function, ::Zygote.Params, ::Array{Array{Float64,1},1}, ::ADAM; cb::Flux.Optimise.var"#40#46") at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:100
 [32] train!(::Function, ::Zygote.Params, ::Array{Array{Float64,1},1}, ::ADAM) at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:98
 [33] top-level scope at REPL[9]:1

And, LSTM layers give an ArgumentError about not having a definite number of fields.

m = LSTM(5, 5)
Flux.train!(loss, Flux.params(m), x, ADAM())
LSTM stacktrace
ERROR: ArgumentError: type does not have a definite number of fields
Stacktrace:
 [1] fieldcount(::Any) at .\reflection.jl:725
 [2] fieldnames(::DataType) at .\reflection.jl:172
 [3] #s77#154 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:208 [inlined]
 [4] #s77#154(::Any, ::Any) at .\none:0
 [5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at .\boot.jl:527
 [6] grad_mut(::Type{T} where T) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:249
 [7] grad_mut(::Zygote.Context, ::Type{T} where T) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:256
 [8] (::Zygote.var"#back#159"{:parameters,Zygote.Context,DataType,Core.SimpleVector})(::Tuple{Nothing}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:220
 [9] (::Zygote.var"#1704#back#160"{Zygote.var"#back#159"{:parameters,Zygote.Context,DataType,Core.SimpleVector}})(::Tuple{Nothing}) at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [10] getproperty at .\Base.jl:28 [inlined]
 [11] (::typeof((getproperty)))(::Tuple{Nothing}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [12] literal_getproperty at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\ZygoteRules.jl:11 [inlined]
 [13] (::typeof((literal_getproperty)))(::Tuple{Nothing}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [14] tuple_type_tail at .\essentials.jl:223 [inlined]
 [15] (::typeof((tuple_type_tail)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [16] convert at .\essentials.jl:310 [inlined]
 [17] (::typeof((convert)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [18] convert at .\essentials.jl:310 [inlined]
 [19] (::typeof((convert)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [20] setproperty! at .\Base.jl:34 [inlined]
 [21] (::typeof((setproperty!)))(::Nothing) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [22] Recur at C:\Users\freez\.julia\dev\Flux\src\layers\recurrent.jl:34 [inlined]
 [23] (::typeof((λ)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [24] #509 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:187 [inlined]
 [25] #3 at .\generator.jl:36 [inlined]
 [26] iterate at .\generator.jl:47 [inlined]
 [27] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof((λ)),1},Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}}},Base.var"#3#4"{Zygote.var"#509#513"}}) at .\array.jl:686
 [28] map at .\abstractarray.jl:2248 [inlined]
 [29] (::Zygote.var"#508#512"{Array{typeof((λ)),1}})(::Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\array.jl:187
 [30] (::Flux.var"#324#back#177"{Zygote.var"#508#512"{Array{typeof((λ)),1}}})(::Array{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},1}) at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [31] loss at .\REPL[3]:1 [inlined]
 [32] (::typeof((loss)))(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [33] #150 at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\lib\lib.jl:191 [inlined]
 [34] #1693#back at C:\Users\freez\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [35] #39 at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:103 [inlined]
 [36] (::typeof((λ)))(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface2.jl:0
 [37] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof((λ))})(::Float64) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface.jl:172
 [38] gradient(::Function, ::Zygote.Params) at C:\Users\freez\.julia\packages\Zygote\Iz3wR\src\compiler\interface.jl:49        
 [39] macro expansion at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:102 [inlined]
 [40] macro expansion at C:\Users\freez\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [41] train!(::Function, ::Zygote.Params, ::Array{Array{Float64,1},1}, ::ADAM; cb::Flux.Optimise.var"#40#46") at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:100
 [42] train!(::Function, ::Zygote.Params, ::Array{Array{Float64,1},1}, ::ADAM) at C:\Users\freez\.julia\dev\Flux\src\optimise\train.jl:98
 [43] top-level scope at REPL[7]:1

None of these errors occur in the latest release, v0.11.6, and all the tests in the current master branch also pass for me.

@DrChainsaw
Copy link
Contributor

Just a hunch from the phone:

Try

x = [rand(5, 1) for i in 1:2]

Flux rnns wants 2D input (size x batch). Cant think of why this would trigger mutation, but Ive seen stranger things happen :)

@maetshju
Copy link
Contributor Author

The same errors occur for me when having x be a 2D array. Generally, I've had 1D arrays work fine in the past when using Flux, though perhaps this is a by-product of Flux having sufficiently abstract/generic code.

@jeremiedb
Copy link
Contributor

I could reproduce the error on current master but it's a bit puzzling as the RNN code hasn't changed since 0.11.3 and your test script works fine with latest releases (0.11.6, 0.11.4...). The Manifest on master seems to lag compared to latest release, with Zygote being 0.6.0 rather than 0.6.2 as in Flux 0.11.6. However even after making a pkg update from master, I still get the same mutating array error msg and couldn't isolate which package is in cause.

@DhairyaLGandhi Any idea if some changes in Project/Manifest may have introduced a mutating array issue? I can't think of any other cause.

@DhairyaLGandhi
Copy link
Member

The env shouldn't have caused this.

@jeremiedb
Copy link
Contributor

It shouldn't, but it seems like all points into that direction. I've just made a local PR test from 0.11.6 to master branch, and strangely and the only differences are the Project/Manifest, along the data-test and the .fluxbot folder.
Quite strangely, despite having a Zygote showing 0.6.0 in Manifest in master, the PR comparison with the 0.11.6 branch shows as if Zygote was set at 0.5.9 in Manifest. I can open a dummy PR in Flux to better illustrate this.

It shouldn't be the env, but as the Manifest/Project are about the parts that has moved since 0.11.6, what else could be the cause?

@jeremiedb
Copy link
Contributor

jeremiedb commented Jan 30, 2021

Sorry for the noise, I was incorrect in my assumption that the RNN modifs were integrated in Flux 0.11.6.
@maetshju : if you convert input to Float32, it should work:

x = [rand(Float32, 5) for i in 1:2]
Flux.train!(loss, Flux.params(m), x, ADAM())

This stems from the fix for the type inference stability of the RNN layers, as the initial state get initialized by glorot which default to Float32. Would it be desirable to have the initialization to follow the expected input, or rather convert input to match initialization type (Float32)?

@DhairyaLGandhi
Copy link
Member

Definitely the latter, although where exactly does the difference show up

@CarloLucibello
Copy link
Member

Would it be desirable to have the initialization to follow the expected input, or rather convert input to match initialization type (Float32)?

Neither I guess. We should just follow the julia promotion rules for the math operations involved. So ideally we shouldn't be seeing an error, or at least have a meaningful one.
Float64, inputs shouldn't be forcefully downcasted to Float32 just because the weights are Float32. If people want to work in 32 bit precision they have to fed 32 bit input. If they want to work in 64 bit instead, they create each layer with float64 weights, or more conveniently cast the entire model to 64 bit with model |> f64

@jeremiedb
Copy link
Contributor

To make Recur work fine with Float64, the type parameter S in the struct can simply be dropped:

mutable struct Recur{T,S}
  cell::T
  state::S
end

However, it results in a non defined type output:

julia> @code_warntype m(X[1])
Variables
  m::Flux.Recur{Flux.RNNCell{typeof(tanh),Array{Float32,2},Array{Float32,1},Array{Float32,2}}}
  x::Array{Float32,2}
  @_3::Int64
  y::Any

Body::Any
1 ─       Core.NewvarNode(:(@_3))
│         Core.NewvarNode(:(y))
│   %3  = Flux.eltype(x)::Core.Compiler.Const(Float32, false)
│   %4  = Base.getproperty(m, :state)::Any%5  = Flux.eltype(%4)::Any%6  = (%3 == %5)::Any
└──       goto #3 if not %6
2 ─       goto #4
3%9  = Base.AssertionError("Recur input elements must have the same type has its state.")::AssertionError
└──       Base.throw(%9)
4%11 = Base.getproperty(m, :cell)::Flux.RNNCell{typeof(tanh),Array{Float32,2},Array{Float32,1},Array{Float32,2}}%12 = Base.getproperty(m, :state)::Any%13 = (%11)(%12, x)::Tuple{Any,Any}%14 = Base.indexed_iterate(%13, 1)::Core.Compiler.PartialStruct(Tuple{Any,Int64}, Any[Any, Core.Compiler.Const(2, false)])
│   %15 = Core.getfield(%14, 1)::Any
│         Base.setproperty!(m, :state, %15)
│         (@_3 = Core.getfield(%14, 2))
│   %18 = Base.indexed_iterate(%13, 2, @_3::Core.Compiler.Const(2, false))::Core.Compiler.PartialStruct(Tuple{Any,Int64}, Any[Any, Core.Compiler.Const(3, false)])
│         (y = Core.getfield(%18, 1))
└──       return y

This doesn't affect performance on a simple Chain model, but there was an issue opened regarding important impact when there was downstream tasks performed: #1092

If keeping the Recur's state type parameter in Recur, adding an assertion on input types would work but looks to adds ~2-3% overhead:

function (m::Recur)(x::AbstractArray)
  @assert eltype(x) == eltype(m.state[1]) "Recur input elements must have the same type has its state."
  m.state, y = m.cell(m.state, x)
  return y
end

What you the preferred option between dropping the infrence type stability vs. adding a type assertion? Or a third option?

@DhairyaLGandhi
Copy link
Member

What if we could error appropriately in the forward pass. much of the error would be better understood. It would just require typing things a bit more strongly.

@ToucheSir
Copy link
Member

Just to clarify, forward pass here means the callable cells? Does stronger typing there (or the assert in Recur) limit mixed precision at all?

@DhairyaLGandhi
Copy link
Member

So the assert should actually go away, it's use is also mentioned in the docs.

But yes, the stronger typing can infact hurt mixed precision, but that would require some kernel forwarding and accumulation logic anyway which is somewhat separate and can be handled down the road.

@jeremiedb
Copy link
Contributor

jeremiedb commented Jan 30, 2021

Would that be in line with what you refer by forward pass:

mutable struct Recur{T,S}
  cell::T
  state::Union{S, Tuple{S,S}}
end

function (m::Recur{T,S})(x::S) where{T,S}
  m.state, y = m.cell(m.state, x)
  return y
end

It seems to almost be a workable approach. But it fails when the x input is a single vector rather than a matrix.

@DhairyaLGandhi
Copy link
Member

Let's definitely avoid unions. Those increase the surface of issues we would face and muddy the intent with the assumptions we make.

@jeremiedb
Copy link
Contributor

With hidden state of RNN and GRU being different than that of LSTM, plus the need to handle both Vector and Matrix input, I'm unfortunately short of ideas on how to articulate the type parametrization. Suggestions welcome!

@jeremiedb
Copy link
Contributor

jeremiedb commented Jan 31, 2021

Here's another proposal. It involves having the hidden state defined as a tuple for all types of recurrent cells, as opposed to a Matrix for RNN/GRU and Tuple of Matrix for LSTM. It performs a reshape to force a 2D input, similar in spirit to what is done with Dense layer.

mutable struct Recur{T,S}
  cell::T
  state::Tuple{Vararg{S}}
end

function (m::Recur)(x)
  m.state, y = m.cell(m.state, reshape(x, size(x,1), :))
  return y
end

Then, the consistency check on input parameters is performed through the forward definition of the cells:

function (m::RNNCell{F,A,V,S})((h,), x::S) where {F,A,V,S}
  σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
  h = σ.(Wi*x .+ Wh*h .+ b)
  return (h,), h
end

So requiring input x to share type with hidden state triggers a more sound error if hidden state is Float32 while x is Float64:

julia> m(X[1])
ERROR: MethodError: no method matching (::Flux.RNNCell{typeof(tanh),Array{Float32,2},Array{Float32,1},Array{Float32,2}})(::Tuple{Array{Float32,2}}, ::Array{Float64,2})
Closest candidates are:
  Any(::Any, ::S) where {F, A, V, S} at c:\github\Flux.jl\src\layers\recurrent.jl:85

Also, mixing types would work under this scheme. For example, if one wish to work with Flota64 input and Float32 parameters, it will work fine as long as the initial state is also initialized with Float64.

@jeremiedb
Copy link
Contributor

jeremiedb commented Feb 6, 2021

@DhairyaLGandhi Any opinion on the previous proposal? It would provide with both type inference stability as well as a meaningful type mismatch error if type of state is different than input. I could move forward with a PR.

@DhairyaLGandhi
Copy link
Member

Having the hidden state as tuple might be fine, but we typically don't restrict the type of input letting Julia's promotion mechanism take care of it. Also that might trip up mixed precision work in the future.

I'm wondering if we could deal with the reshape via dispatching on the kind of cell we are dealing with

@jeremiedb
Copy link
Contributor

Maybe there's a misunderstanding as the above proposition would support mixed precision. The constraint is only to have the hidden state to be of the same type than the input, which seems reasonable from my perspective as both are essentially inputs. The parameters could well be of Float32 while the input/state are Float64.
Why would you want a dispatch of the reshape based on the kind of shape? Since all 3 types of cells share the same reshaping need, I don't see the benefit of adding moving parts.

@CarloLucibello
Copy link
Member

It involves having the hidden state defined as a tuple for all types of recurrent cells

Although this change makes sense and brings more consistency, I'd like to avoid another breaking change if we really don't have to.

Wouldn't something like this be enough to error out on mismatched types?

function (m::RNNCell{F,A,V,<:AbstractArray{T}})(h, x::AbstractArray{T}) where {F,A,V,T}
  σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
  h = σ.(Wi*x .+ Wh*h .+ b)
  return h, h
end

Something similar could be done for GRU and LSTM, although it would be more complicated.

As a side note, the output of Dense is a vector when the input is a vector:

julia> Dense(3,2)(rand(3))
2-element Array{Float64,1}:
 1.7997510852225493
 0.6333258983290979

julia> Dense(3,2)(rand(3,1))
2×1 Array{Float64,2}:
 0.3791885941573536
 0.5453548860286661

so probably we should have the same behavior for recurrent layers (which is already the case I guess)

@jeremiedb
Copy link
Contributor

Wouldn't something like this be enough to error out on mismatched types?

This would effectively captures the mismatched types, but wouldn't solve the issue of type inference.
For the Recur struct to have an output whose type can be inferred, AFAIK, Recur's state must have a parameter set:

mutable struct Recur{T,S}
  cell::T
  state::Tuple{Vararg{S}}
end

This requirements however come at the cost that the type of the state cannot change after being initialized. This is why it got initialized as a Matrix. It is compatible with both Vector and Matrix inputs, however, as the state is a Matrix, it also results in the output being a matrix. To output a Vector, it can be fixed with the following, which results in stable output for Matrix input, but a Union of Vec and Mat if the input is a Vector, which isn't a concern to me:

function (m::Recur)(x)
  m.state, y = m.cell(m.state, x)
  sz = size(x)
  return reshape(y, :, sz[2:end]...)
end

To sum up, if it is desirable to have type inference of the Recur output (to me it is), then I see 2 options:

  1. Recur's state is always a tuple, which brings a change compared to v0.11. However, if I'm not mistaken, the latest changes were never part of a release, only merged into master, so there's would not be another breaking change, at least from a release perspective.
  2. Recur's state is defined as a union of Matrix and Tuple, which is less elegant, but keeps the current status where RNN and GRU output a Matrix as a state, and LSTM a Tuple of Matrix.

And finally there's option 3, which is Recur's type inference is dropped. However, I don't see a reason to go that route. At least from my usage perspective, I would only end up rewriting RNN's structs for my own usage.

@jeremiedb
Copy link
Contributor

Sorry to bump, if there any of the above 3 options or another one that you'd like to take?

@DhairyaLGandhi
Copy link
Member

I wasn't sure what was being proposed earlier. I mean that we effectively dispatch based on the type of cell, and let that kernel take care of reshaping, not dispatching on reshape itself. Sorry about the confusion.

@DhairyaLGandhi
Copy link
Member

It's a bit odd for the state to be a tuple, I admit.

@CarloLucibello
Copy link
Member

Bumps are very welcome!
Moving on the same breakage surface from last v0.11 release is fine, so 1. is seems ok (provided we can put deprecations in place).

Still, it is not clear to me what's wrong with what I suggested above, and what you mean by "issue of type inference". My proposal is

mutable struct Recur{T,S}
  cell::T
  state::S
end

function (m::RNNCell{F,A,V,<:AbstractArray{T}})(h, x::AbstractArray{T}) where {F,A,V,T}
...
end

#this one is just a lazy guess
function (m::RNNCell{F,A,NTuple{2,<:AbstractArray{T}})(h, x::AbstractArray{T}) where {F,A,V,T}
...
end

I'm not particularly fond of it but I think it should work.

@CarloLucibello
Copy link
Member

And yes, having a tuple state when not needed is not ideal. Also true though that 2 out of 3 cells have a tuple state, so we may have the remaining one be a tuple as well for consistency. Not sure what's the best choice.

@jeremiedb
Copy link
Contributor

@CarloLucibello Yes your proposal looks like a good solution. My initial doubt was concerning the parametrization of the Recur struct but keeping it as you stated turns out to work fine.

For the hidden state, the proposal to always go with tuples was initially motivated to avoid the union of Matrix and Tuple of Matrix. But given the above comment on Recur, it is now longer needed. Therefore, RNN and GRU would continue to return a Matrix in Recur's state and LSTM a tuple as before.

As for the requirement of having a Vector as input to also return a Vector, following @DhairyaLGandhi comment to perform the operation within the cell, it would take that form:

function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat{T}) where {F,A,V,T}
  σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
  h = σ.(Wi*x .+ Wh*h .+ b)
  sz = size(x)
  return h, reshape(h, :, sz[2:end]...)
end

Altogether, I think it these will resolve the remaining friction points. If that is fine with you, I'll go with a PR.

bors bot added a commit that referenced this issue Mar 4, 2021
1521: Fixes to Recurrent models for informative type mismatch error & output Vector for Vector input r=CarloLucibello a=jeremiedb

Minor fix to Recurrent to return `Vector` with `Vector` input, returns an indicative error relative to type incompatibility where eltype of input doesn't match with eltype of state, as well as some typos in associated docs. 
As discussed in #1483. 


Co-authored-by: jeremie.db <[email protected]>
Co-authored-by: jeremiedb <[email protected]>
@darsnack
Copy link
Member

darsnack commented Mar 4, 2021

I believe this was closed with #1521 but GH didn't automatically link the issue with the PR. @jeremiedb can reopen if I missed something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants