Skip to content

Commit

Permalink
Update docstring for einops.reduce (#358)
Browse files Browse the repository at this point in the history
* Update docstring for einops.reduce

* rm trailing whitespace

* please doctest, add more comments
  • Loading branch information
arogozhnikov authored Jan 8, 2025
1 parent 1bee724 commit a462d07
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,40 +459,48 @@ def _prepare_recipes_for_all_dims(

def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor:
"""
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
einops.reduce combines rearrangement and reduction using reader-friendly notation.
Examples for reduce operation:
Some examples:
```python
>>> x = np.random.randn(100, 32, 64)
# perform max-reduction on the first axis
# Axis t does not appear on RHS - thus we reduced over t
>>> y = reduce(x, 't b c -> b c', 'max')
# same as previous, but with clearer axes meaning
# same as previous, but using verbose names for axes
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
# let's pretend now that x is a batch of images
# with 4 dims: batch=10, height=20, width=30, channel=40
>>> x = np.random.randn(10, 20, 30, 40)
# 2d max-pooling with kernel size = 2 * 2 for image processing
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# if one wants to go back to the original height and width, depth-to-space trick can be applied
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
# same as previous, using anonymous axes,
# note: only reduced axes can be anonymous
>>> y1 = reduce(x, 'b c (h1 2) (w1 2) -> b c h1 w1', 'max')
# Adaptive 2d max-pooling to 3 * 4 grid
# adaptive 2d max-pooling to 3 * 4 grid,
# each element is max of 10x10 tile in the original tensor.
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
(10, 20, 3, 4)
# Global average pooling
>>> reduce(x, 'b c h w -> b c', 'mean').shape
(10, 20)
# Subtracting mean over batch for each channel
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
# subtracting mean over batch for each channel;
# similar to x - np.mean(x, axis=(0, 2, 3), keepdims=True)
>>> y = x - reduce(x, 'b c h w -> 1 c 1 1', 'mean')
# Subtracting per-image mean for each channel
>>> y = x - reduce(x, 'b c h w -> b c 1 1', 'mean')
# same as previous, but using empty compositions
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
```
Expand All @@ -501,9 +509,9 @@ def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reducti
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, reduction pattern
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc.
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod', 'any', 'all').
Alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
This allows using various reductions like: np.max, np.nanmean, tf.reduce_logsumexp, torch.var, etc.
axes_lengths: any additional specifications for dimensions
Returns:
Expand Down Expand Up @@ -540,7 +548,7 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths:
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
stack, concatenate and other operations.
Examples for rearrange operation:
Examples:
```python
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
Expand Down Expand Up @@ -595,7 +603,7 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths:
def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
"""
einops.repeat allows reordering elements and repeating them in arbitrary combinations.
This operation includes functionality of repeat, tile, broadcast functions.
This operation includes functionality of repeat, tile, and broadcast functions.
Examples for repeat operation:
Expand Down

0 comments on commit a462d07

Please sign in to comment.