-
-
Notifications
You must be signed in to change notification settings - Fork 611
Recurrent network interface updates/design #1678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I added the ConvLSTM issue in the top post so we can track that here as well. |
I've started working on the Folded interface. It actually should be pretty easy to add without disrupting the current api, but I haven't thought through how CuDNN fits yet. I'll make a pr after we get #1675 merged, so we can iterate. |
I separated out CuDNNs from 3d array support. They both influence each other, but I think getting the FoldedRNN's api right needs a bit of iteration first. |
I added Bidirectional RNNs as per the conversation in #1686. |
1686: Adding support for folding RNNs over 3d arrays r=DhairyaLGandhi a=mkschleg From #1678, adding a Recur like interface for a folded operation with support for 3-dimensional arrays. This is how many users expect RNNs to work if they are familiar with Pytorch and Tensorflow, and there seems to be some desire for support for this feature as per the discussion in #1671 and `@jeremiedb` . This will also make a push to implementing support for the CuDNN versions of RNNs/GRUs/LSTMs more streamlined as this is the data layout that API expects. I did a barebones implementation to add support so we can start iterating on API. There are several questions that I have lingering with this interface: - ~Should we support different modes where we return all or only the last hidden state? Is there a better way to do the concat of the hidden states?~ - What kind of tests should we have? Just follow what we currently do for RNNs/LSTMs/GRUs? - ~For the CPU version, does it make sense not to specialize on the different rnn types? We might be able to take more advantage of BLAS if we specialized on say `Folded{GRU}`.~ - ~Do we want to force the temporal dimension to be the 2nd?~ - ~Do we want this to be stateful? (i.e. allow the user to change what the starting hidden state is rather than state0).~ ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Matthew Schlegel <[email protected]> Co-authored-by: Matthew Schlegel <[email protected]> Co-authored-by: Dhairya Gandhi <[email protected]>
Now that #1686 is merged I think we should disambiguate the interface (e.g. a 3d tensor input could be a single time step on batched grey images [width, height, batch_size] or multiple time steps on batched 1d inputs [num_features, batch_size, seq_length]). I think that starting from next breaking release we should always assume that the input's last dimensions are batch and time, and start introducing a deprecation path. |
It's not something that needs deprecation yet. |
If we decide to go with this
at the very least we should immediately update the docs |
👍 for a docs update. I don't think we need to deprecate anything though, because it's not at all clear that |
@ToucheSir @DhairyaLGandhi @CarloLucibello and others. Sorry for the recent static, but this semester was absolutely chaotic/brutal for me. I've started a RNN repo where we can play with the interface. I think the first thing we need to do is figure out the weirdness in the API that was introduced by #1686 . As discussed, it would not be possible to have convolutional RNNs w/ this interface. I'm thinking the recur struct needs to know some amount of information about its cell type for this to work. It is also possible that we could use traits to accomplish this. This would also be a good place to test CuDNN paths. I'm curious to see what people think so we can get the RNN api in a better place. I have some ideas and will hopefully get to working on instantiating them over the next few days. The current repo just a copy of the current interface, and the tests are also the same (but just for the RNN layers). |
Been playing around w/ the designs in FluxRNNs.jl. The one I've settled on is using traits for the input types and then EllipsesNotation to deal with to looping over the final index. For each type of input all we need to do is implement forward!(::InputType, m, x) and it should appropriately dispatch and This also brings up some oddness. There are also some other odd hard edges with this design. For example. Say you us a 2D array as input and assume that the dims are in \times timesteps, the current implementation assumes timesteps is actually batch and not roll through. This is solved by documentation to reshape the array as in \times 1 \times timesteps, which is not unreasonable. |
My next objective is to actually implement a ConvLSTM to see how it works w/ the interface. |
One thing I didn't think of for #1686 but could help quite a bit with performance is using |
Oh! Right, this was what I was looking for at some point. I'll do some comparisons and see. I think I have no qualms with setting the assumption that the last dimensions is time, as this makes the most sense to me. When we consider the types of inputs we may need for a recurrent architecture, each cell will have three options:
These three feel appropriate, and I don't think it would be useful to have a utility which can take something that is The one we might be able to remove would be the first, forcing ppl to shape as I need to read #1790 still to figure out how BiDirection fits within all this, but from some brief skimming it looks like it can be handled outside of recur, which is likely for the best. |
Ok. From the above, eachslice is slightly better (but likely just the same) as our view. So I would say that would be a good option. In the repo I'm working in I've switched to this. In the FluxRNNs repo, I've been working on some tools to measure and compare performance of rnns and store/plot the data so I can make changes without having to go back and manually check for regressions. It is inspired what flux does already, but to make it work w/ rnns. We can eventually use some of this in Flux. Instead of going towards convLSTMs, because I'm still looking for a canonical implementation that I can base things off of, I started playing w/ cudnn paths in this branch. There are some challenges to get the new CuDNN interface from CUDA.jl working w/ our RNNs that I'm still working through. Right now, CuDNN expects the weights to be in block matrix of size (out, in+out+bias), which we don't do. I implemented a CuRNN which is literally just a simple RNN except we are storing weights in a single block and then using views to access the parts of it. I see a few options that some opinions on would be good:
struct CuRNNCell{F,A,V,M,S}
σ::F
Wi::A # view into W
Wh::A # view into W
b::V # view into W
W::M
state0::S
end
I was looking to see how this was handled in the past (i.e. v0.10.0), and what was done was to use a function I think this is a bit bigger of an issue than the conv issue, and we should try and resolve this relatively soon before making too many changes to RNNs. |
Do weights need to be blocked for stacked (multi-layer) cuDNN RNNs as well? This certainly makes the design more challenging, but I like your view idea. Since basically the entire forward pass would require a custom rrule to make use of cuDNN functionality, there's also more room to optimize inside that rrule. |
Yes. It is really handy to take a look at the docs for I think the blocking would be interesting. I think we should avoid having a separate cpu/gpu struct and stick w/ the shared structure. Seems like a bad idea to change the structure in a hidden way when going through gpu. This would also give us a path to create custom cpu paths/kernels for various configurations if that is something we want to work on eventually. I'm gunna ping CUDA.jl to see if there is something like the old |
Anyway. No reason to ping them. I was searching for In any case. Spelunking through some of Tensorflow's source they also use these structures pretty wholesale afaict. I'm still trying to figure out pytorch's source, but I would be surprised if they didn't also use this. All in all, I think moving to implementations where we have blocked weights might just be best for the future. This will probably be more flexible in the end for optimizations, and hopefully doesn't have a performance regression. |
My understanding from reading the main C API docs is that the underlying representation of these structures is purposefully kept opaque. If generating a RNN descriptor isn't too expensive, that could always be done on the fly. Caching is also an option like is currently done for some other cuDNN descriptor types. |
That makes sense. My guess is generating the RNN descriptor on the fly shouldn't be too bad. My only concern is how to manage the weights (as the I see a few options:
I think 3 is pretty sensible, as this is how Also, sorry for the dump of info/iteration on my end. Interacting w/ cudnn is really new to me, so it is taking awhile to get up and running. |
First steps to block weights in rnns #1855 . Would really like some thoughts before I move forward on this for the other cells. This impl works with cudnn (I've only worked with the forward pass, but I'm assuming it will also work w/ the backward). The main change is blocking the weights and adding some useful views. |
Was perusing the new Optimisers.jl stateless approach, and was wondering if we might want something similar to that for recurrent cells. This would also be similar to how flax works, I think. I remember having a lot of headaches come from having state be embedded in recur, and actually have a whole stack for inspecting and reseting state to a value other than s0 (for RL things) in arbitrary chains/models. Lux.jl already does this, so we might be able to take advantage of their implementation of "Recurrence". Although, I'm not sure where the boundaries between Lux and Flux are being drawn. This would only be concerned with the state of the cell, not the parameterization. I'm also not sure how well this would interact with the chain interface, and could be impossible. |
I was thinking much the same thing! Never got around to it though, but now that we have https://github.com/FluxML/Fluxperimental.jl/ this should be easier to prototype.
My rough idea would be: function Flux.apply(c::Chain, x)
l1, y1 = apply(c[1], x)
l2, y2 = apply(c[1], y1)
...
return Chain(l1, l2, ...), yN
end Instead of updating an externally passed in piece of state, we return an updated layer object. For layers without any auxiliary, non-parameter state, a fallback |
And note that an For examples of how these are different in Jax, you can compare Flax vs Equinox. The latter is closer to Flux's direction. |
Was a pretty easy interface to implement in Fluxperimental. Pull request here: FluxML/Fluxperimental.jl#5. Once we are happy with the basic interface we can start fleshing out the functionality to stateful structures. |
Uh oh!
There was an error while loading. Please reload this page.
While we were discussing #1675 and #1671 several improvements/updates to the recurrent network API came up. Instead of taking over #1675, @ToucheSir and myself thought it would be best to separate out the needed improvements into a separate issue so they can be worked on and discussed here. That way we can finish #1675, and move on with the other changes in lock-step.
eltype
restriction breaksoutputsize
#1565)Any others I'm missing?
The text was updated successfully, but these errors were encountered: