diff --git a/docs/changelog.md b/docs/changelog.md index a37539e3..f642e1e1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -11,6 +11,7 @@ Release notes for `quimb`. **Enhancements:** +- [`Tensor`](quimb.tensor.tensor_core.Tensor): make binary operations (`+, -, *, /, **`) automatically align and broadcast indices. This would previously error. - [`MatrixProductState.measure`](quimb.tensor.tensor_1d.MatrixProductState.measure): add a `seed` kwarg - belief propagation, implement DIIS (direct inversion in the iterative subspace) - belief propagation, unify various aspects such as message normalization and distance. diff --git a/docs/tensor-basics.ipynb b/docs/tensor-basics.ipynb index 2581fb05..cf236b41 100644 --- a/docs/tensor-basics.ipynb +++ b/docs/tensor-basics.ipynb @@ -129,6 +129,7 @@ "\n", "1. Methods manipulating dimensions use the names of indices, thus provided these are labelled correctly, their specific permutation doesn't matter.\n", "2. The underlying `data` object can be anything that [autoray](https://github.com/jcmgray/autoray) supports - for example a symbolic or GPU array. These are passed around by reference and assumed to be immutable.\n", + "3. Binary operations such as `+` and `*` automatically align and broadcast the indices.\n", "\n", "Some common [`Tensor`](quimb.tensor.tensor_core.Tensor) methods are:\n", "\n", @@ -220,21 +221,26 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorNetwork([\n", - " Tensor(shape=(2, 2), inds=('k0', 'k1'), tags=oset(['KET'])),\n", - " Tensor(shape=(2, 2), inds=('k0', 'b0'), tags=oset(['PAULI', 'X', '0'])),\n", - " Tensor(shape=(2, 2), inds=('k1', 'b1'), tags=oset(['PAULI', 'Y', '1'])),\n", - " Tensor(shape=(2, 2), inds=('b0', 'b1'), tags=oset(['BRA'])),\n", - "], tensors=4, indices=4)\n" - ] + "data": { + "text/html": [ + "
TensorNetwork(tensors=4, indices=4)
Tensor(shape=(2, 2), inds=[k0, k1], tags={KET}),backend=numpy, dtype=complex128, data=array([[ 0. -0.j, 0.70710678-0.j],\n", + " [-0.70710678-0.j, 0. -0.j]])
Tensor(shape=(2, 2), inds=[k0, b0], tags={PAULI, X, 0}),backend=numpy, dtype=complex128, data=array([[0.+0.j, 1.+0.j],\n", + " [1.+0.j, 0.+0.j]])
Tensor(shape=(2, 2), inds=[k1, b1], tags={PAULI, Y, 1}),backend=numpy, dtype=complex128, data=array([[ 0.+0.j, -0.-1.j],\n", + " [ 0.+1.j, 0.+0.j]])
Tensor(shape=(2, 2), inds=[b0, b1], tags={BRA}),backend=numpy, dtype=complex128, data=array([[ 0.71586105-0.4477296j , -0.03812291-0.08236035j],\n", + " [ 0.16472649-0.39243288j, -0.0464581 -0.30910814j]])
" + ], + "text/plain": [ + "TensorNetwork(tensors=4, indices=4)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "TN = ket.H & X & Y & bra\n", - "print(TN)" + "TN" ] }, { @@ -346,7 +352,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -424,7 +430,7 @@ { "data": { "text/plain": [ - "(0.556024406455527+0.28213663229389124j)" + "(0.33572951058280465+0.08952224604844464j)" ] }, "execution_count": 9, @@ -454,7 +460,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -509,8 +515,8 @@ { "data": { "text/html": [ - "
Tensor(shape=(2, 4), inds=[a, b], tags={A, B}),backend=numpy, dtype=float64, data=array([[-1.07482966, 1.40801906, 0.97244287, 2.06712511],\n", - " [-1.19598176, -1.1037286 , -0.15011799, 1.1052763 ]])
" + "
Tensor(shape=(2, 4), inds=[a, b], tags={A, B}),backend=numpy, dtype=float64, data=array([[ 1.70457424, 0.92376522, -2.71552006, 0.73727999],\n", + " [-0.14605588, -0.22119553, -0.45599571, -0.43772905]])
" ], "text/plain": [ "Tensor(shape=(2, 4), inds=('a', 'b'), tags=oset(['A', 'B']))" @@ -550,7 +556,7 @@ { "data": { "text/plain": [ - "0.9999999999999999" + "1.0000000000000002" ] }, "execution_count": 12, @@ -618,7 +624,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -663,7 +669,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -689,7 +695,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -759,7 +765,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -820,7 +826,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -847,7 +853,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -904,7 +910,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -983,7 +989,7 @@ { "data": { "text/plain": [ - "array([2900.03154824+0.j])" + "array([3963.85599485+0.j])" ] }, "execution_count": 22, @@ -1017,7 +1023,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1096,7 +1102,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1130,7 +1136,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1164,7 +1170,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1198,7 +1204,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1243,10 +1249,10 @@ { "data": { "text/html": [ - "
Tensor(shape=(4, 4, 4, 4), inds=[_835204AAABZ, _835204AAABb, _835204AAABc, _835204AAABT], tags={I2,3, X2, Y3, ALL, ANY}),backend=numpy, dtype=float64, data=...
" + "
Tensor(shape=(4, 4, 4, 4), inds=[_91c677AAABZ, _91c677AAABb, _91c677AAABc, _91c677AAABT], tags={I2,3, X2, Y3, ALL, ANY}),backend=numpy, dtype=float64, data=...
" ], "text/plain": [ - "Tensor(shape=(4, 4, 4, 4), inds=('_835204AAABZ', '_835204AAABb', '_835204AAABc', '_835204AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))" + "Tensor(shape=(4, 4, 4, 4), inds=('_91c677AAABZ', '_91c677AAABb', '_91c677AAABc', '_91c677AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))" ] }, "execution_count": 29, @@ -1275,10 +1281,10 @@ { "data": { "text/html": [ - "
Tensor(shape=(4, 4, 4, 4), inds=[_835204AAABZ, _835204AAABb, _835204AAABc, _835204AAABT], tags={I2,3, X2, Y3, ALL, ANY}),backend=numpy, dtype=float64, data=...
" + "
Tensor(shape=(4, 4, 4, 4), inds=[_91c677AAABZ, _91c677AAABb, _91c677AAABc, _91c677AAABT], tags={I2,3, X2, Y3, ALL, ANY}),backend=numpy, dtype=float64, data=...
" ], "text/plain": [ - "Tensor(shape=(4, 4, 4, 4), inds=('_835204AAABZ', '_835204AAABb', '_835204AAABc', '_835204AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))" + "Tensor(shape=(4, 4, 4, 4), inds=('_91c677AAABZ', '_91c677AAABb', '_91c677AAABc', '_91c677AAABT'), tags=oset(['I2,3', 'X2', 'Y3', 'ALL', 'ANY']))" ] }, "execution_count": 30, @@ -1395,7 +1401,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1421,7 +1427,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "
" @@ -1450,18 +1456,18 @@ { "data": { "text/plain": [ - "('_835204AAACJ',\n", - " '_835204AAACK',\n", - " '_835204AAACL',\n", - " '_835204AAACM',\n", - " '_835204AAACN',\n", - " '_835204AAACO',\n", - " '_835204AAACP',\n", - " '_835204AAACQ',\n", - " '_835204AAACR',\n", - " '_835204AAACS',\n", - " '_835204AAACT',\n", - " '_835204AAACU')" + "('_91c677AAACJ',\n", + " '_91c677AAACK',\n", + " '_91c677AAACL',\n", + " '_91c677AAACM',\n", + " '_91c677AAACN',\n", + " '_91c677AAACO',\n", + " '_91c677AAACP',\n", + " '_91c677AAACQ',\n", + " '_91c677AAACR',\n", + " '_91c677AAACS',\n", + " '_91c677AAACT',\n", + " '_91c677AAACU')" ] }, "execution_count": 35, @@ -1489,7 +1495,7 @@ "metadata": { "celltoolbar": "Raw Cell Format", "kernelspec": { - "display_name": "py312", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1503,7 +1509,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3" } }, "nbformat": 4, diff --git a/quimb/tensor/tensor_core.py b/quimb/tensor/tensor_core.py index f49c9391..a22a9859 100644 --- a/quimb/tensor/tensor_core.py +++ b/quimb/tensor/tensor_core.py @@ -3300,15 +3300,35 @@ def COPY_tree_tensors(d, inds, tags=None, dtype=float, ssa_path=None): def _make_promote_array_func(op, meth_name): @functools.wraps(getattr(np.ndarray, meth_name)) def _promote_array_func(self, other): - """Use standard array func, but make sure Tensor inds match.""" + """Use standard array func, but auto match up indices.""" if isinstance(other, Tensor): - if set(self.inds) != set(other.inds): - raise ValueError( - "The indicies of these two tensors do not " - f"match: {self.inds} != {other.inds}" - ) - - otherT = other.transpose(*self.inds) + # auto match up indices - i.e. broadcast dimensions + left_expand = [] + right_expand = [] + + for ix in self.inds: + if ix not in other.inds: + right_expand.append(ix) + for ix in other.inds: + if ix not in self.inds: + left_expand.append(ix) + + # new_ind is an inplace operation -> track if we need to copy + copied = False + for ix in left_expand: + if not copied: + self = self.copy() + copied = True + self.new_ind(ix, axis=-1) + + copied = False + for ix in right_expand: + if not copied: + other = other.copy() + copied = True + other.new_ind(ix) + + otherT = other.transpose(*self.inds, inplace=copied) return Tensor( data=op(self.data, otherT.data), diff --git a/tests/test_tensor/test_tensor_core.py b/tests/test_tensor/test_tensor_core.py index c377ce4c..b53f9fda 100644 --- a/tests/test_tensor/test_tensor_core.py +++ b/tests/test_tensor/test_tensor_core.py @@ -126,8 +126,11 @@ def test_tensor_tensor_arithmetic(self, op, mismatch): b = Tensor(np.random.rand(2, 3, 4), inds=[0, 1, 2], tags="red") if mismatch: b.modify(inds=(0, 1, 3)) - with pytest.raises(ValueError): - op(a, b) + c = op(a, b) + assert_allclose(c.data, op( + a.data.reshape(2, 3, 4, 1), + b.data.reshape(2, 3, 1, 4)) + ) else: c = op(a, b) assert_allclose(c.data, op(a.data, b.data))