|
1 | | -from collections.abc import Sequence |
| 1 | +from collections.abc import Iterable, Sequence |
| 2 | +from itertools import pairwise |
2 | 3 | from typing import cast as type_cast |
3 | 4 |
|
4 | 5 | import numpy as np |
|
9 | 10 | from pytensor.graph.op import Op |
10 | 11 | from pytensor.graph.replace import _vectorize_node |
11 | 12 | from pytensor.tensor import TensorLike, as_tensor_variable |
12 | | -from pytensor.tensor.basic import infer_static_shape |
| 13 | +from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split |
13 | 14 | from pytensor.tensor.math import prod |
14 | 15 | from pytensor.tensor.shape import ShapeValueType |
15 | 16 | from pytensor.tensor.type import tensor |
@@ -152,12 +153,12 @@ def __init__(self, axis: int): |
152 | 153 | self.axis = axis |
153 | 154 |
|
154 | 155 | def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override] |
155 | | - if shape.type.numpy_dtype.kind not in "iu": |
156 | | - raise TypeError("shape must be an integer tensor") |
157 | | - |
158 | 156 | x = as_tensor_variable(x) |
159 | 157 | shape = as_tensor_variable(shape, dtype=int, ndim=1) |
160 | 158 |
|
| 159 | + if shape.type.numpy_dtype.kind not in "iu": |
| 160 | + raise TypeError("shape must be an integer tensor") |
| 161 | + |
161 | 162 | axis = self.axis |
162 | 163 | _, constant_shape = infer_static_shape(shape) |
163 | 164 |
|
@@ -261,4 +262,263 @@ def split_dims( |
261 | 262 | return type_cast(TensorVariable, split_op(x, shape)) |
262 | 263 |
|
263 | 264 |
|
264 | | -__all__ = ["join_dims", "split_dims"] |
| 265 | +def _analyze_axes_list(axes) -> tuple[int, int, int]: |
| 266 | + """ |
| 267 | + Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as |
| 268 | + well as the minimum and maximum number of axes that the inputs can have. |
| 269 | +
|
| 270 | + The rules are: |
| 271 | + - Axes must be strictly increasing in both the positive and negative parts of the list. |
| 272 | + - Negative axes must come after positive axes. |
| 273 | + - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint |
| 274 | + (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]). |
| 275 | +
|
| 276 | + Returns |
| 277 | + ------- |
| 278 | + n_axes_before: int |
| 279 | + The number of axes before the interval to be raveled. |
| 280 | + n_axes_after: int |
| 281 | + The number of axes after the interval to be raveled. |
| 282 | + min_axes: int |
| 283 | + The minimum number of axes that the inputs must have. |
| 284 | + """ |
| 285 | + if axes is None: |
| 286 | + return 0, 0, 0 |
| 287 | + |
| 288 | + if isinstance(axes, int): |
| 289 | + axes = (axes,) |
| 290 | + elif not isinstance(axes, Iterable): |
| 291 | + raise TypeError("axes must be an int, an iterable of ints, or None") |
| 292 | + |
| 293 | + axes = tuple(axes) |
| 294 | + |
| 295 | + if len(axes) == 0: |
| 296 | + raise ValueError("axes=[] is ambiguous; use None to ravel all") |
| 297 | + |
| 298 | + if len(set(axes)) != len(axes): |
| 299 | + raise ValueError("axes must have no duplicates") |
| 300 | + |
| 301 | + first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes)) |
| 302 | + positive_axes = list(axes[:first_negative_idx]) |
| 303 | + negative_axes = list(axes[first_negative_idx:]) |
| 304 | + |
| 305 | + if not all(a < 0 for a in negative_axes): |
| 306 | + raise ValueError("Negative axes must come after positive") |
| 307 | + |
| 308 | + def not_strictly_increasing(s): |
| 309 | + if len(s) < 1: |
| 310 | + return False |
| 311 | + return any(b <= a for a, b in pairwise(s)) |
| 312 | + |
| 313 | + if not_strictly_increasing(positive_axes): |
| 314 | + raise ValueError("Axes must be strictly increasing in the positive part") |
| 315 | + if not_strictly_increasing(negative_axes): |
| 316 | + raise ValueError("Axes must be strictly increasing in the negative part") |
| 317 | + |
| 318 | + def find_gaps(s): |
| 319 | + """Find if there are gaps in a strictly increasing sequence.""" |
| 320 | + return any(b - a > 1 for a, b in pairwise(s)) |
| 321 | + |
| 322 | + if find_gaps(positive_axes): |
| 323 | + raise ValueError("Positive axes must be contiguous") |
| 324 | + if find_gaps(negative_axes): |
| 325 | + raise ValueError("Negative axes must be contiguous") |
| 326 | + |
| 327 | + if positive_axes and positive_axes[0] != 0: |
| 328 | + raise ValueError( |
| 329 | + "If positive axes are provided, the first positive axis must be 0 to avoid ambiguity. To ravel indices " |
| 330 | + "starting from the front, use negative axes only." |
| 331 | + ) |
| 332 | + |
| 333 | + if negative_axes and negative_axes[-1] != -1: |
| 334 | + raise ValueError( |
| 335 | + "If negative axes are provided, the last negative axis must be -1 to avoid ambiguity. To ravel indices " |
| 336 | + "up to the end, use positive axes only." |
| 337 | + ) |
| 338 | + |
| 339 | + n_before = len(positive_axes) |
| 340 | + n_after = len(negative_axes) |
| 341 | + min_axes = n_before + n_after |
| 342 | + |
| 343 | + return n_before, n_after, min_axes |
| 344 | + |
| 345 | + |
| 346 | +def pack( |
| 347 | + *tensors: TensorLike, axes: Sequence[int] | int | None = None |
| 348 | +) -> tuple[TensorVariable, list[ShapeValueType]]: |
| 349 | + """ |
| 350 | + Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. |
| 351 | +
|
| 352 | + Parameters |
| 353 | + ---------- |
| 354 | + *tensors : TensorLike |
| 355 | + Input tensors to be packed. |
| 356 | + axes : int, sequence of int, or None, optional |
| 357 | + Axes to preserve during packing. If None, all axes are raveled. See the Notes section for the rules. |
| 358 | +
|
| 359 | + Returns |
| 360 | + ------- |
| 361 | + packed_tensor : TensorLike |
| 362 | + The packed tensor with specified axes preserved and others raveled. |
| 363 | + packed_shapes : list of ShapeValueType |
| 364 | + A list containing the shapes of the raveled dimensions for each input tensor. |
| 365 | +
|
| 366 | + Notes |
| 367 | + ----- |
| 368 | + The `axes` parameter determines which axes are preserved during packing. Axes can be specified using positive or |
| 369 | + negative indices, but must follow these rules: |
| 370 | + - If axes is None, all axes are raveled. |
| 371 | + - If a single integer is provided, it can be positive or negative, and can take any value up to the smallest |
| 372 | + number of dimensions among the input tensors. |
| 373 | + - If a list is provided, it can be all positive, all negative, or a combination of positive and negative. |
| 374 | + - Positive axes must be contiguous and start from 0. |
| 375 | + - Negative axes must be contiguous and end at -1. |
| 376 | + - If positive and negative axes are combined, positive axes must come before negative axes, and both 0 and -1 |
| 377 | + must be included. |
| 378 | +
|
| 379 | + Examples |
| 380 | + -------- |
| 381 | + The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent |
| 382 | + to ``join(0, *[t.ravel() for t in tensors])``: |
| 383 | +
|
| 384 | + .. code-block:: python |
| 385 | + import pytensor.tensor as pt |
| 386 | +
|
| 387 | + x = pt.tensor("x", shape=(2, 3)) |
| 388 | + y = pt.tensor("y", shape=(4, 5, 6)) |
| 389 | +
|
| 390 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=None) |
| 391 | + # packed_tensor has shape (6 + 120,) == (126,) |
| 392 | + # packed_shapes is [(2, 3), (4, 5, 6)] |
| 393 | +
|
| 394 | + If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors |
| 395 | + must have the same size along the preserved axis. For example, using axes=0: |
| 396 | +
|
| 397 | + .. code-block:: python |
| 398 | + import pytensor.tensor as pt |
| 399 | +
|
| 400 | + x = pt.tensor("x", shape=(2, 3)) |
| 401 | + y = pt.tensor("y", shape=(2, 5, 6)) |
| 402 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=0) |
| 403 | + # packed_tensor has shape (2, 3 + 30) == (2, 33) |
| 404 | + # packed_shapes is [(3,), (5, 6)] |
| 405 | +
|
| 406 | +
|
| 407 | + Using negative indexing we can preserve the last two axes: |
| 408 | +
|
| 409 | + .. code-block:: python |
| 410 | + import pytensor.tensor as pt |
| 411 | +
|
| 412 | + x = pt.tensor("x", shape=(4, 2, 3)) |
| 413 | + y = pt.tensor("y", shape=(5, 2, 3)) |
| 414 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1)) |
| 415 | + # packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3) |
| 416 | + # packed_shapes is [(4,), (5, |
| 417 | +
|
| 418 | + Or using a mix of positive and negative axes, we can preserve the first and last axes: |
| 419 | +
|
| 420 | + .. code-block:: python |
| 421 | + import pytensor.tensor as pt |
| 422 | +
|
| 423 | + x = pt.tensor("x", shape=(2, 4, 3)) |
| 424 | + y = pt.tensor("y", shape=(2, 5, 3)) |
| 425 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1)) |
| 426 | + # packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3) |
| 427 | + # packed_shapes is [(4,), (5,)] |
| 428 | + """ |
| 429 | + tensor_list = [as_tensor_variable(t) for t in tensors] |
| 430 | + |
| 431 | + n_before, n_after, min_axes = _analyze_axes_list(axes) |
| 432 | + |
| 433 | + reshaped_tensors: list[TensorVariable] = [] |
| 434 | + packed_shapes: list[ShapeValueType] = [] |
| 435 | + |
| 436 | + for i, input_tensor in enumerate(tensor_list): |
| 437 | + n_dim = input_tensor.ndim |
| 438 | + |
| 439 | + if n_dim < min_axes: |
| 440 | + raise ValueError( |
| 441 | + f"Input {i} (zero indexed) to pack has {n_dim} dimensions, " |
| 442 | + f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." |
| 443 | + ) |
| 444 | + n_after_packed = n_dim - n_after |
| 445 | + packed_shapes.append(input_tensor.shape[n_before:n_after_packed]) |
| 446 | + |
| 447 | + if n_dim == min_axes: |
| 448 | + # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern |
| 449 | + # implied by the axes. |
| 450 | + input_tensor = expand_dims(input_tensor, axis=n_before) |
| 451 | + reshaped_tensors.append(input_tensor) |
| 452 | + continue |
| 453 | + |
| 454 | + # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1, |
| 455 | + # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the |
| 456 | + # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the |
| 457 | + # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing). |
| 458 | + join_axes = range(n_before, n_after_packed) |
| 459 | + joined = join_dims(input_tensor, tuple(join_axes)) |
| 460 | + reshaped_tensors.append(joined) |
| 461 | + |
| 462 | + return join(n_before, *reshaped_tensors), packed_shapes |
| 463 | + |
| 464 | + |
| 465 | +def unpack( |
| 466 | + packed_input: TensorLike, |
| 467 | + axes: int | Sequence[int] | None, |
| 468 | + packed_shapes: list[ShapeValueType], |
| 469 | +) -> list[TensorVariable]: |
| 470 | + """ |
| 471 | + Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. |
| 472 | +
|
| 473 | + The unpacking process reverses the packing operation, restoring the original shapes of the input tensors. `axes` |
| 474 | + corresponds to the axes that were preserved during packing, and `packed_shapes` contains the shapes of the raveled |
| 475 | + dimensions for each output tensor (that is, the shapes that were destroyed during packing). |
| 476 | +
|
| 477 | + The signature of unpack is such that the same `axes` should be passed to both `pack` and `unpack` to create a |
| 478 | + "round-trip" operation. For details on the rules for `axes`, see the documentation for `pack`. |
| 479 | +
|
| 480 | + Parameters |
| 481 | + ---------- |
| 482 | + packed_input : TensorLike |
| 483 | + The packed tensor to be unpacked. |
| 484 | + axes : int, sequence of int, or None |
| 485 | + Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used. |
| 486 | + packed_shapes : list of ShapeValueType |
| 487 | + A list containing the shapes of the raveled dimensions for each output tensor. |
| 488 | +
|
| 489 | + Returns |
| 490 | + ------- |
| 491 | + unpacked_tensors : list of TensorLike |
| 492 | + A list of unpacked tensors with their original shapes restored. |
| 493 | + """ |
| 494 | + packed_input = as_tensor_variable(packed_input) |
| 495 | + |
| 496 | + if axes is None: |
| 497 | + if packed_input.ndim != 1: |
| 498 | + raise ValueError( |
| 499 | + "unpack can only be called with keep_axis=None for 1d inputs" |
| 500 | + ) |
| 501 | + split_axis = 0 |
| 502 | + else: |
| 503 | + axes = normalize_axis_tuple(axes, ndim=packed_input.ndim) |
| 504 | + try: |
| 505 | + [split_axis] = (i for i in range(packed_input.ndim) if i not in axes) |
| 506 | + except ValueError as err: |
| 507 | + raise ValueError( |
| 508 | + "Unpack must have exactly one more dimension that implied by axes" |
| 509 | + ) from err |
| 510 | + |
| 511 | + split_inputs = split( |
| 512 | + packed_input, |
| 513 | + splits_size=[prod(shape, dtype=int) for shape in packed_shapes], |
| 514 | + n_splits=len(packed_shapes), |
| 515 | + axis=split_axis, |
| 516 | + ) |
| 517 | + |
| 518 | + return [ |
| 519 | + split_dims(inp, shape, split_axis) |
| 520 | + for inp, shape in zip(split_inputs, packed_shapes, strict=True) |
| 521 | + ] |
| 522 | + |
| 523 | + |
| 524 | +__all__ = ["join_dims", "pack", "split_dims", "unpack"] |
0 commit comments