Skip to content

Adding dot for xtensor #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: labeled_tensors
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 116 additions & 2 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import pytensor.scalar as ps
from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping
from pytensor.xtensor.basic import as_xtensor
from pytensor.scalar.basic import _cast_mapping, upcast
from pytensor.xtensor.basic import XOp, as_xtensor
from pytensor.xtensor.type import xtensor
from pytensor.xtensor.vectorization import XElemwise


Expand Down Expand Up @@ -134,3 +136,115 @@ def cast(x, dtype):
if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)


class XDot(XOp):
"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.

Parameters
----------
dims : tuple of str
The dimensions to contract over. If None, will contract over all matching dimensions.
"""

__props__ = ("dims",)

def __init__(self, dims: tuple[str, ...] | None = None):
self.dims = dims
super().__init__()

def make_node(self, x, y):
x = as_xtensor(x)
y = as_xtensor(y)

# Get dimensions to contract
if self.dims is None:
# Contract over all matching dimensions
x_dims = set(x.type.dims)
y_dims = set(y.type.dims)
contract_dims = tuple(x_dims & y_dims)
else:
contract_dims = self.dims

# Determine output dimensions and shapes
x_dims = list(x.type.dims)
y_dims = list(y.type.dims)
x_shape = list(x.type.shape)
y_shape = list(y.type.shape)

# Remove contracted dimensions
for dim in contract_dims:
x_idx = x_dims.index(dim)
y_idx = y_dims.index(dim)
x_dims.pop(x_idx)
y_dims.pop(y_idx)
x_shape.pop(x_idx)
y_shape.pop(y_idx)

# Combine remaining dimensions
out_dims = tuple(x_dims + y_dims)
out_shape = tuple(x_shape + y_shape)

# Determine output dtype
out_dtype = upcast(x.type.dtype, y.type.dtype)

out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, y], [out])


def dot(x, y, dims: tuple[str, ...] | None = None):
"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.

Parameters
----------
x : XTensorVariable
First input tensor
y : XTensorVariable
Second input tensor
dims : tuple of str, optional
The dimensions to contract over. If None, will contract over all matching dimensions.

Returns
-------
XTensorVariable
The result of the matrix multiplication.

Examples
--------
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
>>> z = dot(x, y) # Result has dimensions ("a", "c")
"""
x = as_xtensor(x)
y = as_xtensor(y)

# Validate dimensions if specified
if dims is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are better placed in make_node

if not isinstance(dims, tuple):
dims = tuple(dims)
for dim in dims:
if dim not in x.type.dims:
raise ValueError(
f"Dimension {dim} not found in first input {x.type.dims}"
)
if dim not in y.type.dims:
raise ValueError(
f"Dimension {dim} not found in second input {y.type.dims}"
)
# Check for compatible shapes in contracted dimensions
x_idx = x.type.dims.index(dim)
y_idx = y.type.dims.index(dim)
x_size = x.type.shape[x_idx]
y_size = y.type.shape[y_idx]
if x_size is not None and y_size is not None and x_size != y_size:
raise ValueError(
f"Dimension {dim} has incompatible shapes: {x_size} and {y_size}"
)

return XDot(dims=dims)(x, y)
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.math
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
40 changes: 40 additions & 0 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import tensordot
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_lower_xtensor
@node_rewriter(tracks=[XDot])
def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot.

This rewrite converts an XDot operation to a tensor-based dot operation,
handling dimension alignment and contraction.
"""
[x, y] = node.inputs
[out] = node.outputs

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
y_tensor = tensor_from_xtensor(y)

# Get dimensions to contract
if node.op.dims is None:
# Contract over all matching dimensions
x_dims = set(x.type.dims)
y_dims = set(y.type.dims)
contract_dims = tuple(x_dims & y_dims)
else:
contract_dims = node.op.dims

# Get axes to contract for each input
x_axes = [x.type.dims.index(dim) for dim in contract_dims]
y_axes = [y.type.dims.index(dim) for dim in contract_dims]

# Perform dot product
out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))

# Convert back to xtensor
return [xtensor_from_tensor(out_tensor, out.type.dims)]
4 changes: 4 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,10 @@ def stack(self, dim, **dims):
def unstack(self, dim, **dims):
return px.shape.unstack(self, dim, **dims)

def dot(self, other, dims=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return px.math.dot(self, other, dims=dims)


class XTensorConstantSignature(tuple):
def __eq__(self, other):
Expand Down
31 changes: 31 additions & 0 deletions tests/xtensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,34 @@ def test_cast():
yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")


def test_dot():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks promising, but needs some tests calling dot with specific dims?

"""Test basic dot product operations."""
# Test matrix-matrix dot product
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
z = x.dot(y)
assert z.type.dims == ("a", "c")
assert z.type.shape == (2, 4)

fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-vector dot product
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b",), shape=(3,))
z = x.dot(y)
assert z.type.dims == ("a",)
assert z.type.shape == (2,)

fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones(3), dims=("b",))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)
Loading