Skip to content

Commit

Permalink
Add support of ellipsis to einmix (#360)
Browse files Browse the repository at this point in the history
* add ellipsis to einmix

* add tests for different failure modes

* update documentation for einmix
  • Loading branch information
arogozhnikov authored Jan 11, 2025
1 parent 9ac21cd commit 253545a
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 28 deletions.
107 changes: 80 additions & 27 deletions einops/layers/_einmix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List, Optional, Dict

from einops import EinopsError
from einops.parsing import ParsedExpression
from einops.parsing import ParsedExpression, _ellipsis
import warnings
import string
from ..einops import _product
Expand All @@ -17,21 +17,21 @@ def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] =
"""
EinMix - Einstein summation with automated tensor management and axis packing/unpacking.
EinMix is an advanced tool, helpful tutorial:
EinMix is a combination of einops and MLP, see tutorial:
https://github.com/arogozhnikov/einops/blob/main/docs/3-einmix-layer.ipynb
Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
>>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)
This layer manages weights for you, syntax highlights separate role of weight matrix
This layer manages weights for you, syntax highlights a special role of weight matrix
>>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
But otherwise it is the same einsum under the hood.
But otherwise it is the same einsum under the hood. Plus einops-rearrange.
Simple linear layer with bias term (you have one like that in your framework)
Simple linear layer with a bias term (you have one like that in your framework)
>>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
There is no restriction to mix the last axis. Let's mix along height
>>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
Channel-wise multiplication (like one used in normalizations)
Example of channel-wise multiplication (like one used in normalizations)
>>> EinMix('t b c -> t b c', weight_shape='c', c=128)
Multi-head linear layer (each head is own linear layer):
>>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)
Expand All @@ -42,14 +42,16 @@ def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] =
- when channel dimension is not last, use EinMix, not transposition
- patch/segment embeddings
- when need only within-group connections to reduce number of weights and computations
- perfect as a part of sequential models
- next-gen MLPs (follow tutorial to learn more!)
- next-gen MLPs (follow tutorial link above to learn more!)
- in general, any time you want to combine linear layer and einops.rearrange
Uniform He initialization is applied to weight tensor. This accounts for number of elements mixed.
Uniform He initialization is applied to weight tensor.
This accounts for the number of elements mixed and produced.
Parameters
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
:param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer
If bias_shape is not specified, bias is not created.
:param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added.
:param axes_lengths: dimensions of weight tensor
"""
Expand All @@ -71,9 +73,13 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
"Unrecognized identifiers on the right side of EinMix {}",
)

if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis:
raise EinopsError("Ellipsis is not supported in EinMix (right now)")
if weight.has_ellipsis:
raise EinopsError("Ellipsis is not supported in weight, as its shape should be fully specified")
if left.has_ellipsis or right.has_ellipsis:
if not (left.has_ellipsis and right.has_ellipsis):
raise EinopsError(f"Ellipsis in EinMix should be on both sides, {pattern}")
if left.has_ellipsis_parenthesized:
raise EinopsError(f"Ellipsis on left side can't be in parenthesis, got {pattern}")
if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
raise EinopsError("Anonymous axes (numbers) are not allowed in EinMix")
if "(" in weight_shape or ")" in weight_shape:
Expand All @@ -86,16 +92,18 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
names: List[str] = []
for group in left.composition:
names += group
names = [name if name != _ellipsis else "..." for name in names]
composition = " ".join(names)
pre_reshape_pattern = f"{left_pattern}->{composition}"
pre_reshape_pattern = f"{left_pattern}-> {composition}"
pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names}

if any(len(group) != 1 for group in right.composition):
if any(len(group) != 1 for group in right.composition) or right.has_ellipsis_parenthesized:
names = []
for group in right.composition:
names += group
names = [name if name != _ellipsis else "..." for name in names]
composition = " ".join(names)
post_reshape_pattern = f"{composition}->{right_pattern}"
post_reshape_pattern = f"{composition} ->{right_pattern}"

self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})

Expand All @@ -116,22 +124,36 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
# single output element is a combination of fan_in input elements
_fan_in = _product([axes_lengths[axis] for (axis,) in weight.composition if axis not in right.identifiers])
if bias_shape is not None:
# maybe I should put ellipsis in the beginning for simplicity?
if not isinstance(bias_shape, str):
raise EinopsError("bias shape should be string specifying which axes bias depends on")
bias = ParsedExpression(bias_shape)
_report_axes(set.difference(bias.identifiers, right.identifiers), "Bias axes {} not present in output")
_report_axes(
set.difference(bias.identifiers, right.identifiers),
"Bias axes {} not present in output",
)
_report_axes(
set.difference(bias.identifiers, set(axes_lengths)),
"Sizes not provided for bias axes {}",
)

_bias_shape = []
used_non_trivial_size = False
for axes in right.composition:
for axis in axes:
if axis in bias.identifiers:
_bias_shape.append(axes_lengths[axis])
else:
_bias_shape.append(1)
if axes == _ellipsis:
if used_non_trivial_size:
raise EinopsError("all bias dimensions should go after ellipsis in the output")
else:
# handles ellipsis correctly
for axis in axes:
if axis == _ellipsis:
if used_non_trivial_size:
raise EinopsError("all bias dimensions should go after ellipsis in the output")
elif axis in bias.identifiers:
_bias_shape.append(axes_lengths[axis])
used_non_trivial_size = True
else:
_bias_shape.append(1)
else:
_bias_shape = None

Expand All @@ -142,15 +164,26 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
# rewrite einsum expression with single-letter latin identifiers so that
# expression will be understood by any framework
mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers}
if _ellipsis in mapped_identifiers:
mapped_identifiers.remove(_ellipsis)
mapped_identifiers = list(sorted(mapped_identifiers))
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)}

def write_flat(axes: list):
return "".join(mapping2letters[axis] for axis in axes)
mapping2letters[_ellipsis] = "..." # preserve ellipsis

def write_flat_remapped(axes: ParsedExpression):
result = []
for composed_axis in axes.composition:
if isinstance(composed_axis, list):
result.extend([mapping2letters[axis] for axis in composed_axis])
else:
assert composed_axis == _ellipsis
result.append("...")
return "".join(result)

self.einsum_pattern: str = "{},{}->{}".format(
write_flat(left.flat_axes_order()),
write_flat(weight.flat_axes_order()),
write_flat(right.flat_axes_order()),
write_flat_remapped(left),
write_flat_remapped(weight),
write_flat_remapped(right),
)

def _create_rearrange_layers(
Expand All @@ -174,3 +207,23 @@ def __repr__(self):
for axis, length in self.axes_lengths.items():
params += ", {}={}".format(axis, length)
return "{}({})".format(self.__class__.__name__, params)


class _EinmixDebugger(_EinmixMixin):
"""Used only to test mixin"""

def _create_rearrange_layers(
self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_reshape_pattern = pre_reshape_pattern
self.pre_reshape_lengths = pre_reshape_lengths
self.post_reshape_pattern = post_reshape_pattern
self.post_reshape_lengths = post_reshape_lengths

def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
self.saved_weight_shape = weight_shape
self.saved_bias_shape = bias_shape
126 changes: 125 additions & 1 deletion einops/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy
import pytest

from einops import rearrange, reduce
from einops import rearrange, reduce, EinopsError
from einops.tests import collect_test_backends, is_backend_tested, FLOAT_REDUCTIONS as REDUCTIONS

__author__ = "Alex Rogozhnikov"
Expand Down Expand Up @@ -343,3 +343,127 @@ def eval_at_point(params):
# check serialization
fbytes = flax.serialization.to_bytes(params)
_loaded = flax.serialization.from_bytes(params, fbytes)


def test_einmix_decomposition():
"""
Testing that einmix correctly decomposes into smaller transformations.
"""
from einops.layers._einmix import _EinmixDebugger

mixin1 = _EinmixDebugger(
"a b c d e -> e d c b a",
weight_shape="d a b",
d=2, a=3, b=5,
) # fmt: off
assert mixin1.pre_reshape_pattern is None
assert mixin1.post_reshape_pattern is None
assert mixin1.einsum_pattern == "abcde,dab->edcba"
assert mixin1.saved_weight_shape == [2, 3, 5]
assert mixin1.saved_bias_shape is None

mixin2 = _EinmixDebugger(
"a b c d e -> e d c b a",
weight_shape="d a b",
bias_shape="a b c d e",
a=1, b=2, c=3, d=4, e=5,
) # fmt: off
assert mixin2.pre_reshape_pattern is None
assert mixin2.post_reshape_pattern is None
assert mixin2.einsum_pattern == "abcde,dab->edcba"
assert mixin2.saved_weight_shape == [4, 1, 2]
assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1]

mixin3 = _EinmixDebugger(
"... -> ...",
weight_shape="",
bias_shape="",
) # fmt: off
assert mixin3.pre_reshape_pattern is None
assert mixin3.post_reshape_pattern is None
assert mixin3.einsum_pattern == "...,->..."
assert mixin3.saved_weight_shape == []
assert mixin3.saved_bias_shape == []

mixin4 = _EinmixDebugger(
"b a ... -> b c ...",
weight_shape="b a c",
a=1, b=2, c=3,
) # fmt: off
assert mixin4.pre_reshape_pattern is None
assert mixin4.post_reshape_pattern is None
assert mixin4.einsum_pattern == "ba...,bac->bc..."
assert mixin4.saved_weight_shape == [2, 1, 3]
assert mixin4.saved_bias_shape is None

mixin5 = _EinmixDebugger(
"(b a) ... -> b c (...)",
weight_shape="b a c",
a=1, b=2, c=3,
) # fmt: off
assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..."
assert mixin5.pre_reshape_lengths == dict(a=1, b=2)
assert mixin5.post_reshape_pattern == "b c ... -> b c (...)"
assert mixin5.einsum_pattern == "ba...,bac->bc..."
assert mixin5.saved_weight_shape == [2, 1, 3]
assert mixin5.saved_bias_shape is None

mixin6 = _EinmixDebugger(
"b ... (a c) -> b ... (a d)",
weight_shape="c d",
bias_shape="a d",
a=1, c=3, d=4,
) # fmt: off
assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c"
assert mixin6.pre_reshape_lengths == dict(a=1, c=3)
assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)"
assert mixin6.einsum_pattern == "b...ac,cd->b...ad"
assert mixin6.saved_weight_shape == [3, 4]
assert mixin6.saved_bias_shape == [1, 1, 4] # (b) a d, ellipsis does not participate

mixin7 = _EinmixDebugger(
"a ... (b c) -> a (... d b)",
weight_shape="c d b",
bias_shape="d b",
b=2, c=3, d=4,
) # fmt: off
assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c"
assert mixin7.pre_reshape_lengths == dict(b=2, c=3)
assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)"
assert mixin7.einsum_pattern == "a...bc,cdb->a...db"
assert mixin7.saved_weight_shape == [3, 4, 2]
assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate


def test_einmix_restrictions():
"""
Testing different cases
"""
from einops.layers._einmix import _EinmixDebugger

with pytest.raises(EinopsError):
_EinmixDebugger(
"a b c d e -> e d c b a",
weight_shape="d a b",
d=2, a=3, # missing b
) # fmt: off

with pytest.raises(EinopsError):
_EinmixDebugger(
"a b c d e -> e d c b a",
weight_shape="w a b",
d=2, a=3, b=1 # missing d
) # fmt: off

with pytest.raises(EinopsError):
_EinmixDebugger(
"(...) a -> ... a",
weight_shape="a", a=1, # ellipsis on the left
) # fmt: off

with pytest.raises(EinopsError):
_EinmixDebugger(
"(...) a -> a ...",
weight_shape="a", a=1, # ellipsis on the right side after bias axis
bias_shape='a',
) # fmt: off

0 comments on commit 253545a

Please sign in to comment.