-
-
Notifications
You must be signed in to change notification settings - Fork 611
Recur struct's fields are not type annotated, which is causing run–time dispatch and a significant slowdowns #1092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Do you have a minimal working example for this? |
Here's one: using Flux
f(m, xs) = sum(m.(xs))
m = LSTM(3,5)
xs = [rand(Float32, 3,7) for _ ∈ 1:11] and you can see that the type inference fails: julia> @code_warntype f(m, xs)
Variables
#self#::Core.Compiler.Const(f, false)
m::Flux.Recur{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}}}
xs::Array{Array{Float32,2},1}
Body::Any
1 ─ %1 = Base.broadcasted(m, xs)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,Flux.Recur{Flux.LSTMCell{Array{Float32,2},Array{Float32,1}}},Tuple{Array{Array{Float32,2},1}}}
│ %2 = Base.materialize(%1)::Any
│ %3 = Main.sum(%2)::Any
└── return %3 |
|
One solution is to add type parameters to the |
Won't work, for LSTM |
I don't see how them being |
How will you set the type of |
I believe you meant struct Recur{T,I,S}
cell::T
init::I
state::S
end so for |
Yeah, sorry, meant |
Yeah, I remember testing it and hitting this error
|
As I've mentioned previously, you would also need to change the initialization of |
I'm getting around 10% performance improvement on the toy examples. You should definitely write a PR with this proposed change. The only caveat seems to be the matrix output for vector input. Not sure about gpu stuff though. |
Seems like this was addressed in #1367, but GH wasn't smart enough to auto-close. |
In my use case, the state of the LSTM layer is inferred as
Any
because of this, which propagates down the network leading to run–time dispatch in ALL layers and a significant slowdownThe text was updated successfully, but these errors were encountered: