Skip to content

Commit

Permalink
Tensor: align and broadcast binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Feb 20, 2025
1 parent f3fd72f commit 6fa9dfa
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Release notes for `quimb`.

**Enhancements:**

- [`Tensor`](quimb.tensor.tensor_core.Tensor): make binary operations (`+, -, *, /, **`) automatically align and broadcast indices. This would previously error.
- [`MatrixProductState.measure`](quimb.tensor.tensor_1d.MatrixProductState.measure): add a `seed` kwarg
- belief propagation, implement DIIS (direct inversion in the iterative subspace)
- belief propagation, unify various aspects such as message normalization and distance.
Expand Down
106 changes: 56 additions & 50 deletions docs/tensor-basics.ipynb

Large diffs are not rendered by default.

36 changes: 28 additions & 8 deletions quimb/tensor/tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3300,15 +3300,35 @@ def COPY_tree_tensors(d, inds, tags=None, dtype=float, ssa_path=None):
def _make_promote_array_func(op, meth_name):
@functools.wraps(getattr(np.ndarray, meth_name))
def _promote_array_func(self, other):
"""Use standard array func, but make sure Tensor inds match."""
"""Use standard array func, but auto match up indices."""
if isinstance(other, Tensor):
if set(self.inds) != set(other.inds):
raise ValueError(
"The indicies of these two tensors do not "
f"match: {self.inds} != {other.inds}"
)

otherT = other.transpose(*self.inds)
# auto match up indices - i.e. broadcast dimensions
left_expand = []
right_expand = []

for ix in self.inds:
if ix not in other.inds:
right_expand.append(ix)
for ix in other.inds:
if ix not in self.inds:
left_expand.append(ix)

# new_ind is an inplace operation -> track if we need to copy
copied = False
for ix in left_expand:
if not copied:
self = self.copy()
copied = True
self.new_ind(ix, axis=-1)

copied = False
for ix in right_expand:
if not copied:
other = other.copy()
copied = True
other.new_ind(ix)

otherT = other.transpose(*self.inds, inplace=copied)

return Tensor(
data=op(self.data, otherT.data),
Expand Down
7 changes: 5 additions & 2 deletions tests/test_tensor/test_tensor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,11 @@ def test_tensor_tensor_arithmetic(self, op, mismatch):
b = Tensor(np.random.rand(2, 3, 4), inds=[0, 1, 2], tags="red")
if mismatch:
b.modify(inds=(0, 1, 3))
with pytest.raises(ValueError):
op(a, b)
c = op(a, b)
assert_allclose(c.data, op(
a.data.reshape(2, 3, 4, 1),
b.data.reshape(2, 3, 1, 4))
)
else:
c = op(a, b)
assert_allclose(c.data, op(a.data, b.data))
Expand Down

0 comments on commit 6fa9dfa

Please sign in to comment.