-
Notifications
You must be signed in to change notification settings - Fork 219
Add a pass that fuses matmul and transpose operations #567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
For the transpose-matrix multiplication-transpose pattern in Llama2, perform fusion and vectorization during dialect reduction. The acceleration ratio at the operator level is 1.84 before and after fusion, while it is not visible at the model level.
| Value B = op->getOperand(1); | ||
| Value C = op->getOpResult(0); | ||
|
|
||
| tosa::ReshapeOp reshapeBOp = B.getDefiningOp<tosa::ReshapeOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you can use auto here, getDefiningOptosa::ReshapeOp(); we know the type of the op
linuxlonelyeagle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a brief review.
| if (!transposeBOp) { | ||
| return failure(); | ||
| } | ||
| Value::user_iterator reshapeCUserIt = C.getUsers().begin(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C.getUsers().empty() is good
| ShapedType newBType = | ||
| cast<ShapedType>(transposeBOp.getOperand(0).getType()); | ||
| ShapedType newCType = | ||
| cast<ShapedType>(transposeCOp->getOpResult(0).getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use ransposeCOp->getresult->getType()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or transposeCOp.getType()
| Value vlStep = rewriter.create<arith::ConstantIndexOp>(loc, vecSize); | ||
| Value zero = rewriter.create<arith::ConstantOp>( | ||
| loc, rewriter.getZeroAttr(elementType)); | ||
| const AffineExpr d0 = rewriter.getAffineDimExpr(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use const
|
|
||
| // Create pass through vector. | ||
| Value passThroughVec = rewriter.create<SplatOp>(loc, vectorTy, zero); | ||
| Value newA = rewriter.create<bufferization::ToMemrefOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to avoid using the bufferize dialect here?This is just a fusion pattern.
| Value aCol = rewriter.create<memref::DimOp>(loc, newA, c2); | ||
| Value bCol = rewriter.create<memref::DimOp>(loc, newB, c3); | ||
|
|
||
| Value upperBoundTmp = rewriter.create<arith::SubIOp>(loc, bCol, vlStep); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sub and add we can use affine,Rather than creating add and sub operations.
| // loopBody->addArguments(types, locs); | ||
| Block &loopBody = parOp.getRegion().front(); | ||
| rewriter.setInsertionPointToStart(&loopBody); | ||
| Value ivs0 = loopBody.getArguments()[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iv = loopOp.getLoopInductionVar
| newC, | ||
| ValueRange{c0, ivs1, ivs0, iv}); | ||
| Value idx = | ||
| nestedBuilder.create<arith::AddIOp>(nestedLoc, iv, vlStep); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use affine.apply add
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your feedback. While considering avoiding the bufferize dialect, I discovered this pass can be moved from the TOSA level to be completed at the Linalg level. I will resubmit all modifications after completing this change.
For the transpose-matrix multiplication-transpose pattern in Llama2, perform fusion and vectorization during dialect reduction. The acceleration ratio at the operator level is 1.84 before and after fusion, while it is not visible at the model level.
After applying this pass, the computational speed of transpose-matmul-transpose changed from 0.00528216 to 0.00191212.