Skip to content

Commit 4b89716

Browse files
jessegrabowskiricardoV94
authored andcommitted
Implement Pack and Unpack
1 parent a0be97e commit 4b89716

File tree

2 files changed

+455
-7
lines changed

2 files changed

+455
-7
lines changed

pytensor/tensor/reshape.py

Lines changed: 266 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from collections.abc import Sequence
1+
from collections.abc import Iterable, Sequence
2+
from itertools import pairwise
23
from typing import cast as type_cast
34

45
import numpy as np
@@ -9,7 +10,7 @@
910
from pytensor.graph.op import Op
1011
from pytensor.graph.replace import _vectorize_node
1112
from pytensor.tensor import TensorLike, as_tensor_variable
12-
from pytensor.tensor.basic import infer_static_shape
13+
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
1314
from pytensor.tensor.math import prod
1415
from pytensor.tensor.shape import ShapeValueType
1516
from pytensor.tensor.type import tensor
@@ -152,12 +153,12 @@ def __init__(self, axis: int):
152153
self.axis = axis
153154

154155
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
155-
if shape.type.numpy_dtype.kind not in "iu":
156-
raise TypeError("shape must be an integer tensor")
157-
158156
x = as_tensor_variable(x)
159157
shape = as_tensor_variable(shape, dtype=int, ndim=1)
160158

159+
if shape.type.numpy_dtype.kind not in "iu":
160+
raise TypeError("shape must be an integer tensor")
161+
161162
axis = self.axis
162163
_, constant_shape = infer_static_shape(shape)
163164

@@ -261,4 +262,263 @@ def split_dims(
261262
return type_cast(TensorVariable, split_op(x, shape))
262263

263264

264-
__all__ = ["join_dims", "split_dims"]
265+
def _analyze_axes_list(axes) -> tuple[int, int, int]:
266+
"""
267+
Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as
268+
well as the minimum and maximum number of axes that the inputs can have.
269+
270+
The rules are:
271+
- Axes must be strictly increasing in both the positive and negative parts of the list.
272+
- Negative axes must come after positive axes.
273+
- There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint
274+
(e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]).
275+
276+
Returns
277+
-------
278+
n_axes_before: int
279+
The number of axes before the interval to be raveled.
280+
n_axes_after: int
281+
The number of axes after the interval to be raveled.
282+
min_axes: int
283+
The minimum number of axes that the inputs must have.
284+
"""
285+
if axes is None:
286+
return 0, 0, 0
287+
288+
if isinstance(axes, int):
289+
axes = (axes,)
290+
elif not isinstance(axes, Iterable):
291+
raise TypeError("axes must be an int, an iterable of ints, or None")
292+
293+
axes = tuple(axes)
294+
295+
if len(axes) == 0:
296+
raise ValueError("axes=[] is ambiguous; use None to ravel all")
297+
298+
if len(set(axes)) != len(axes):
299+
raise ValueError("axes must have no duplicates")
300+
301+
first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes))
302+
positive_axes = list(axes[:first_negative_idx])
303+
negative_axes = list(axes[first_negative_idx:])
304+
305+
if not all(a < 0 for a in negative_axes):
306+
raise ValueError("Negative axes must come after positive")
307+
308+
def not_strictly_increasing(s):
309+
if len(s) < 1:
310+
return False
311+
return any(b <= a for a, b in pairwise(s))
312+
313+
if not_strictly_increasing(positive_axes):
314+
raise ValueError("Axes must be strictly increasing in the positive part")
315+
if not_strictly_increasing(negative_axes):
316+
raise ValueError("Axes must be strictly increasing in the negative part")
317+
318+
def find_gaps(s):
319+
"""Find if there are gaps in a strictly increasing sequence."""
320+
return any(b - a > 1 for a, b in pairwise(s))
321+
322+
if find_gaps(positive_axes):
323+
raise ValueError("Positive axes must be contiguous")
324+
if find_gaps(negative_axes):
325+
raise ValueError("Negative axes must be contiguous")
326+
327+
if positive_axes and positive_axes[0] != 0:
328+
raise ValueError(
329+
"If positive axes are provided, the first positive axis must be 0 to avoid ambiguity. To ravel indices "
330+
"starting from the front, use negative axes only."
331+
)
332+
333+
if negative_axes and negative_axes[-1] != -1:
334+
raise ValueError(
335+
"If negative axes are provided, the last negative axis must be -1 to avoid ambiguity. To ravel indices "
336+
"up to the end, use positive axes only."
337+
)
338+
339+
n_before = len(positive_axes)
340+
n_after = len(negative_axes)
341+
min_axes = n_before + n_after
342+
343+
return n_before, n_after, min_axes
344+
345+
346+
def pack(
347+
*tensors: TensorLike, axes: Sequence[int] | int | None = None
348+
) -> tuple[TensorVariable, list[ShapeValueType]]:
349+
"""
350+
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.
351+
352+
Parameters
353+
----------
354+
*tensors : TensorLike
355+
Input tensors to be packed.
356+
axes : int, sequence of int, or None, optional
357+
Axes to preserve during packing. If None, all axes are raveled. See the Notes section for the rules.
358+
359+
Returns
360+
-------
361+
packed_tensor : TensorLike
362+
The packed tensor with specified axes preserved and others raveled.
363+
packed_shapes : list of ShapeValueType
364+
A list containing the shapes of the raveled dimensions for each input tensor.
365+
366+
Notes
367+
-----
368+
The `axes` parameter determines which axes are preserved during packing. Axes can be specified using positive or
369+
negative indices, but must follow these rules:
370+
- If axes is None, all axes are raveled.
371+
- If a single integer is provided, it can be positive or negative, and can take any value up to the smallest
372+
number of dimensions among the input tensors.
373+
- If a list is provided, it can be all positive, all negative, or a combination of positive and negative.
374+
- Positive axes must be contiguous and start from 0.
375+
- Negative axes must be contiguous and end at -1.
376+
- If positive and negative axes are combined, positive axes must come before negative axes, and both 0 and -1
377+
must be included.
378+
379+
Examples
380+
--------
381+
The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent
382+
to ``join(0, *[t.ravel() for t in tensors])``:
383+
384+
.. code-block:: python
385+
import pytensor.tensor as pt
386+
387+
x = pt.tensor("x", shape=(2, 3))
388+
y = pt.tensor("y", shape=(4, 5, 6))
389+
390+
packed_tensor, packed_shapes = pt.pack(x, y, axes=None)
391+
# packed_tensor has shape (6 + 120,) == (126,)
392+
# packed_shapes is [(2, 3), (4, 5, 6)]
393+
394+
If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors
395+
must have the same size along the preserved axis. For example, using axes=0:
396+
397+
.. code-block:: python
398+
import pytensor.tensor as pt
399+
400+
x = pt.tensor("x", shape=(2, 3))
401+
y = pt.tensor("y", shape=(2, 5, 6))
402+
packed_tensor, packed_shapes = pt.pack(x, y, axes=0)
403+
# packed_tensor has shape (2, 3 + 30) == (2, 33)
404+
# packed_shapes is [(3,), (5, 6)]
405+
406+
407+
Using negative indexing we can preserve the last two axes:
408+
409+
.. code-block:: python
410+
import pytensor.tensor as pt
411+
412+
x = pt.tensor("x", shape=(4, 2, 3))
413+
y = pt.tensor("y", shape=(5, 2, 3))
414+
packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1))
415+
# packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3)
416+
# packed_shapes is [(4,), (5,
417+
418+
Or using a mix of positive and negative axes, we can preserve the first and last axes:
419+
420+
.. code-block:: python
421+
import pytensor.tensor as pt
422+
423+
x = pt.tensor("x", shape=(2, 4, 3))
424+
y = pt.tensor("y", shape=(2, 5, 3))
425+
packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1))
426+
# packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3)
427+
# packed_shapes is [(4,), (5,)]
428+
"""
429+
tensor_list = [as_tensor_variable(t) for t in tensors]
430+
431+
n_before, n_after, min_axes = _analyze_axes_list(axes)
432+
433+
reshaped_tensors: list[TensorVariable] = []
434+
packed_shapes: list[ShapeValueType] = []
435+
436+
for i, input_tensor in enumerate(tensor_list):
437+
n_dim = input_tensor.ndim
438+
439+
if n_dim < min_axes:
440+
raise ValueError(
441+
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
442+
f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
443+
)
444+
n_after_packed = n_dim - n_after
445+
packed_shapes.append(input_tensor.shape[n_before:n_after_packed])
446+
447+
if n_dim == min_axes:
448+
# If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
449+
# implied by the axes.
450+
input_tensor = expand_dims(input_tensor, axis=n_before)
451+
reshaped_tensors.append(input_tensor)
452+
continue
453+
454+
# The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
455+
# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
456+
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
457+
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
458+
join_axes = range(n_before, n_after_packed)
459+
joined = join_dims(input_tensor, tuple(join_axes))
460+
reshaped_tensors.append(joined)
461+
462+
return join(n_before, *reshaped_tensors), packed_shapes
463+
464+
465+
def unpack(
466+
packed_input: TensorLike,
467+
axes: int | Sequence[int] | None,
468+
packed_shapes: list[ShapeValueType],
469+
) -> list[TensorVariable]:
470+
"""
471+
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
472+
473+
The unpacking process reverses the packing operation, restoring the original shapes of the input tensors. `axes`
474+
corresponds to the axes that were preserved during packing, and `packed_shapes` contains the shapes of the raveled
475+
dimensions for each output tensor (that is, the shapes that were destroyed during packing).
476+
477+
The signature of unpack is such that the same `axes` should be passed to both `pack` and `unpack` to create a
478+
"round-trip" operation. For details on the rules for `axes`, see the documentation for `pack`.
479+
480+
Parameters
481+
----------
482+
packed_input : TensorLike
483+
The packed tensor to be unpacked.
484+
axes : int, sequence of int, or None
485+
Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used.
486+
packed_shapes : list of ShapeValueType
487+
A list containing the shapes of the raveled dimensions for each output tensor.
488+
489+
Returns
490+
-------
491+
unpacked_tensors : list of TensorLike
492+
A list of unpacked tensors with their original shapes restored.
493+
"""
494+
packed_input = as_tensor_variable(packed_input)
495+
496+
if axes is None:
497+
if packed_input.ndim != 1:
498+
raise ValueError(
499+
"unpack can only be called with keep_axis=None for 1d inputs"
500+
)
501+
split_axis = 0
502+
else:
503+
axes = normalize_axis_tuple(axes, ndim=packed_input.ndim)
504+
try:
505+
[split_axis] = (i for i in range(packed_input.ndim) if i not in axes)
506+
except ValueError as err:
507+
raise ValueError(
508+
"Unpack must have exactly one more dimension that implied by axes"
509+
) from err
510+
511+
split_inputs = split(
512+
packed_input,
513+
splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
514+
n_splits=len(packed_shapes),
515+
axis=split_axis,
516+
)
517+
518+
return [
519+
split_dims(inp, shape, split_axis)
520+
for inp, shape in zip(split_inputs, packed_shapes, strict=True)
521+
]
522+
523+
524+
__all__ = ["join_dims", "pack", "split_dims", "unpack"]

0 commit comments

Comments
 (0)