Implement more meaningful Reshape operation
              
              #883
            
            
                  
                    
                      ricardoV94
                    
                  
                
                  started this conversation in
                Ideas
              
            Replies: 1 comment
-
| Recent related discussions: #1201 and #1192 (comment) | 
Beta Was this translation helpful? Give feedback.
                  
                    0 replies
                  
                
            
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
        
    
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Description
Analyzing graphs with reshape operations is rather complex because Reshape represents not "the meaning", but rather "the final look" of the operation.
Except for esoteric cases where
Reshapeshapes may come from a complex computation / shapes of other variables, it is usually a case of multiplying some dimensions (merging) and diving others (splitting), plus squeezing/expand_dims.The last two are well encoded by DimShuffle, but there's nothing nice for the first two.
What if we had:
It almost begs for an extension of
DimShuffle, which was brought up before: Theano/Theano#4640Splitting dims is trickier, because there are many choices, we can split in different orders and sizes
split_dimsis still unfortunate because we don't have symbolic dims. We can saysplit_dims(..., sizes=(x.shape[0], x.shape[1]))though, which is still a bit more readable than Reshape (specially with the sneaky -1).How would it be used
Users will probably not know about this specialized Op, but in our internal uses where we know this is the goal we can introduce it. This is most of the cases I've seen: tensordot, tile, repeat..., matmul rewrites
We can also try to pay the one time cost of canonicalizing arbitrary Reshapes into join_dims / split_dims. In the end we can specialize back to
ReshapeExisting pain points
An example where Reshape is currently hard to work with is during vectorization. If we have a common graph like
reshape(x, x.shape[0] * x.shape[1], -1)we cannot return the desired outputreshape(new_x, x.shape[0], x.shape[1] * x.shape[2], -1)eagerly because there is a chain of complex operations we must vectorize before we get to theReshapenode (Shape->Subtensor->Mul->MakeVector). So we need to put it in a costly Blockwise and try our best to remove it during rewrites. This came up in #722 when vectorizingtensordotto get abatched_tensordot.Such a problem wouldn't exist with a
join_dims, although it would still exist to some extent with asplit_dims.Another is for repeated-element-irrelevant reductions, where we should be able to just ignore the reshape:
It also makes rewrites to remove/lift reshapes much simpler than they currently are:
pytensor/pytensor/tensor/rewriting/shape.py
Lines 798 to 895 in bf73f8a
Precedence
This is somewhat related to why we have
SecondandAlloc. The first one is easier to reason about because it tells us more immediately that we are broadcasting with the shape of a variable, whereas Alloc specifies the desired output without its meaning (specially after some rewrites, where the shape may become dissociated from the original variable)pytensor/pytensor/tensor/rewriting/basic.py
Lines 3 to 23 in d62f4b1
Beta Was this translation helpful? Give feedback.
All reactions