Skip to content

Commit

Permalink
missing interleave for associative scan, fix axis
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2023
1 parent e2a342b commit 044c266
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
17 changes: 17 additions & 0 deletions gateloop_transformer/associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# will be adapted to test out GateLoop on a small scale https://arxiv.org/abs/2311.01927

import torch
import torch.nn.functional as F
from functools import partial
from optree import tree_flatten, tree_unflatten

Expand Down Expand Up @@ -87,3 +88,19 @@ def _scan(elems):
scans = _scan(elems_flat)

return tree_unflatten(tree, scans)

def _interleave(a, b, axis):
# https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors
if b_trunc := (a.shape[axis] == b.shape[axis] + 1):
pad = [0, 0] * b.ndim
pad[(b.ndim-axis-1)*2+1] = 1 # +1=always end of dim, pad-order is reversed so start is at end
b = F.pad(b, pad)

stacked = torch.stack([a, b], dim=axis+1)
interleaved = torch.flatten(stacked, start_dim=axis, end_dim=axis+1)

if not b_trunc:
return interleaved

# TODO: find torch alternative for slice_along axis for torch.jit.script to work
return interleaved[slice_along_axis(0, b.shape[axis]+a.shape[axis]-1, axis=axis)]
3 changes: 2 additions & 1 deletion gateloop_transformer/gateloop_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def binary_operator(a, b):

return a_j * a_i, a_j.real * kv_i + kv_j

_, kv = associative_scan(binary_operator, (a, kv))
a = rearrange(a, '... -> ... 1')
_, kv = associative_scan(binary_operator, (a, kv), axis = 1)

return einsum('b n d, b n d e -> b n e', q, kv)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'gateloop-transformer',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'GateLoop Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 044c266

Please sign in to comment.