diff --git a/docs/changelog.md b/docs/changelog.md
index a37539e3..f642e1e1 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -11,6 +11,7 @@ Release notes for `quimb`.
**Enhancements:**
+- [`Tensor`](quimb.tensor.tensor_core.Tensor): make binary operations (`+, -, *, /, **`) automatically align and broadcast indices. This would previously error.
- [`MatrixProductState.measure`](quimb.tensor.tensor_1d.MatrixProductState.measure): add a `seed` kwarg
- belief propagation, implement DIIS (direct inversion in the iterative subspace)
- belief propagation, unify various aspects such as message normalization and distance.
diff --git a/docs/tensor-basics.ipynb b/docs/tensor-basics.ipynb
index 2581fb05..cf236b41 100644
--- a/docs/tensor-basics.ipynb
+++ b/docs/tensor-basics.ipynb
@@ -129,6 +129,7 @@
"\n",
"1. Methods manipulating dimensions use the names of indices, thus provided these are labelled correctly, their specific permutation doesn't matter.\n",
"2. The underlying `data` object can be anything that [autoray](https://github.com/jcmgray/autoray) supports - for example a symbolic or GPU array. These are passed around by reference and assumed to be immutable.\n",
+ "3. Binary operations such as `+` and `*` automatically align and broadcast the indices.\n",
"\n",
"Some common [`Tensor`](quimb.tensor.tensor_core.Tensor) methods are:\n",
"\n",
@@ -220,21 +221,26 @@
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "TensorNetwork([\n",
- " Tensor(shape=(2, 2), inds=('k0', 'k1'), tags=oset(['KET'])),\n",
- " Tensor(shape=(2, 2), inds=('k0', 'b0'), tags=oset(['PAULI', 'X', '0'])),\n",
- " Tensor(shape=(2, 2), inds=('k1', 'b1'), tags=oset(['PAULI', 'Y', '1'])),\n",
- " Tensor(shape=(2, 2), inds=('b0', 'b1'), tags=oset(['BRA'])),\n",
- "], tensors=4, indices=4)\n"
- ]
+ "data": {
+ "text/html": [
+ "TensorNetwork(tensors=4, indices=4)
Tensor(shape=(2, 2), inds=[k0, k1], tags={KET}),
backend=numpy, dtype=complex128, data=array([[ 0. -0.j, 0.70710678-0.j],\n",
+ " [-0.70710678-0.j, 0. -0.j]])Tensor(shape=(2, 2), inds=[k0, b0], tags={PAULI, X, 0}),
backend=numpy, dtype=complex128, data=array([[0.+0.j, 1.+0.j],\n",
+ " [1.+0.j, 0.+0.j]])Tensor(shape=(2, 2), inds=[k1, b1], tags={PAULI, Y, 1}),
backend=numpy, dtype=complex128, data=array([[ 0.+0.j, -0.-1.j],\n",
+ " [ 0.+1.j, 0.+0.j]])Tensor(shape=(2, 2), inds=[b0, b1], tags={BRA}),
backend=numpy, dtype=complex128, data=array([[ 0.71586105-0.4477296j , -0.03812291-0.08236035j],\n",
+ " [ 0.16472649-0.39243288j, -0.0464581 -0.30910814j]]) "
+ ],
+ "text/plain": [
+ "TensorNetwork(tensors=4, indices=4)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
"TN = ket.H & X & Y & bra\n",
- "print(TN)"
+ "TN"
]
},
{
@@ -346,7 +352,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -424,7 +430,7 @@
{
"data": {
"text/plain": [
- "(0.556024406455527+0.28213663229389124j)"
+ "(0.33572951058280465+0.08952224604844464j)"
]
},
"execution_count": 9,
@@ -454,7 +460,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -509,8 +515,8 @@
{
"data": {
"text/html": [
- "Tensor(shape=(2, 4), inds=[a, b], tags={A, B}),
backend=numpy, dtype=float64, data=array([[-1.07482966, 1.40801906, 0.97244287, 2.06712511],\n",
- " [-1.19598176, -1.1037286 , -0.15011799, 1.1052763 ]]) "
+ "Tensor(shape=(2, 4), inds=[a, b], tags={A, B}),
backend=numpy, dtype=float64, data=array([[ 1.70457424, 0.92376522, -2.71552006, 0.73727999],\n",
+ " [-0.14605588, -0.22119553, -0.45599571, -0.43772905]]) "
],
"text/plain": [
"Tensor(shape=(2, 4), inds=('a', 'b'), tags=oset(['A', 'B']))"
@@ -550,7 +556,7 @@
{
"data": {
"text/plain": [
- "0.9999999999999999"
+ "1.0000000000000002"
]
},
"execution_count": 12,
@@ -618,7 +624,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -663,7 +669,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -689,7 +695,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -759,7 +765,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -820,7 +826,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -847,7 +853,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -904,7 +910,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -983,7 +989,7 @@
{
"data": {
"text/plain": [
- "array([2900.03154824+0.j])"
+ "array([3963.85599485+0.j])"
]
},
"execution_count": 22,
@@ -1017,7 +1023,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1096,7 +1102,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1130,7 +1136,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1164,7 +1170,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1198,7 +1204,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1243,10 +1249,10 @@
{
"data": {
"text/html": [
- "Tensor(shape=(4, 4, 4, 4), inds=[_835204AAABZ, _835204AAABb, _835204AAABc, _835204AAABT], tags={I2,3, X2, Y3, ALL, ANY}),
backend=numpy, dtype=float64, data=... "
+ "Tensor(shape=(4, 4, 4, 4), inds=[_91c677AAABZ, _91c677AAABb, _91c677AAABc, _91c677AAABT], tags={I2,3, X2, Y3, ALL, ANY}),
backend=numpy, dtype=float64, data=... "
],
"text/plain": [
- "Tensor(shape=(4, 4, 4, 4), inds=('_835204AAABZ', '_835204AAABb', '_835204AAABc', '_835204AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))"
+ "Tensor(shape=(4, 4, 4, 4), inds=('_91c677AAABZ', '_91c677AAABb', '_91c677AAABc', '_91c677AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))"
]
},
"execution_count": 29,
@@ -1275,10 +1281,10 @@
{
"data": {
"text/html": [
- "Tensor(shape=(4, 4, 4, 4), inds=[_835204AAABZ, _835204AAABb, _835204AAABc, _835204AAABT], tags={I2,3, X2, Y3, ALL, ANY}),
backend=numpy, dtype=float64, data=... "
+ "Tensor(shape=(4, 4, 4, 4), inds=[_91c677AAABZ, _91c677AAABb, _91c677AAABc, _91c677AAABT], tags={I2,3, X2, Y3, ALL, ANY}),
backend=numpy, dtype=float64, data=... "
],
"text/plain": [
- "Tensor(shape=(4, 4, 4, 4), inds=('_835204AAABZ', '_835204AAABb', '_835204AAABc', '_835204AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))"
+ "Tensor(shape=(4, 4, 4, 4), inds=('_91c677AAABZ', '_91c677AAABb', '_91c677AAABc', '_91c677AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))"
]
},
"execution_count": 30,
@@ -1395,7 +1401,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1421,7 +1427,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -1450,18 +1456,18 @@
{
"data": {
"text/plain": [
- "('_835204AAACJ',\n",
- " '_835204AAACK',\n",
- " '_835204AAACL',\n",
- " '_835204AAACM',\n",
- " '_835204AAACN',\n",
- " '_835204AAACO',\n",
- " '_835204AAACP',\n",
- " '_835204AAACQ',\n",
- " '_835204AAACR',\n",
- " '_835204AAACS',\n",
- " '_835204AAACT',\n",
- " '_835204AAACU')"
+ "('_91c677AAACJ',\n",
+ " '_91c677AAACK',\n",
+ " '_91c677AAACL',\n",
+ " '_91c677AAACM',\n",
+ " '_91c677AAACN',\n",
+ " '_91c677AAACO',\n",
+ " '_91c677AAACP',\n",
+ " '_91c677AAACQ',\n",
+ " '_91c677AAACR',\n",
+ " '_91c677AAACS',\n",
+ " '_91c677AAACT',\n",
+ " '_91c677AAACU')"
]
},
"execution_count": 35,
@@ -1489,7 +1495,7 @@
"metadata": {
"celltoolbar": "Raw Cell Format",
"kernelspec": {
- "display_name": "py312",
+ "display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -1503,7 +1509,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.6"
+ "version": "3"
}
},
"nbformat": 4,
diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py
index f49c9391..a22a9859 100644
--- a/quimb/tensor/tensor_core.py
+++ b/quimb/tensor/tensor_core.py
@@ -3300,15 +3300,35 @@ def COPY_tree_tensors(d, inds, tags=None, dtype=float, ssa_path=None):
def _make_promote_array_func(op, meth_name):
@functools.wraps(getattr(np.ndarray, meth_name))
def _promote_array_func(self, other):
- """Use standard array func, but make sure Tensor inds match."""
+ """Use standard array func, but auto match up indices."""
if isinstance(other, Tensor):
- if set(self.inds) != set(other.inds):
- raise ValueError(
- "The indicies of these two tensors do not "
- f"match: {self.inds} != {other.inds}"
- )
-
- otherT = other.transpose(*self.inds)
+ # auto match up indices - i.e. broadcast dimensions
+ left_expand = []
+ right_expand = []
+
+ for ix in self.inds:
+ if ix not in other.inds:
+ right_expand.append(ix)
+ for ix in other.inds:
+ if ix not in self.inds:
+ left_expand.append(ix)
+
+ # new_ind is an inplace operation -> track if we need to copy
+ copied = False
+ for ix in left_expand:
+ if not copied:
+ self = self.copy()
+ copied = True
+ self.new_ind(ix, axis=-1)
+
+ copied = False
+ for ix in right_expand:
+ if not copied:
+ other = other.copy()
+ copied = True
+ other.new_ind(ix)
+
+ otherT = other.transpose(*self.inds, inplace=copied)
return Tensor(
data=op(self.data, otherT.data),
diff --git a/tests/test_tensor/test_tensor_core.py b/tests/test_tensor/test_tensor_core.py
index c377ce4c..b53f9fda 100644
--- a/tests/test_tensor/test_tensor_core.py
+++ b/tests/test_tensor/test_tensor_core.py
@@ -126,8 +126,11 @@ def test_tensor_tensor_arithmetic(self, op, mismatch):
b = Tensor(np.random.rand(2, 3, 4), inds=[0, 1, 2], tags="red")
if mismatch:
b.modify(inds=(0, 1, 3))
- with pytest.raises(ValueError):
- op(a, b)
+ c = op(a, b)
+ assert_allclose(c.data, op(
+ a.data.reshape(2, 3, 4, 1),
+ b.data.reshape(2, 3, 1, 4))
+ )
else:
c = op(a, b)
assert_allclose(c.data, op(a.data, b.data))