Skip to content

Fixes to Recurrent models for informative type mismatch error & output Vector for Vector input #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 4, 2021

Conversation

jeremiedb
Copy link
Contributor

Minor fix to Recurrent to return Vector with Vector input, returns an indicative error relative to type incompatibility where eltype of input doesn't match with eltype of state, as well as some typos in associated docs.
As discussed in #1483.

@jeremiedb jeremiedb changed the title Rnn test2 Fixes to Recurrent models for informative type mismatch error & output Vector for Vector input Feb 28, 2021
@jeremiedb
Copy link
Contributor Author

jeremiedb commented Feb 28, 2021

Tests error on OneHotInput. I missed that case, which brakes the design of the informative error on inconsistent eltype between input and state. Would dropping the parametric type in Cells' arguments be acceptable? That would revert to current case where the input/state eltype mismatch results in a more obscure mutation not supported error.

@darsnack
Copy link
Member

darsnack commented Feb 28, 2021

Why not just let x::S where S<:Union{<:AbstractArray{S}, OneHotArray}? Probably in all cases the operations involving x (e.g. W*x) will promote to the eltype of the weights. Assuming the weights and initial state are aligned, the type inference should be stable?

@jeremiedb
Copy link
Contributor Author

Why not just let x::S where S<:Union{<:AbstractArray{S}, OneHotArray}?

I think that's a good suggestion :) My understanding is that Union are to be avoided if possible, but in this case, it seems the best course of action. Tests are back to green, thanks!

@CarloLucibello
Copy link
Member

looks good! it only needs a @test_throw check for the type mismatch and tests for the vector input.

@DhairyaLGandhi
Copy link
Member

We are also considering using cudnn instead, it's unclear whether we have actually seen a performance advantage.

@CarloLucibello
Copy link
Member

CarloLucibello commented Mar 2, 2021

whether we do it through cudnn or not, we can add support for sequences batched in 3d arrays at any point in addition to the current interface. I think @jeremiedb already mentioned that a few months ago

@ToucheSir
Copy link
Member

Given how other frameworks have to bend over backwards and add a bunch of nasty conditionals/documentation caveats to integrate with cuDNN RNNs, I'd say we should only consider adding it if we decide to support 3d sequences.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to avoid unions for struct fields, but it should be fine for use in dispatch like this.

@jeremiedb
Copy link
Contributor Author

@CarloLucibello Tests on type mismatch and vector vs matrix input have been added.
@DhairyaLGandhi From the previously held discussion around CUDNN's RNN, point was following what @ToucheSir mentioned, that is that CUDNN's would makes sense in the context of 3d input. I still don't see benefit of introducing a machinery designed for 3D input whose optimizations get lost in the process of adapting it back into the Flux 2D design. I still see adding such feature as desirable, but of a different nature than the current Flux Recurrent approach.

@CarloLucibello
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Mar 4, 2021

Build succeeded:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants