64
64
#endif
65
65
/**********************************************/
66
66
67
- typedef enum {
68
- BROADCAST_NONE,
69
- BROADCAST_LEFT,
70
- BROADCAST_RIGHT,
71
- BROADCAST_MIDDLE
72
- } EINSUM_BROADCAST;
73
-
74
67
/**begin repeat
75
68
* #name = byte, short, int, long, longlong,
76
69
* ubyte, ushort, uint, ulong, ulonglong,
@@ -1802,8 +1795,7 @@ parse_operand_subscripts(char *subscripts, int length,
1802
1795
char *out_label_counts,
1803
1796
int *out_min_label,
1804
1797
int *out_max_label,
1805
- int *out_num_labels,
1806
- EINSUM_BROADCAST *out_broadcast)
1798
+ int *out_num_labels)
1807
1799
{
1808
1800
int i, idim, ndim_left, label;
1809
1801
int left_labels = 0, right_labels = 0, ellipsis = 0;
@@ -1947,19 +1939,6 @@ parse_operand_subscripts(char *subscripts, int length,
1947
1939
}
1948
1940
}
1949
1941
1950
- if (!ellipsis) {
1951
- *out_broadcast = BROADCAST_NONE;
1952
- }
1953
- else if (left_labels && right_labels) {
1954
- *out_broadcast = BROADCAST_MIDDLE;
1955
- }
1956
- else if (!left_labels) {
1957
- *out_broadcast = BROADCAST_RIGHT;
1958
- }
1959
- else {
1960
- *out_broadcast = BROADCAST_LEFT;
1961
- }
1962
-
1963
1942
return 1;
1964
1943
}
1965
1944
@@ -1972,8 +1951,7 @@ static int
1972
1951
parse_output_subscripts(char *subscripts, int length,
1973
1952
int ndim_broadcast,
1974
1953
const char *label_counts,
1975
- char *out_labels,
1976
- EINSUM_BROADCAST *out_broadcast)
1954
+ char *out_labels)
1977
1955
{
1978
1956
int i, nlabels, label, idim, ndim, ndim_left;
1979
1957
int left_labels = 0, right_labels = 0, ellipsis = 0;
@@ -2097,19 +2075,6 @@ parse_output_subscripts(char *subscripts, int length,
2097
2075
out_labels[idim++] = 0;
2098
2076
}
2099
2077
2100
- if (!ellipsis) {
2101
- *out_broadcast = BROADCAST_NONE;
2102
- }
2103
- else if (left_labels && right_labels) {
2104
- *out_broadcast = BROADCAST_MIDDLE;
2105
- }
2106
- else if (!left_labels) {
2107
- *out_broadcast = BROADCAST_RIGHT;
2108
- }
2109
- else {
2110
- *out_broadcast = BROADCAST_LEFT;
2111
- }
2112
-
2113
2078
return ndim;
2114
2079
}
2115
2080
@@ -2328,136 +2293,44 @@ get_combined_dims_view(PyArrayObject *op, int iop, char *labels)
2328
2293
2329
2294
static int
2330
2295
prepare_op_axes(int ndim, int iop, char *labels, int *axes,
2331
- int ndim_iter, char *iter_labels, EINSUM_BROADCAST broadcast )
2296
+ int ndim_iter, char *iter_labels)
2332
2297
{
2333
2298
int i, label, ibroadcast;
2334
2299
2335
- /* Regular broadcasting */
2336
- if (broadcast == BROADCAST_RIGHT) {
2337
- /* broadcast dimensions get placed in rightmost position */
2338
- ibroadcast = ndim-1;
2339
- for (i = ndim_iter-1; i >= 0; --i) {
2340
- label = iter_labels[i];
2341
- /*
2342
- * If it's an unlabeled broadcast dimension, choose
2343
- * the next broadcast dimension from the operand.
2344
- */
2345
- if (label == 0) {
2346
- while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
2347
- --ibroadcast;
2348
- }
2349
- /*
2350
- * If we used up all the operand broadcast dimensions,
2351
- * extend it with a "newaxis"
2352
- */
2353
- if (ibroadcast < 0) {
2354
- axes[i] = -1;
2355
- }
2356
- /* Otherwise map to the broadcast axis */
2357
- else {
2358
- axes[i] = ibroadcast;
2359
- --ibroadcast;
2360
- }
2361
- }
2362
- /* It's a labeled dimension, find the matching one */
2363
- else {
2364
- char *match = memchr(labels, label, ndim);
2365
- /* If the op doesn't have the label, broadcast it */
2366
- if (match == NULL) {
2367
- axes[i] = -1;
2368
- }
2369
- /* Otherwise use it */
2370
- else {
2371
- axes[i] = match - labels;
2372
- }
2300
+ ibroadcast = ndim-1;
2301
+ for (i = ndim_iter-1; i >= 0; --i) {
2302
+ label = iter_labels[i];
2303
+ /*
2304
+ * If it's an unlabeled broadcast dimension, choose
2305
+ * the next broadcast dimension from the operand.
2306
+ */
2307
+ if (label == 0) {
2308
+ while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
2309
+ --ibroadcast;
2373
2310
}
2374
- }
2375
- }
2376
- /* Reverse broadcasting */
2377
- else if (broadcast == BROADCAST_LEFT) {
2378
- /* broadcast dimensions get placed in leftmost position */
2379
- ibroadcast = 0;
2380
- for (i = 0; i < ndim_iter; ++i) {
2381
- label = iter_labels[i];
2382
2311
/*
2383
- * If it's an unlabeled broadcast dimension, choose
2384
- * the next broadcast dimension from the operand.
2312
+ * If we used up all the operand broadcast dimensions,
2313
+ * extend it with a "newaxis"
2385
2314
*/
2386
- if (label == 0) {
2387
- while (ibroadcast < ndim && labels[ibroadcast] != 0) {
2388
- ++ibroadcast;
2389
- }
2390
- /*
2391
- * If we used up all the operand broadcast dimensions,
2392
- * extend it with a "newaxis"
2393
- */
2394
- if (ibroadcast >= ndim) {
2395
- axes[i] = -1;
2396
- }
2397
- /* Otherwise map to the broadcast axis */
2398
- else {
2399
- axes[i] = ibroadcast;
2400
- ++ibroadcast;
2401
- }
2315
+ if (ibroadcast < 0) {
2316
+ axes[i] = -1;
2402
2317
}
2403
- /* It's a labeled dimension, find the matching one */
2318
+ /* Otherwise map to the broadcast axis */
2404
2319
else {
2405
- char *match = memchr(labels, label, ndim);
2406
- /* If the op doesn't have the label, broadcast it */
2407
- if (match == NULL) {
2408
- axes[i] = -1;
2409
- }
2410
- /* Otherwise use it */
2411
- else {
2412
- axes[i] = match - labels;
2413
- }
2320
+ axes[i] = ibroadcast;
2321
+ --ibroadcast;
2414
2322
}
2415
2323
}
2416
- }
2417
- /* Middle or None broadcasting */
2418
- else {
2419
- /* broadcast dimensions get placed in leftmost position */
2420
- ibroadcast = 0;
2421
- for (i = 0; i < ndim_iter; ++i) {
2422
- label = iter_labels[i];
2423
- /*
2424
- * If it's an unlabeled broadcast dimension, choose
2425
- * the next broadcast dimension from the operand.
2426
- */
2427
- if (label == 0) {
2428
- while (ibroadcast < ndim && labels[ibroadcast] != 0) {
2429
- ++ibroadcast;
2430
- }
2431
- /*
2432
- * If we used up all the operand broadcast dimensions,
2433
- * it's an error
2434
- */
2435
- if (ibroadcast >= ndim) {
2436
- PyErr_Format(PyExc_ValueError,
2437
- "operand %d did not have enough dimensions "
2438
- "to match the broadcasting, and couldn't be "
2439
- "extended because einstein sum subscripts "
2440
- "were specified at both the start and end",
2441
- iop);
2442
- return 0;
2443
- }
2444
- /* Otherwise map to the broadcast axis */
2445
- else {
2446
- axes[i] = ibroadcast;
2447
- ++ibroadcast;
2448
- }
2324
+ /* It's a labeled dimension, find the matching one */
2325
+ else {
2326
+ char *match = memchr(labels, label, ndim);
2327
+ /* If the op doesn't have the label, broadcast it */
2328
+ if (match == NULL) {
2329
+ axes[i] = -1;
2449
2330
}
2450
- /* It's a labeled dimension, find the matching one */
2331
+ /* Otherwise use it */
2451
2332
else {
2452
- char *match = memchr(labels, label, ndim);
2453
- /* If the op doesn't have the label, broadcast it */
2454
- if (match == NULL) {
2455
- axes[i] = -1;
2456
- }
2457
- /* Otherwise use it */
2458
- else {
2459
- axes[i] = match - labels;
2460
- }
2333
+ axes[i] = match - labels;
2461
2334
}
2462
2335
}
2463
2336
}
@@ -2737,7 +2610,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
2737
2610
char output_labels[NPY_MAXDIMS], *iter_labels;
2738
2611
int idim, ndim_output, ndim_broadcast, ndim_iter;
2739
2612
2740
- EINSUM_BROADCAST broadcast[NPY_MAXARGS];
2741
2613
PyArrayObject *op[NPY_MAXARGS], *ret = NULL;
2742
2614
PyArray_Descr *op_dtypes_array[NPY_MAXARGS], **op_dtypes;
2743
2615
@@ -2783,8 +2655,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
2783
2655
if (!parse_operand_subscripts(subscripts, length,
2784
2656
PyArray_NDIM(op_in[iop]),
2785
2657
iop, op_labels[iop], label_counts,
2786
- &min_label, &max_label, &num_labels,
2787
- &broadcast[iop])) {
2658
+ &min_label, &max_label, &num_labels)) {
2788
2659
return NULL;
2789
2660
}
2790
2661
@@ -2845,7 +2716,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
2845
2716
/* Parse the output subscript string */
2846
2717
ndim_output = parse_output_subscripts(outsubscripts, length,
2847
2718
ndim_broadcast, label_counts,
2848
- output_labels, &broadcast[nop] );
2719
+ output_labels);
2849
2720
}
2850
2721
else {
2851
2722
if (subscripts[0] != '-' || subscripts[1] != '>') {
@@ -2859,7 +2730,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
2859
2730
/* Parse the output subscript string */
2860
2731
ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
2861
2732
ndim_broadcast, label_counts,
2862
- output_labels, &broadcast[nop] );
2733
+ output_labels);
2863
2734
}
2864
2735
if (ndim_output < 0) {
2865
2736
return NULL;
@@ -2961,7 +2832,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
2961
2832
op_axes[iop] = op_axes_arrays[iop];
2962
2833
2963
2834
if (!prepare_op_axes(PyArray_NDIM(op[iop]), iop, op_labels[iop],
2964
- op_axes[iop], ndim_iter, iter_labels, broadcast[iop] )) {
2835
+ op_axes[iop], ndim_iter, iter_labels)) {
2965
2836
goto fail;
2966
2837
}
2967
2838
}
0 commit comments