|
1 | 1 | // Generated by the Tensor Algebra Compiler (tensor-compiler.org)
|
2 |
| -// taco "A(i,j)=B(i,k,l)*C(k,j)*D(l,j)" -f=A:dd:0,1 -f=B:sss:0,1,2 -f=C:dd:0,1 -f=D:dd:0,1 -write-source=taco_kernel.c -write-compute=taco_compute.c -write-assembly=taco_assembly.c |
| 2 | +// taco "A(i,j)=B(i,k,l)*D(l,j)*C(k,j)" -f=A:dd:0,1 -f=B:sss:0,1,2 -f=D:dd:0,1 -f=C:dd:0,1 -s="reorder(i,k,l,j)" -s="precompute(B(i,k,l)*D(l,j),j,j)" -s="split(i,i0,i1,32)" -s="parallelize(i0,CPUThread,NoRaces)" -write-source=taco_kernel.c -write-compute=taco_compute.c -write-assembly=taco_assembly.c |
3 | 3 |
|
4 |
| -int compute(taco_tensor_t *A, taco_tensor_t *B, taco_tensor_t *C, taco_tensor_t *D) { |
| 4 | +int compute(taco_tensor_t *A, taco_tensor_t *B, taco_tensor_t *D, taco_tensor_t *C) { |
5 | 5 | int A1_dimension = (int)(A->dimensions[0]);
|
6 | 6 | int A2_dimension = (int)(A->dimensions[1]);
|
7 | 7 | double* restrict A_vals = (double*)(A->vals);
|
| 8 | + int B1_dimension = (int)(B->dimensions[0]); |
8 | 9 | int* restrict B1_pos = (int*)(B->indices[0][0]);
|
9 | 10 | int* restrict B1_crd = (int*)(B->indices[0][1]);
|
10 | 11 | int* restrict B2_pos = (int*)(B->indices[1][0]);
|
11 | 12 | int* restrict B2_crd = (int*)(B->indices[1][1]);
|
12 | 13 | int* restrict B3_pos = (int*)(B->indices[2][0]);
|
13 | 14 | int* restrict B3_crd = (int*)(B->indices[2][1]);
|
14 | 15 | double* restrict B_vals = (double*)(B->vals);
|
15 |
| - int C1_dimension = (int)(C->dimensions[0]); |
16 |
| - int C2_dimension = (int)(C->dimensions[1]); |
17 |
| - double* restrict C_vals = (double*)(C->vals); |
18 | 16 | int D1_dimension = (int)(D->dimensions[0]);
|
19 | 17 | int D2_dimension = (int)(D->dimensions[1]);
|
20 | 18 | double* restrict D_vals = (double*)(D->vals);
|
| 19 | + int C1_dimension = (int)(C->dimensions[0]); |
| 20 | + int C2_dimension = (int)(C->dimensions[1]); |
| 21 | + double* restrict C_vals = (double*)(C->vals); |
21 | 22 |
|
22 | 23 | #pragma omp parallel for schedule(static)
|
23 | 24 | for (int32_t pA = 0; pA < (A1_dimension * A2_dimension); pA++) {
|
24 | 25 | A_vals[pA] = 0.0;
|
25 | 26 | }
|
26 | 27 |
|
27 | 28 | #pragma omp parallel for schedule(runtime)
|
28 |
| - for (int32_t iB = B1_pos[0]; iB < B1_pos[1]; iB++) { |
| 29 | + for (int32_t i0 = 0; i0 < ((B1_dimension + 31) / 32); i0++) { |
| 30 | + int32_t pB1_begin = i0 * 32; |
| 31 | + int32_t iB = taco_binarySearchAfter(B1_crd, B1_pos[0], B1_pos[1], pB1_begin); |
| 32 | + int32_t pB1_end = B1_pos[1]; |
| 33 | + int32_t iB0 = B1_crd[iB]; |
29 | 34 | int32_t i = B1_crd[iB];
|
30 |
| - for (int32_t kB = B2_pos[iB]; kB < B2_pos[(iB + 1)]; kB++) { |
31 |
| - int32_t k = B2_crd[kB]; |
32 |
| - for (int32_t lB = B3_pos[kB]; lB < B3_pos[(kB + 1)]; lB++) { |
33 |
| - int32_t l = B3_crd[lB]; |
34 |
| - for (int32_t j = 0; j < D2_dimension; j++) { |
35 |
| - int32_t jA = i * A2_dimension + j; |
36 |
| - int32_t jC = k * C2_dimension + j; |
37 |
| - int32_t jD = l * D2_dimension + j; |
38 |
| - A_vals[jA] = A_vals[jA] + (B_vals[lB] * C_vals[jC]) * D_vals[jD]; |
| 35 | + int32_t i1 = i - i0 * 32; |
| 36 | + int32_t i1_end = 32; |
| 37 | + |
| 38 | + while (iB < pB1_end && i1 < i1_end) { |
| 39 | + iB0 = B1_crd[iB]; |
| 40 | + i = B1_crd[iB]; |
| 41 | + if (iB0 == i) { |
| 42 | + double* restrict workspace = 0; |
| 43 | + workspace = (double*)malloc(sizeof(double) * C2_dimension); |
| 44 | + |
| 45 | + for (int32_t kB = B2_pos[iB]; kB < B2_pos[(iB + 1)]; kB++) { |
| 46 | + int32_t k = B2_crd[kB]; |
| 47 | + for (int32_t pworkspace = 0; pworkspace < C2_dimension; pworkspace++) { |
| 48 | + workspace[pworkspace] = 0.0; |
| 49 | + } |
| 50 | + for (int32_t lB = B3_pos[kB]; lB < B3_pos[(kB + 1)]; lB++) { |
| 51 | + int32_t l = B3_crd[lB]; |
| 52 | + for (int32_t j = 0; j < C2_dimension; j++) { |
| 53 | + int32_t jD = l * D2_dimension + j; |
| 54 | + workspace[j] = workspace[j] + B_vals[lB] * D_vals[jD]; |
| 55 | + } |
| 56 | + } |
| 57 | + for (int32_t j = 0; j < C2_dimension; j++) { |
| 58 | + int32_t jA = i * A2_dimension + j; |
| 59 | + int32_t jC = k * C2_dimension + j; |
| 60 | + A_vals[jA] = A_vals[jA] + workspace[j] * C_vals[jC]; |
| 61 | + } |
39 | 62 | }
|
| 63 | + |
| 64 | + free(workspace); |
40 | 65 | }
|
| 66 | + iB += (int32_t)(iB0 == i); |
| 67 | + iB0 = B1_crd[iB]; |
| 68 | + i = B1_crd[iB]; |
| 69 | + i1 = i - i0 * 32; |
41 | 70 | }
|
42 | 71 | }
|
| 72 | + |
| 73 | + A->vals = (uint8_t*)A_vals; |
43 | 74 | return 0;
|
44 | 75 | }
|
0 commit comments