-
-
Notifications
You must be signed in to change notification settings - Fork 611
Adding GRUv3 support. #1675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GRUv3 support. #1675
Changes from all commits
dcf5afb
2cf1288
f08c679
e62e06f
a83baec
56273f0
6c04915
2a5420d
f1ea924
434c10e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -183,6 +183,15 @@ end | |
|
||
# GRU | ||
|
||
function _gru_output(Wi, Wh, b, x, h) | ||
o = size(h, 1) | ||
gx, gh = Wi*x, Wh*h | ||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) | ||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) | ||
|
||
return gx, gh, r, z | ||
end | ||
|
||
struct GRUCell{A,V,S} | ||
Wi::A | ||
Wh::A | ||
|
@@ -195,9 +204,7 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = | |
|
||
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} | ||
b, o = m.b, size(h, 1) | ||
gx, gh = m.Wi*x, m.Wh*h | ||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) | ||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) | ||
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h) | ||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) | ||
h′ = (1 .- z) .* h̃ .+ z .* h | ||
sz = size(x) | ||
|
@@ -212,8 +219,9 @@ Base.show(io::IO, l::GRUCell) = | |
""" | ||
GRU(in::Integer, out::Integer) | ||
|
||
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an | ||
RNN but generally exhibits a longer memory span over sequences. | ||
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an | ||
RNN but generally exhibits a longer memory span over sequences. This implements | ||
the variant proposed in v1 of the referenced paper. | ||
|
||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) | ||
for a good overview of the internals. | ||
|
@@ -233,6 +241,49 @@ function Base.getproperty(m::GRUCell, sym::Symbol) | |
end | ||
end | ||
|
||
|
||
# GRU v3 | ||
|
||
struct GRUv3Cell{A,V,S} | ||
Wi::A | ||
Wh::A | ||
b::V | ||
Wh_h̃::A | ||
state0::S | ||
end | ||
|
||
GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs an activation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm following the exact constructor for the GRU currently in Flux line. If we want to add activations here it would make sense to add them for the original GRU and LSTMs for consistency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK the activations in LSTM/GRU are very specifically chosen. That's why they are currently not options, and we should probably keep that consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth pointing out that TF and JAX-based libraries do allow you to customize the activation. I presume PyTorch doesn't because it lacks a non-CuDNN path for it's GPU RNN backend. That said, this would be better as a separate PR that changes every RNN layer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, is it just the output activation or all of them? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of them I believe: tensorflow GRU There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of them, IIUC. There's a distinction between "activation" functions (by default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a non-Google implementation, here's MXNet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prior art in Flux: #964 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I agree that we should have the activations in Flux generally across all layers. |
||
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), | ||
init(out, out), init_state(out,1)) | ||
|
||
function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to remove the types on the input and parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe these were introduced as part of #1521. We should tackle them separately for all recurrent cells. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should have a central issue which details all the updates to recurrent cells the discussion in this PR and related issue has mentioned? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm definitely invested in the recurrent architectures for flux, so would like to help. But knowing all the outstanding issues is out of scope for what I can use my time for right now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's get this through and litigate general changes to the RNN interface in a separate issue. |
||
b, o = m.b, size(h, 1) | ||
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h) | ||
h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @CarloLucibello This line. |
||
h′ = (1 .- z) .* h̃ .+ z .* h | ||
sz = size(x) | ||
return h′, reshape(h′, :, sz[2:end]...) | ||
end | ||
|
||
@functor GRUv3Cell | ||
|
||
Base.show(io::IO, l::GRUv3Cell) = | ||
print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") | ||
|
||
""" | ||
GRUv3(in::Integer, out::Integer) | ||
|
||
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an | ||
RNN but generally exhibits a longer memory span over sequences. This implements | ||
the variant proposed in v3 of the referenced paper. | ||
|
||
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) | ||
for a good overview of the internals. | ||
""" | ||
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do the versions differ api wise? Does this need any extra terms? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't. The main difference is the |
||
Recur(m::GRUv3Cell) = Recur(m, m.state0) | ||
|
||
|
||
@adjoint function Broadcast.broadcasted(f::Recur, args...) | ||
Zygote.∇map(__context__, f, args...) | ||
end |
Uh oh!
There was an error while loading. Please reload this page.