diff --git a/gateloop_transformer/associative_scan.py b/gateloop_transformer/associative_scan.py index 0c782b1..687626e 100644 --- a/gateloop_transformer/associative_scan.py +++ b/gateloop_transformer/associative_scan.py @@ -4,6 +4,7 @@ # will be adapted to test out GateLoop on a small scale https://arxiv.org/abs/2311.01927 import torch +import torch.nn.functional as F from functools import partial from optree import tree_flatten, tree_unflatten @@ -87,3 +88,19 @@ def _scan(elems): scans = _scan(elems_flat) return tree_unflatten(tree, scans) + +def _interleave(a, b, axis): + # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors + if b_trunc := (a.shape[axis] == b.shape[axis] + 1): + pad = [0, 0] * b.ndim + pad[(b.ndim-axis-1)*2+1] = 1 # +1=always end of dim, pad-order is reversed so start is at end + b = F.pad(b, pad) + + stacked = torch.stack([a, b], dim=axis+1) + interleaved = torch.flatten(stacked, start_dim=axis, end_dim=axis+1) + + if not b_trunc: + return interleaved + + # TODO: find torch alternative for slice_along axis for torch.jit.script to work + return interleaved[slice_along_axis(0, b.shape[axis]+a.shape[axis]-1, axis=axis)] diff --git a/gateloop_transformer/gateloop_transformer.py b/gateloop_transformer/gateloop_transformer.py index 7e4d3f5..c369ca5 100644 --- a/gateloop_transformer/gateloop_transformer.py +++ b/gateloop_transformer/gateloop_transformer.py @@ -143,7 +143,8 @@ def binary_operator(a, b): return a_j * a_i, a_j.real * kv_i + kv_j - _, kv = associative_scan(binary_operator, (a, kv)) + a = rearrange(a, '... -> ... 1') + _, kv = associative_scan(binary_operator, (a, kv), axis = 1) return einsum('b n d, b n d e -> b n e', q, kv) diff --git a/setup.py b/setup.py index 9b687d6..18d17f7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'gateloop-transformer', packages = find_packages(exclude=[]), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'GateLoop Transformer', author = 'Phil Wang',