Skip to content

Commit

Permalink
add post ln as an option for simple gateloop
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 31, 2024
1 parent 3d83895 commit 1c00da3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions gateloop_transformer/simplified_gate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
prenorm = True,
use_heinsen = False,
use_jax_associative_scan = False,
post_ln = False,
reverse = False
):
super().__init__()
Expand All @@ -135,6 +136,7 @@ def __init__(
else:
self.gate_loop_fn = gate_loop_operator

self.maybe_post_ln = nn.LayerNorm(dim) if post_ln else nn.Identity()
self.split_heads = Rearrange('(b d) n 1 -> b n d', d = dim)

self.reverse = reverse
Expand All @@ -155,6 +157,7 @@ def forward(
out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache)

out = self.split_heads(out)
out = self.maybe_post_ln(out)

if self.reverse:
out = torch.flip(out, dims = (-2,))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'gateloop-transformer',
packages = find_packages(exclude=[]),
version = '0.2.3',
version = '0.2.4',
license='MIT',
description = 'GateLoop Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 1c00da3

Please sign in to comment.