@@ -649,24 +649,29 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
649
649
return getShapes ().front ();
650
650
}
651
651
652
- // TODO: Support folding with more than 2 input shapes
653
- if (getShapes ().size () > 2 )
652
+ if (!adaptor.getShapes ().front ())
654
653
return nullptr ;
655
654
656
- if (!adaptor.getShapes ()[0 ] || !adaptor.getShapes ()[1 ])
657
- return nullptr ;
658
- auto lhsShape = llvm::to_vector<6 >(
659
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ()[0 ])
660
- .getValues <int64_t >());
661
- auto rhsShape = llvm::to_vector<6 >(
662
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ()[1 ])
655
+ SmallVector<int64_t , 6 > resultShape (
656
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes ().front ())
663
657
.getValues <int64_t >());
664
- SmallVector<int64_t , 6 > resultShape;
665
658
666
- // If the shapes are not compatible, we can't fold it.
667
- // TODO: Fold to an "error".
668
- if (!OpTrait::util::getBroadcastedShape (lhsShape, rhsShape, resultShape))
669
- return nullptr ;
659
+ for (auto next : adaptor.getShapes ().drop_front ()) {
660
+ if (!next)
661
+ return nullptr ;
662
+ auto nextShape = llvm::to_vector<6 >(
663
+ llvm::cast<DenseIntElementsAttr>(next).getValues <int64_t >());
664
+
665
+ SmallVector<int64_t , 6 > tmpShape;
666
+ // If the shapes are not compatible, we can't fold it.
667
+ // TODO: Fold to an "error".
668
+ if (!OpTrait::util::getBroadcastedShape (resultShape, nextShape, tmpShape))
669
+ return nullptr ;
670
+
671
+ resultShape.clear ();
672
+ std::copy (tmpShape.begin (), tmpShape.end (),
673
+ std::back_inserter (resultShape));
674
+ }
670
675
671
676
Builder builder (getContext ());
672
677
return builder.getIndexTensorAttr (resultShape);
0 commit comments