Skip to content

Labeled tensors #1411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4551,7 +4551,7 @@ def ix_(*args):
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1)))
out.append(new)
return tuple(out)

Expand Down
18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

Expand Down
7 changes: 1 addition & 6 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs):
return Apply(
self,
(x, y, *new_inputs),
[
tensor(
dtype=x.type.dtype,
shape=tuple(1 if s == 1 else None for s in x.type.shape),
)
],
[x.type()],
)

def perform(self, node, inputs, out_):
Expand Down
16 changes: 16 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import (
linalg,
special,
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
104 changes: 104 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from collections.abc import Sequence

from pytensor.compile import ViewOp
from pytensor.graph import Apply, Op
from pytensor.link.c.op import COp
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""

def perform(self, node, inputs, outputs):
raise NotImplementedError(

Check warning on line 14 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L14

Added line #L14 was not covered by tests
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


class XTypeCastOp(COp):
"""Base class for Ops that type cast between TensorType and XTensorType.
This is like a `ViewOp` but without the expectation the input and output have identical types.
"""

view_map = {0: [0]}

def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0]

def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
fail = sub["fail"]

Check warning on line 33 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L31-L33

Added lines #L31 - L33 were not covered by tests

code, _ = ViewOp.c_code_and_version[TensorType]
return code % locals()

Check warning on line 36 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L35-L36

Added lines #L35 - L36 were not covered by tests

def c_code_cache_version(self):
_, version = ViewOp.c_code_and_version[TensorType]
return (version,)


class TensorFromXTensor(XTypeCastOp):
__props__ = ()

def make_node(self, x):
if not isinstance(x.type, XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")

Check warning on line 48 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L48

Added line #L48 was not covered by tests
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])


tensor_from_xtensor = TensorFromXTensor()


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)

def make_node(self, x):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")

Check warning on line 65 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L65

Added line #L65 was not covered by tests
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])


def xtensor_from_tensor(x, dims):
return XTensorFromTensor(dims=dims)(x)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims

def make_node(self, x):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
return Apply(self, [x], [output])


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
names = name_dict

Check warning on line 91 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L90-L91

Added lines #L90 - L91 were not covered by tests

x = as_xtensor(x)
old_names = x.type.dims
new_names = list(old_names)
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except IndexError:
raise ValueError(

Check warning on line 100 in pytensor/xtensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/xtensor/basic.py#L99-L100

Added lines #L99 - L100 were not covered by tests
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
Loading