Skip to content

Commit

Permalink
Merge pull request #13 from yuehhua/develop
Browse files Browse the repository at this point in the history
Fix multihead issue
  • Loading branch information
yuehhua authored Jul 21, 2022
2 parents a970d4b + 1d1908b commit 29855d1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 22 deletions.
7 changes: 3 additions & 4 deletions src/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Hopfield layer forward function.
- `key::AbstractArray`:
- `value::AbstractArray`:
- `max_iter`: `-1` for iteration until stable.
- `ϵ`:
- `ϵ`: The error tolerance for convergence.
"""
function hopfield_forward(Qt, Kt, Vt, out_proj, dropout, heads::Int,
β::AbstractArray, query::AbstractArray, key::AbstractArray, value::AbstractArray,
Expand All @@ -25,11 +25,10 @@ function hopfield_forward(Qt, Kt, Vt, out_proj, dropout, heads::Int,

= attention_prob(Q, K, β)
= multiple_updates(Â, Q, K, β, heads, max_iter, ϵ)
= move_heads_to_first(Â, heads)
V = move_heads_to_first(V, heads)

V = dropout(V)
attn_out = batched_mul(V, batched_transpose(Â))
attn_out = batched_innerprod(V, Â, dims=2)
attn_out = move_heads_to_first(attn_out, heads)
return out_proj(attn_out)
end

Expand Down
10 changes: 5 additions & 5 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=vdim,
head_dim=head_dim, pattern_dim=pattern_dim) |> gpu
Y = l(Q, K, V)
@test size(Y) == (emb_dim, heads*target_len, batch_size)
@test size(Y) == (emb_dim, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, K, V)), Flux.params(l))
@test length(g.grads) == 6
Expand All @@ -28,7 +28,7 @@
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=vdim,
head_dim=head_dim, pattern_dim=pattern_dim, enable_out_proj=false) |> gpu
Y = l(Q, K, V)
@test size(Y) == (heads*pattern_dim, heads*target_len, batch_size)
@test size(Y) == (heads*pattern_dim, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, K, V)), Flux.params(l))
@test length(g.grads) == 4
Expand All @@ -38,17 +38,17 @@
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=kdim,
head_dim=head_dim, pattern_dim=pattern_dim) |> gpu
Y = l(Q, nothing, nothing)
@test size(Y) == (emb_dim, heads*target_len, batch_size)
@test size(Y) == (emb_dim, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, nothing, nothing)), Flux.params(l))
@test length(g.grads) == 8
@test length(g.grads) == 10
end

@testset "static Q" begin
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=vdim,
head_dim=head_dim, pattern_dim=pattern_dim) |> gpu
Y = l(nothing, K, V)
@test size(Y) == (emb_dim, heads*emb_dim, batch_size)
@test size(Y) == (emb_dim, emb_dim, batch_size)

g = Zygote.gradient(() -> sum(l(nothing, K, V)), Flux.params(l))
@test length(g.grads) == 8
Expand Down
26 changes: 13 additions & 13 deletions test/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
@test size(l.linear_v(V)) == (heads*pattern_dim, source_len, batch_size)

Y = l(Q, K, V)
@test size(Y) == (emb_dim, heads*target_len, batch_size)
@test size(Y) == (emb_dim, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, K, V)), Flux.params(l))
@test length(g.grads) == 6
Expand All @@ -41,7 +41,7 @@
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=vdim,
head_dim=head_dim, pattern_dim=pattern_dim, enable_out_proj=false)
Y = l(Q, K, V)
@test size(Y) == (heads*pattern_dim, heads*target_len, batch_size)
@test size(Y) == (heads*pattern_dim, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, K, V)), Flux.params(l))
@test length(g.grads) == 4
Expand All @@ -51,17 +51,17 @@
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=kdim,
head_dim=head_dim, pattern_dim=pattern_dim)
Y = l(Q, nothing, nothing)
@test size(Y) == (emb_dim, heads*target_len, batch_size)
@test size(Y) == (emb_dim, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, nothing, nothing)), Flux.params(l))
@test length(g.grads) == 8
@test length(g.grads) == 10
end

@testset "static Q" begin
l = HopfieldCore(emb_dim, heads; kdim=kdim, vdim=vdim,
head_dim=head_dim, pattern_dim=pattern_dim)
Y = l(nothing, K, V)
@test size(Y) == (emb_dim, heads*emb_dim, batch_size)
@test size(Y) == (emb_dim, emb_dim, batch_size)

g = Zygote.gradient(() -> sum(l(nothing, K, V)), Flux.params(l))
@test length(g.grads) == 8
Expand All @@ -81,7 +81,7 @@
@test Hopfields.pattern_projection_dim(l) == pattern_projection_dim

Y = l(Q, K, V)
@test size(Y) == (out_channel, heads*target_len, batch_size)
@test size(Y) == (out_channel, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, K, V)), Flux.params(l))
@test length(g.grads) == 6
Expand All @@ -100,10 +100,10 @@
@test Hopfields.pattern_projection_dim(l) == stored_pattern_dim

Y = l(Q, nothing, nothing)
@test size(Y) == (out_channel, heads*target_len, batch_size)
@test size(Y) == (out_channel, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, nothing, nothing)), Flux.params(l))
@test length(g.grads) == 8
@test length(g.grads) == 10
end

@testset "HopfieldPooling" begin
Expand All @@ -120,7 +120,7 @@
@test Hopfields.pattern_projection_dim(l) == pattern_projection_dim

Y = l(nothing, K, V)
@test size(Y) == (out_channel, heads*emb_dim, batch_size)
@test size(Y) == (out_channel, emb_dim, batch_size)

g = Zygote.gradient(() -> sum(l(nothing, K, V)), Flux.params(l))
@test length(g.grads) == 8
Expand All @@ -132,7 +132,7 @@
pattern_projection_dim=pattern_projection_dim, max_iter=3)

Y = l(Q, K, V)
@test size(Y) == (out_channel, heads*target_len, batch_size)
@test size(Y) == (out_channel, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, K, V)), Flux.params(l))
@test length(g.grads) == 6
Expand All @@ -143,10 +143,10 @@
stored_pattern_dim=stored_pattern_dim, max_iter=3)

Y = l(Q, nothing, nothing)
@test size(Y) == (out_channel, heads*target_len, batch_size)
@test size(Y) == (out_channel, target_len, batch_size)

g = Zygote.gradient(() -> sum(l(Q, nothing, nothing)), Flux.params(l))
@test length(g.grads) == 20
@test length(g.grads) == 22
end

@testset "HopfieldPooling with multiple updates" begin
Expand All @@ -155,7 +155,7 @@
pattern_projection_dim=pattern_projection_dim, max_iter=3)

Y = l(nothing, K, V)
@test size(Y) == (out_channel, heads*emb_dim, batch_size)
@test size(Y) == (out_channel, emb_dim, batch_size)

g = Zygote.gradient(() -> sum(l(nothing, K, V)), Flux.params(l))
@test length(g.grads) == 8
Expand Down

0 comments on commit 29855d1

Please sign in to comment.