Skip to content

Recur struct's fields are not type annotated, which is causing run–time dispatch and a significant slowdowns #1092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
AzamatB opened this issue Mar 19, 2020 · 13 comments

Comments

@AzamatB
Copy link
Contributor

AzamatB commented Mar 19, 2020

In my use case, the state of the LSTM layer is inferred as Any because of this, which propagates down the network leading to run–time dispatch in ALL layers and a significant slowdown

@AzamatB AzamatB changed the title Recur's fields are not type annotated, which is causing run–time dispatch and a significant slowdowns Recur struct's fields are not type annotated, which is causing run–time dispatch and a significant slowdowns Mar 19, 2020
@bhvieira
Copy link
Contributor

Do you have a minimal working example for this?

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 17, 2020

Here's one:

using Flux

f(m, xs) = sum(m.(xs))

m = LSTM(3,5)
xs = [rand(Float32, 3,7) for _  1:11]

and you can see that the type inference fails:

julia> @code_warntype f(m, xs)
Variables
  #self#::Core.Compiler.Const(f, false)
  m::Flux.Recur{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}}}
  xs::Array{Array{Float32,2},1}

Body::Any
1%1 = Base.broadcasted(m, xs)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Flux.Recur{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│   %2 = Base.materialize(%1)::Any%3 = Main.sum(%2)::Any
└──      return %3

@bhvieira
Copy link
Contributor

Recur is used for all recurrent layers, but in LSTM init and state are both Tuple. In RNN they're not, for example. In LSTM state starts as a vector, but may become an array if your data has a batch dimension. We'd need to work that out, we can't type annotate them without some design changes (which I'm sure are possible looking into the code right now).

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 17, 2020

One solution is to add type parameters to the Recur's init and state fields and initialize them as n×1 matrix instead of a vector, but it has a possible downside of returning matrix output for vector input.

@bhvieira
Copy link
Contributor

Won't work, for LSTM init and state are Tuple

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 17, 2020

I don't see how them being Tuple is a problem here

@bhvieira
Copy link
Contributor

How will you set the type of init and cell then?

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 17, 2020

type of init and cell then?

I believe you meant init and state instead? For that you would do

struct Recur{T,I,S}
   cell::T
   init::I
   state::S
end

so for RNN and GRU type parameters I and S would be initialized with the matrix type and for LSTM those would initialize as a tuple of matrices instead. This is essentially what I was proposing at the beginning of this issue.

@bhvieira
Copy link
Contributor

bhvieira commented Apr 17, 2020

Yeah, sorry, meant state. I think I tested that earlier, will try it again, but the Tuples won't work so directly

@bhvieira
Copy link
Contributor

bhvieira commented Apr 17, 2020

Yeah, I remember testing it and hitting this error

using Flux
using Flux: @functor

mutable struct RecurI{T,I,S}
   cell::T
   init::I
   state::S
end

RecurI(m, h = Flux.hidden(m)) = RecurI(m, h, h)

function (m::RecurI)(xs...)
  h, y = m.cell(m.state, xs...)
  m.state = h #this line will throw an error
  return y
end

LSTMI(a...; ka...) = RecurI(Flux.LSTMCell(a...; ka...))

f(m, xs) = sum(m.(xs))

m = LSTMI(3,5)
xs = [rand(Float32, 3,7) for _ ∈ 1:11]

#julia> f(m, xs)
#ERROR: MethodError: no method matching Array{Float32,1}(::Array{Float32,2})
#Closest candidates are:
#  Array{Float32,1}(::AbstractArray{S,N}) where {T, N, S} at array.jl:497
#  Array{Float32,1}() where T at boot.jl:413
#  Array{Float32,1}(::UndefInitializer, ::Int64) where T at boot.jl:394
#  ...
#Stacktrace:
# [1] convert(::Type{Array{Float32,1}}, ::Array{Float32,2}) at .\array.jl:489
# [2] convert(::Type{Tuple{Array{Float32,1},Array{Float32,1}}}, ::Tuple{Array{Float32,2},Array{Float32,2}}) at .\essentials.jl:275
# [3] setproperty!(::RecurI{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}}}, ::Symbol, ::Tuple{Array{Float32,2},Array{Float32,2}}) at .\sysimg.jl:19
# [4] (::RecurI{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}}})(::Array{Float32,2}) at .\REPL[5]:3
# [5] _broadcast_getindex_evalf at .\broadcast.jl:582 [inlined]
# [6] _broadcast_getindex at .\broadcast.jl:555 [inlined]
# [7] getindex at .\broadcast.jl:515 [inlined]
# [8] copy(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},RecurI{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}) at .\broadcast.jl:790
# [9] materialize at .\broadcast.jl:756 [inlined]
# [10] f(::RecurI{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}},Tuple{Array{Float32,1},Array{Float32,1}}}, ::Array{Array{Float32,2},1}) at .\REPL[7]:1
# [11] top-level scope at none:0

@AzamatB
Copy link
Contributor Author

AzamatB commented Apr 17, 2020

As I've mentioned previously, you would also need to change the initialization of RNNCell's and GRUCell's hidden state h from vector to n×1 matrix for this to work. Or in the case of LSTMCell you would initialize both h and c as n×1 matrix instead of a vector

@bhvieira
Copy link
Contributor

I'm getting around 10% performance improvement on the toy examples. You should definitely write a PR with this proposed change. The only caveat seems to be the matrix output for vector input. Not sure about gpu stuff though.

@ToucheSir
Copy link
Member

Seems like this was addressed in #1367, but GH wasn't smart enough to auto-close.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants