Skip to content

Commit b1b0ea8

Browse files
hpauljcharris
hpaulj
authored andcommitted
ENH: Remove unnecessary broadcasting notation restrictions in einsum.
In a case where 'ik,kj->ij' works, einsum would raise an error for 'ik,k...->i...' because the 'ik' did not have ellipsis In einsum.c.src prepare_op_axes() pass all 'broadcast' cases through the 'RIGHT' case (interation from the end). Since the BROADCAST variable is not longer needed, all instances of it have been removed from einsum.c.src test_einsum.py - adds a test_einsum_broadcast case.
1 parent d1dbf8e commit b1b0ea8

File tree

4 files changed

+91
-162
lines changed

4 files changed

+91
-162
lines changed

doc/release/1.9.0-notes.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ The `out` argument to `np.argmin` and `np.argmax` and their equivalent
7979
C-API functions is now checked to match the desired output shape exactly.
8080
If the check fails a `ValueError` instead of `TypeError` is raised.
8181

82-
82+
Einsum
83+
~~~~~~
84+
Remove unnecessary broadcasting notation restrictions.
85+
np.einsum('ijk,j->ijk', A, B) can also be written as
86+
np.einsum('ij...,j->ij...', A, B) (ellipsis is no longer required on 'j')
87+
8388
C-API
8489
~~~~~
8590

numpy/add_newdocs.py

+14
Original file line numberDiff line numberDiff line change
@@ -2078,6 +2078,8 @@ def luf(lamdaexpr, *args, **kwargs):
20782078
array([ 30, 80, 130, 180, 230])
20792079
>>> np.dot(a, b)
20802080
array([ 30, 80, 130, 180, 230])
2081+
>>> np.einsum('...j,j', a, b)
2082+
array([ 30, 80, 130, 180, 230])
20812083
20822084
>>> np.einsum('ji', c)
20832085
array([[0, 3],
@@ -2147,6 +2149,18 @@ def luf(lamdaexpr, *args, **kwargs):
21472149
[ 4796., 5162.],
21482150
[ 4928., 5306.]])
21492151
2152+
>>> a = np.arange(6).reshape((3,2))
2153+
>>> b = np.arange(12).reshape((4,3))
2154+
>>> np.einsum('ki,jk->ij', a, b)
2155+
array([[10, 28, 46, 64],
2156+
[13, 40, 67, 94]])
2157+
>>> np.einsum('ki,...k->i...', a, b)
2158+
array([[10, 28, 46, 64],
2159+
[13, 40, 67, 94]])
2160+
>>> np.einsum('k...,jk', a, b)
2161+
array([[10, 28, 46, 64],
2162+
[13, 40, 67, 94]])
2163+
21502164
""")
21512165

21522166
add_newdoc('numpy.core', 'alterdot',

numpy/core/src/multiarray/einsum.c.src

+32-161
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@
6464
#endif
6565
/**********************************************/
6666

67-
typedef enum {
68-
BROADCAST_NONE,
69-
BROADCAST_LEFT,
70-
BROADCAST_RIGHT,
71-
BROADCAST_MIDDLE
72-
} EINSUM_BROADCAST;
73-
7467
/**begin repeat
7568
* #name = byte, short, int, long, longlong,
7669
* ubyte, ushort, uint, ulong, ulonglong,
@@ -1802,8 +1795,7 @@ parse_operand_subscripts(char *subscripts, int length,
18021795
char *out_label_counts,
18031796
int *out_min_label,
18041797
int *out_max_label,
1805-
int *out_num_labels,
1806-
EINSUM_BROADCAST *out_broadcast)
1798+
int *out_num_labels)
18071799
{
18081800
int i, idim, ndim_left, label;
18091801
int left_labels = 0, right_labels = 0, ellipsis = 0;
@@ -1947,19 +1939,6 @@ parse_operand_subscripts(char *subscripts, int length,
19471939
}
19481940
}
19491941

1950-
if (!ellipsis) {
1951-
*out_broadcast = BROADCAST_NONE;
1952-
}
1953-
else if (left_labels && right_labels) {
1954-
*out_broadcast = BROADCAST_MIDDLE;
1955-
}
1956-
else if (!left_labels) {
1957-
*out_broadcast = BROADCAST_RIGHT;
1958-
}
1959-
else {
1960-
*out_broadcast = BROADCAST_LEFT;
1961-
}
1962-
19631942
return 1;
19641943
}
19651944

@@ -1972,8 +1951,7 @@ static int
19721951
parse_output_subscripts(char *subscripts, int length,
19731952
int ndim_broadcast,
19741953
const char *label_counts,
1975-
char *out_labels,
1976-
EINSUM_BROADCAST *out_broadcast)
1954+
char *out_labels)
19771955
{
19781956
int i, nlabels, label, idim, ndim, ndim_left;
19791957
int left_labels = 0, right_labels = 0, ellipsis = 0;
@@ -2097,19 +2075,6 @@ parse_output_subscripts(char *subscripts, int length,
20972075
out_labels[idim++] = 0;
20982076
}
20992077

2100-
if (!ellipsis) {
2101-
*out_broadcast = BROADCAST_NONE;
2102-
}
2103-
else if (left_labels && right_labels) {
2104-
*out_broadcast = BROADCAST_MIDDLE;
2105-
}
2106-
else if (!left_labels) {
2107-
*out_broadcast = BROADCAST_RIGHT;
2108-
}
2109-
else {
2110-
*out_broadcast = BROADCAST_LEFT;
2111-
}
2112-
21132078
return ndim;
21142079
}
21152080

@@ -2328,136 +2293,44 @@ get_combined_dims_view(PyArrayObject *op, int iop, char *labels)
23282293

23292294
static int
23302295
prepare_op_axes(int ndim, int iop, char *labels, int *axes,
2331-
int ndim_iter, char *iter_labels, EINSUM_BROADCAST broadcast)
2296+
int ndim_iter, char *iter_labels)
23322297
{
23332298
int i, label, ibroadcast;
23342299

2335-
/* Regular broadcasting */
2336-
if (broadcast == BROADCAST_RIGHT) {
2337-
/* broadcast dimensions get placed in rightmost position */
2338-
ibroadcast = ndim-1;
2339-
for (i = ndim_iter-1; i >= 0; --i) {
2340-
label = iter_labels[i];
2341-
/*
2342-
* If it's an unlabeled broadcast dimension, choose
2343-
* the next broadcast dimension from the operand.
2344-
*/
2345-
if (label == 0) {
2346-
while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
2347-
--ibroadcast;
2348-
}
2349-
/*
2350-
* If we used up all the operand broadcast dimensions,
2351-
* extend it with a "newaxis"
2352-
*/
2353-
if (ibroadcast < 0) {
2354-
axes[i] = -1;
2355-
}
2356-
/* Otherwise map to the broadcast axis */
2357-
else {
2358-
axes[i] = ibroadcast;
2359-
--ibroadcast;
2360-
}
2361-
}
2362-
/* It's a labeled dimension, find the matching one */
2363-
else {
2364-
char *match = memchr(labels, label, ndim);
2365-
/* If the op doesn't have the label, broadcast it */
2366-
if (match == NULL) {
2367-
axes[i] = -1;
2368-
}
2369-
/* Otherwise use it */
2370-
else {
2371-
axes[i] = match - labels;
2372-
}
2300+
ibroadcast = ndim-1;
2301+
for (i = ndim_iter-1; i >= 0; --i) {
2302+
label = iter_labels[i];
2303+
/*
2304+
* If it's an unlabeled broadcast dimension, choose
2305+
* the next broadcast dimension from the operand.
2306+
*/
2307+
if (label == 0) {
2308+
while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
2309+
--ibroadcast;
23732310
}
2374-
}
2375-
}
2376-
/* Reverse broadcasting */
2377-
else if (broadcast == BROADCAST_LEFT) {
2378-
/* broadcast dimensions get placed in leftmost position */
2379-
ibroadcast = 0;
2380-
for (i = 0; i < ndim_iter; ++i) {
2381-
label = iter_labels[i];
23822311
/*
2383-
* If it's an unlabeled broadcast dimension, choose
2384-
* the next broadcast dimension from the operand.
2312+
* If we used up all the operand broadcast dimensions,
2313+
* extend it with a "newaxis"
23852314
*/
2386-
if (label == 0) {
2387-
while (ibroadcast < ndim && labels[ibroadcast] != 0) {
2388-
++ibroadcast;
2389-
}
2390-
/*
2391-
* If we used up all the operand broadcast dimensions,
2392-
* extend it with a "newaxis"
2393-
*/
2394-
if (ibroadcast >= ndim) {
2395-
axes[i] = -1;
2396-
}
2397-
/* Otherwise map to the broadcast axis */
2398-
else {
2399-
axes[i] = ibroadcast;
2400-
++ibroadcast;
2401-
}
2315+
if (ibroadcast < 0) {
2316+
axes[i] = -1;
24022317
}
2403-
/* It's a labeled dimension, find the matching one */
2318+
/* Otherwise map to the broadcast axis */
24042319
else {
2405-
char *match = memchr(labels, label, ndim);
2406-
/* If the op doesn't have the label, broadcast it */
2407-
if (match == NULL) {
2408-
axes[i] = -1;
2409-
}
2410-
/* Otherwise use it */
2411-
else {
2412-
axes[i] = match - labels;
2413-
}
2320+
axes[i] = ibroadcast;
2321+
--ibroadcast;
24142322
}
24152323
}
2416-
}
2417-
/* Middle or None broadcasting */
2418-
else {
2419-
/* broadcast dimensions get placed in leftmost position */
2420-
ibroadcast = 0;
2421-
for (i = 0; i < ndim_iter; ++i) {
2422-
label = iter_labels[i];
2423-
/*
2424-
* If it's an unlabeled broadcast dimension, choose
2425-
* the next broadcast dimension from the operand.
2426-
*/
2427-
if (label == 0) {
2428-
while (ibroadcast < ndim && labels[ibroadcast] != 0) {
2429-
++ibroadcast;
2430-
}
2431-
/*
2432-
* If we used up all the operand broadcast dimensions,
2433-
* it's an error
2434-
*/
2435-
if (ibroadcast >= ndim) {
2436-
PyErr_Format(PyExc_ValueError,
2437-
"operand %d did not have enough dimensions "
2438-
"to match the broadcasting, and couldn't be "
2439-
"extended because einstein sum subscripts "
2440-
"were specified at both the start and end",
2441-
iop);
2442-
return 0;
2443-
}
2444-
/* Otherwise map to the broadcast axis */
2445-
else {
2446-
axes[i] = ibroadcast;
2447-
++ibroadcast;
2448-
}
2324+
/* It's a labeled dimension, find the matching one */
2325+
else {
2326+
char *match = memchr(labels, label, ndim);
2327+
/* If the op doesn't have the label, broadcast it */
2328+
if (match == NULL) {
2329+
axes[i] = -1;
24492330
}
2450-
/* It's a labeled dimension, find the matching one */
2331+
/* Otherwise use it */
24512332
else {
2452-
char *match = memchr(labels, label, ndim);
2453-
/* If the op doesn't have the label, broadcast it */
2454-
if (match == NULL) {
2455-
axes[i] = -1;
2456-
}
2457-
/* Otherwise use it */
2458-
else {
2459-
axes[i] = match - labels;
2460-
}
2333+
axes[i] = match - labels;
24612334
}
24622335
}
24632336
}
@@ -2737,7 +2610,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
27372610
char output_labels[NPY_MAXDIMS], *iter_labels;
27382611
int idim, ndim_output, ndim_broadcast, ndim_iter;
27392612

2740-
EINSUM_BROADCAST broadcast[NPY_MAXARGS];
27412613
PyArrayObject *op[NPY_MAXARGS], *ret = NULL;
27422614
PyArray_Descr *op_dtypes_array[NPY_MAXARGS], **op_dtypes;
27432615

@@ -2783,8 +2655,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
27832655
if (!parse_operand_subscripts(subscripts, length,
27842656
PyArray_NDIM(op_in[iop]),
27852657
iop, op_labels[iop], label_counts,
2786-
&min_label, &max_label, &num_labels,
2787-
&broadcast[iop])) {
2658+
&min_label, &max_label, &num_labels)) {
27882659
return NULL;
27892660
}
27902661

@@ -2845,7 +2716,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28452716
/* Parse the output subscript string */
28462717
ndim_output = parse_output_subscripts(outsubscripts, length,
28472718
ndim_broadcast, label_counts,
2848-
output_labels, &broadcast[nop]);
2719+
output_labels);
28492720
}
28502721
else {
28512722
if (subscripts[0] != '-' || subscripts[1] != '>') {
@@ -2859,7 +2730,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28592730
/* Parse the output subscript string */
28602731
ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
28612732
ndim_broadcast, label_counts,
2862-
output_labels, &broadcast[nop]);
2733+
output_labels);
28632734
}
28642735
if (ndim_output < 0) {
28652736
return NULL;
@@ -2961,7 +2832,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
29612832
op_axes[iop] = op_axes_arrays[iop];
29622833

29632834
if (!prepare_op_axes(PyArray_NDIM(op[iop]), iop, op_labels[iop],
2964-
op_axes[iop], ndim_iter, iter_labels, broadcast[iop])) {
2835+
op_axes[iop], ndim_iter, iter_labels)) {
29652836
goto fail;
29662837
}
29672838
}

numpy/core/tests/test_einsum.py

+39
Original file line numberDiff line numberDiff line change
@@ -498,5 +498,44 @@ def test_einsum_misc(self):
498498
[[[1, 3], [3, 9], [5, 15], [7, 21]],
499499
[[8, 16], [16, 32], [24, 48], [32, 64]]])
500500

501+
def test_einsum_broadcast(self):
502+
# Issue #2455 change in handling ellipsis
503+
# remove the 'middle broadcast' error
504+
# only use the 'RIGHT' iteration in prepare_op_axes
505+
# adds auto broadcast on left where it belongs
506+
# broadcast on right has to be explicit
507+
508+
A = np.arange(2*3*4).reshape(2,3,4)
509+
B = np.arange(3)
510+
ref = np.einsum('ijk,j->ijk',A, B)
511+
assert_equal(np.einsum('ij...,j...->ij...',A, B), ref)
512+
assert_equal(np.einsum('ij...,...j->ij...',A, B), ref)
513+
assert_equal(np.einsum('ij...,j->ij...',A, B), ref) # used to raise error
514+
515+
A = np.arange(12).reshape((4,3))
516+
B = np.arange(6).reshape((3,2))
517+
ref = np.einsum('ik,kj->ij', A, B)
518+
assert_equal(np.einsum('ik...,k...->i...', A, B), ref)
519+
assert_equal(np.einsum('ik...,...kj->i...j', A, B), ref)
520+
assert_equal(np.einsum('...k,kj', A, B), ref) # used to raise error
521+
assert_equal(np.einsum('ik,k...->i...', A, B), ref) # used to raise error
522+
523+
dims=[2,3,4,5];
524+
a = np.arange(np.prod(dims)).reshape(dims)
525+
v = np.arange(dims[2])
526+
ref = np.einsum('ijkl,k->ijl', a, v)
527+
assert_equal(np.einsum('ijkl,k', a, v), ref)
528+
assert_equal(np.einsum('...kl,k', a, v), ref) # used to raise error
529+
assert_equal(np.einsum('...kl,k...', a, v), ref)
530+
# no real diff from 1st
531+
532+
J,K,M=160,160,120;
533+
A=np.arange(J*K*M).reshape(1,1,1,J,K,M)
534+
B=np.arange(J*K*M*3).reshape(J,K,M,3)
535+
ref = np.einsum('...lmn,...lmno->...o', A, B)
536+
assert_equal(np.einsum('...lmn,lmno->...o', A, B), ref) # used to raise error
537+
538+
539+
501540
if __name__ == "__main__":
502541
run_module_suite()

0 commit comments

Comments
 (0)