Skip to content

Incorrect code generated when contraction index is not innermost #555

@smr97

Description

@smr97

Hi, I am trying out different index orders for the following contraction:

T1[p, c, q, ] = start[p, q, r, ] * C3[c, r, ]

where T1 and start are 3D sparse tensors, and C3 is a sparse matrix. For the following index ordering, I suspect that the code generated (compute kernel) is incorrect:

Expression 1: T1[q, c, p, ] = start[q, r, p, ] * C3[r, c, ]                                                                                                   

Here is the kernel I get from TACO:

   2 int compute(taco_tensor_t *t1, taco_tensor_t *start, taco_tensor_t *c3) {                                                                               
    3   int t11_dimension = (int)(t1->dimensions[0]);                                                                                                         
    4   double* restrict t1_vals = (double*)(t1->vals);                                                                                                       
    5   int start1_dimension = (int)(start->dimensions[0]);                                                                                                   
    6   int* restrict start2_pos = (int*)(start->indices[1][0]);                                                                                              
    7   int* restrict start2_crd = (int*)(start->indices[1][1]);                                                                                              
    8   int* restrict start3_pos = (int*)(start->indices[2][0]);                                                                                              
    9   int* restrict start3_crd = (int*)(start->indices[2][1]);                                                                                              
   10   double* restrict start_vals = (double*)(start->vals);                                                                                                 
   11   int c31_dimension = (int)(c3->dimensions[0]);                                                                                                         
   12   int* restrict c32_pos = (int*)(c3->indices[1][0]);                                                                                                    
   13   int* restrict c32_crd = (int*)(c3->indices[1][1]);                                                                                                    
   14   double* restrict c3_vals = (double*)(c3->vals);                                                                                                       
   15                                                                                                                                                         
   16   int32_t pt1 = 0;                                                                                                                                      
   17                                                                                                                                                         
   18   for (int32_t q = 0; q < start1_dimension; q++) {                                                                                                      
   19     for (int32_t rstart = start2_pos[q]; rstart < start2_pos[(q + 1)]; rstart++) {                                                                      
   20       int32_t r = start2_crd[rstart];                                                                                                                   
   21       for (int32_t cc3 = c32_pos[r]; cc3 < c32_pos[(r + 1)]; cc3++) {                                                                                   
   22         for (int32_t pstart = start3_pos[rstart]; pstart < start3_pos[(rstart + 1)]; pstart++) {                                                        
   23           t1_vals[pt1] = 0.0;                                                                                                                           
   24           t1_vals[pt1] = t1_vals[pt1] + start_vals[pstart] * c3_vals[cc3];                                                                              
   25           pt1++;                                                                                                                                        
   26         }                                                                                                                                               
   27       }                                                                                                                                                 
   28     }                                                                                                                                                   
   29   }                                                                                                                                                     
   30   return 0;                                                                                                                                             
   31 }

It looks like the update to t1_vals is being done at a different position for every value of r in line 20.
Since the contraction is over r, it should in-fact be done at the same position in t1_vals.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions