@@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
131131 type);
132132}
133133
134+ bool mlir::emitc::isSwitchOperandType (Type type) {
135+ auto intType = llvm::dyn_cast<IntegerType>(type);
136+ return isSupportedIntegerType (type) && intType.getWidth () != 1 &&
137+ intType.getWidth () != 8 ;
138+ }
139+
134140// / Check that the type of the initial value is compatible with the operations
135141// / result type.
136142static LogicalResult verifyInitializationAttribute (Operation *op,
@@ -1096,6 +1102,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10961102 return success ();
10971103}
10981104
1105+ // ===----------------------------------------------------------------------===//
1106+ // SwitchOp
1107+ // ===----------------------------------------------------------------------===//
1108+
1109+ // / Parse the case regions and values.
1110+ static ParseResult
1111+ parseSwitchCases (OpAsmParser &parser, DenseI64ArrayAttr &cases,
1112+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1113+ SmallVector<int64_t > caseValues;
1114+ while (succeeded (parser.parseOptionalKeyword (" case" ))) {
1115+ int64_t value;
1116+ Region ®ion = *caseRegions.emplace_back (std::make_unique<Region>());
1117+
1118+ if (parser.parseInteger (value) || parser.parseColon () ||
1119+ parser.parseRegion (region, /* arguments=*/ {}))
1120+ return failure ();
1121+ caseValues.push_back (value);
1122+ }
1123+ cases = parser.getBuilder ().getDenseI64ArrayAttr (caseValues);
1124+ return success ();
1125+ }
1126+
1127+ // / Print the case regions and values.
1128+ static void printSwitchCases (OpAsmPrinter &parser, Operation *op,
1129+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
1130+ for (auto [value, region] : llvm::zip (cases.asArrayRef (), caseRegions)) {
1131+ parser.printNewline ();
1132+ parser << " case " << value << " : " ;
1133+ parser.printRegion (*region, /* printEntryBlockArgs=*/ false );
1134+ }
1135+ return ;
1136+ }
1137+
1138+ ParseResult SwitchOp::parse (OpAsmParser &parser, OperationState &result) {
1139+ OpAsmParser::UnresolvedOperand arg;
1140+ DenseI64ArrayAttr casesAttr;
1141+ SmallVector<std::unique_ptr<Region>, 2 > caseRegionsRegions;
1142+ std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
1143+
1144+ if (parser.parseOperand (arg))
1145+ return failure ();
1146+
1147+ Type argType = parser.getBuilder ().getI32Type ();
1148+ // Parse optional type, else assume i32.
1149+ if (!parser.parseOptionalColon () && parser.parseType (argType))
1150+ return failure ();
1151+
1152+ auto loc = parser.getCurrentLocation ();
1153+ if (parser.parseOptionalAttrDict (result.attributes ))
1154+ return failure ();
1155+
1156+ if (failed (verifyInherentAttrs (result.name , result.attributes , [&]() {
1157+ return parser.emitError (loc)
1158+ << " '" << result.name .getStringRef () << " ' op " ;
1159+ })))
1160+ return failure ();
1161+
1162+ auto odsResult = parseSwitchCases (parser, casesAttr, caseRegionsRegions);
1163+ if (odsResult)
1164+ return failure ();
1165+
1166+ result.getOrAddProperties <SwitchOp::Properties>().cases = casesAttr;
1167+
1168+ if (parser.parseKeyword (" default" ) || parser.parseColon ())
1169+ return failure ();
1170+
1171+ if (parser.parseRegion (*defaultRegionRegion))
1172+ return failure ();
1173+
1174+ result.addRegion (std::move (defaultRegionRegion));
1175+ result.addRegions (caseRegionsRegions);
1176+
1177+ if (parser.resolveOperand (arg, argType, result.operands ))
1178+ return failure ();
1179+
1180+ return success ();
1181+ }
1182+
1183+ void SwitchOp::print (OpAsmPrinter &parser) {
1184+ parser << ' ' ;
1185+ parser << getArg ();
1186+ SmallVector<StringRef, 2 > elidedAttrs;
1187+ elidedAttrs.push_back (" cases" );
1188+ parser.printOptionalAttrDict ((*this )->getAttrs (), elidedAttrs);
1189+ parser << ' ' ;
1190+ printSwitchCases (parser, *this , getCasesAttr (), getCaseRegions ());
1191+ parser.printNewline ();
1192+ parser << " default" ;
1193+ parser << ' ' ;
1194+ parser.printRegion (getDefaultRegion (), /* printEntryBlockArgs=*/ true ,
1195+ /* printBlockTerminators=*/ true );
1196+
1197+ return ;
1198+ }
1199+
1200+ static LogicalResult verifyRegion (emitc::SwitchOp op, Region ®ion,
1201+ const Twine &name) {
1202+ auto yield = dyn_cast<emitc::YieldOp>(region.front ().back ());
1203+ if (!yield)
1204+ return op.emitOpError (" expected region to end with emitc.yield, but got " )
1205+ << region.front ().back ().getName ();
1206+
1207+ if (yield.getNumOperands () != 0 ) {
1208+ return (op.emitOpError (" expected each region to return " )
1209+ << " 0 values, but " << name << " returns "
1210+ << yield.getNumOperands ())
1211+ .attachNote (yield.getLoc ())
1212+ << " see yield operation here" ;
1213+ }
1214+ return success ();
1215+ }
1216+
1217+ LogicalResult emitc::SwitchOp::verify () {
1218+ if (!isSwitchOperandType (getArg ().getType ()))
1219+ return emitOpError (" unsupported type " ) << getArg ().getType ();
1220+
1221+ if (getCases ().size () != getCaseRegions ().size ()) {
1222+ return emitOpError (" has " )
1223+ << getCaseRegions ().size () << " case regions but "
1224+ << getCases ().size () << " case values" ;
1225+ }
1226+
1227+ DenseSet<int64_t > valueSet;
1228+ for (int64_t value : getCases ())
1229+ if (!valueSet.insert (value).second )
1230+ return emitOpError (" has duplicate case value: " ) << value;
1231+
1232+ if (failed (verifyRegion (*this , getDefaultRegion (), " default region" )))
1233+ return failure ();
1234+
1235+ for (auto [idx, caseRegion] : llvm::enumerate (getCaseRegions ()))
1236+ if (failed (verifyRegion (*this , caseRegion, " case region #" + Twine (idx))))
1237+ return failure ();
1238+
1239+ return success ();
1240+ }
1241+
1242+ unsigned emitc::SwitchOp::getNumCases () { return getCases ().size (); }
1243+
1244+ Block &emitc::SwitchOp::getDefaultBlock () { return getDefaultRegion ().front (); }
1245+
1246+ Block &emitc::SwitchOp::getCaseBlock (unsigned idx) {
1247+ assert (idx < getNumCases () && " case index out-of-bounds" );
1248+ return getCaseRegions ()[idx].front ();
1249+ }
1250+
1251+ void SwitchOp::getSuccessorRegions (
1252+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1253+ llvm::copy (getRegions (), std::back_inserter (successors));
1254+ return ;
1255+ }
1256+
1257+ void SwitchOp::getEntrySuccessorRegions (
1258+ ArrayRef<Attribute> operands,
1259+ SmallVectorImpl<RegionSuccessor> &successors) {
1260+ FoldAdaptor adaptor (operands, *this );
1261+
1262+ // If a constant was not provided, all regions are possible successors.
1263+ auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg ());
1264+ if (!arg) {
1265+ llvm::copy (getRegions (), std::back_inserter (successors));
1266+ return ;
1267+ }
1268+
1269+ // Otherwise, try to find a case with a matching value. If not, the
1270+ // default region is the only successor.
1271+ for (auto [caseValue, caseRegion] : llvm::zip (getCases (), getCaseRegions ())) {
1272+ if (caseValue == arg.getInt ()) {
1273+ successors.emplace_back (&caseRegion);
1274+ return ;
1275+ }
1276+ }
1277+ successors.emplace_back (&getDefaultRegion ());
1278+ return ;
1279+ }
1280+
1281+ void SwitchOp::getRegionInvocationBounds (
1282+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1283+ auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front ());
1284+ if (!operandValue) {
1285+ // All regions are invoked at most once.
1286+ bounds.append (getNumRegions (), InvocationBounds (/* lb=*/ 0 , /* ub=*/ 1 ));
1287+ return ;
1288+ }
1289+
1290+ unsigned liveIndex = getNumRegions () - 1 ;
1291+ const auto *iteratorToInt = llvm::find (getCases (), operandValue.getInt ());
1292+
1293+ liveIndex = iteratorToInt != getCases ().end ()
1294+ ? std::distance (getCases ().begin (), iteratorToInt)
1295+ : liveIndex;
1296+
1297+ for (unsigned regIndex = 0 , regNum = getNumRegions (); regIndex < regNum;
1298+ ++regIndex)
1299+ bounds.emplace_back (/* lb=*/ 0 , /* ub=*/ regIndex == liveIndex);
1300+
1301+ return ;
1302+ }
1303+
10991304// ===----------------------------------------------------------------------===//
11001305// TableGen'd op method definitions
11011306// ===----------------------------------------------------------------------===//
0 commit comments