@@ -131,6 +131,13 @@ bool mlir::emitc::isPointerWideType(Type type) {
131131 type);
132132}
133133
134+ bool mlir::emitc::isSwitchOperandType (Type type) {
135+ auto intType = llvm::dyn_cast<IntegerType>(type);
136+ return isa<emitc::OpaqueType>(type) ||
137+ (isSupportedIntegerType (type) && intType.getWidth () != 1 &&
138+ intType.getWidth () != 8 );
139+ }
140+
134141// / Check that the type of the initial value is compatible with the operations
135142// / result type.
136143static LogicalResult verifyInitializationAttribute (Operation *op,
@@ -1096,6 +1103,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10961103 return success ();
10971104}
10981105
1106+ // ===----------------------------------------------------------------------===//
1107+ // SwitchOp
1108+ // ===----------------------------------------------------------------------===//
1109+
1110+ // / Parse the case regions and values.
1111+ static ParseResult
1112+ parseSwitchCases (OpAsmParser &parser, DenseI64ArrayAttr &cases,
1113+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1114+ SmallVector<int64_t > caseValues;
1115+ while (succeeded (parser.parseOptionalKeyword (" case" ))) {
1116+ int64_t value;
1117+ Region ®ion = *caseRegions.emplace_back (std::make_unique<Region>());
1118+
1119+ if (parser.parseInteger (value) || parser.parseColon () ||
1120+ parser.parseRegion (region, /* arguments=*/ {}))
1121+ return failure ();
1122+ caseValues.push_back (value);
1123+ }
1124+ cases = parser.getBuilder ().getDenseI64ArrayAttr (caseValues);
1125+ return success ();
1126+ }
1127+
1128+ // / Print the case regions and values.
1129+ static void printSwitchCases (OpAsmPrinter &parser, Operation *op,
1130+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
1131+ for (auto [value, region] : llvm::zip (cases.asArrayRef (), caseRegions)) {
1132+ parser.printNewline ();
1133+ parser << " case " << value << " : " ;
1134+ parser.printRegion (*region, /* printEntryBlockArgs=*/ false );
1135+ }
1136+ return ;
1137+ }
1138+
1139+ ParseResult SwitchOp::parse (OpAsmParser &parser, OperationState &result) {
1140+ OpAsmParser::UnresolvedOperand arg;
1141+ DenseI64ArrayAttr casesAttr;
1142+ SmallVector<std::unique_ptr<Region>, 2 > caseRegionsRegions;
1143+ std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
1144+
1145+ if (parser.parseOperand (arg))
1146+ return failure ();
1147+
1148+ Type argType;
1149+ // Parse the case's type.
1150+ if (parser.parseColon () || parser.parseType (argType))
1151+ return failure ();
1152+
1153+ auto loc = parser.getCurrentLocation ();
1154+ if (parser.parseOptionalAttrDict (result.attributes ))
1155+ return failure ();
1156+
1157+ if (failed (verifyInherentAttrs (result.name , result.attributes , [&]() {
1158+ return parser.emitError (loc)
1159+ << " '" << result.name .getStringRef () << " ' op " ;
1160+ })))
1161+ return failure ();
1162+
1163+ auto odsResult = parseSwitchCases (parser, casesAttr, caseRegionsRegions);
1164+ if (odsResult)
1165+ return failure ();
1166+
1167+ result.getOrAddProperties <SwitchOp::Properties>().cases = casesAttr;
1168+
1169+ if (parser.parseKeyword (" default" ) || parser.parseColon ())
1170+ return failure ();
1171+
1172+ if (parser.parseRegion (*defaultRegionRegion))
1173+ return failure ();
1174+
1175+ result.addRegion (std::move (defaultRegionRegion));
1176+ result.addRegions (caseRegionsRegions);
1177+
1178+ if (parser.resolveOperand (arg, argType, result.operands ))
1179+ return failure ();
1180+
1181+ return success ();
1182+ }
1183+
1184+ void SwitchOp::print (OpAsmPrinter &parser) {
1185+ parser << ' ' ;
1186+ parser << getArg ();
1187+ SmallVector<StringRef, 2 > elidedAttrs;
1188+ elidedAttrs.push_back (" cases" );
1189+ parser.printOptionalAttrDict ((*this )->getAttrs (), elidedAttrs);
1190+ parser << ' ' ;
1191+ printSwitchCases (parser, *this , getCasesAttr (), getCaseRegions ());
1192+ parser.printNewline ();
1193+ parser << " default" ;
1194+ parser << ' ' ;
1195+ parser.printRegion (getDefaultRegion (), /* printEntryBlockArgs=*/ true ,
1196+ /* printBlockTerminators=*/ true );
1197+
1198+ return ;
1199+ }
1200+
1201+ static LogicalResult verifyRegion (emitc::SwitchOp op, Region ®ion,
1202+ const Twine &name) {
1203+ auto yield = dyn_cast<emitc::YieldOp>(region.front ().back ());
1204+ if (!yield)
1205+ return op.emitOpError (" expected region to end with emitc.yield, but got " )
1206+ << region.front ().back ().getName ();
1207+
1208+ if (yield.getNumOperands () != 0 ) {
1209+ return (op.emitOpError (" expected each region to return " )
1210+ << " 0 values, but " << name << " returns "
1211+ << yield.getNumOperands ())
1212+ .attachNote (yield.getLoc ())
1213+ << " see yield operation here" ;
1214+ }
1215+ return success ();
1216+ }
1217+
1218+ LogicalResult emitc::SwitchOp::verify () {
1219+ if (!isSwitchOperandType (getArg ().getType ()))
1220+ return emitOpError (" unsupported type " ) << getArg ().getType ();
1221+
1222+ if (getCases ().size () != getCaseRegions ().size ()) {
1223+ return emitOpError (" has " )
1224+ << getCaseRegions ().size () << " case regions but "
1225+ << getCases ().size () << " case values" ;
1226+ }
1227+
1228+ DenseSet<int64_t > valueSet;
1229+ for (int64_t value : getCases ())
1230+ if (!valueSet.insert (value).second )
1231+ return emitOpError (" has duplicate case value: " ) << value;
1232+
1233+ if (failed (verifyRegion (*this , getDefaultRegion (), " default region" )))
1234+ return failure ();
1235+
1236+ for (auto [idx, caseRegion] : llvm::enumerate (getCaseRegions ()))
1237+ if (failed (verifyRegion (*this , caseRegion, " case region #" + Twine (idx))))
1238+ return failure ();
1239+
1240+ return success ();
1241+ }
1242+
1243+ unsigned emitc::SwitchOp::getNumCases () { return getCases ().size (); }
1244+
1245+ Block &emitc::SwitchOp::getDefaultBlock () { return getDefaultRegion ().front (); }
1246+
1247+ Block &emitc::SwitchOp::getCaseBlock (unsigned idx) {
1248+ assert (idx < getNumCases () && " case index out-of-bounds" );
1249+ return getCaseRegions ()[idx].front ();
1250+ }
1251+
1252+ void SwitchOp::getSuccessorRegions (
1253+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1254+ llvm::copy (getRegions (), std::back_inserter (successors));
1255+ return ;
1256+ }
1257+
1258+ void SwitchOp::getEntrySuccessorRegions (
1259+ ArrayRef<Attribute> operands,
1260+ SmallVectorImpl<RegionSuccessor> &successors) {
1261+ FoldAdaptor adaptor (operands, *this );
1262+
1263+ // If a constant was not provided, all regions are possible successors.
1264+ auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg ());
1265+ if (!arg) {
1266+ llvm::copy (getRegions (), std::back_inserter (successors));
1267+ return ;
1268+ }
1269+
1270+ // Otherwise, try to find a case with a matching value. If not, the
1271+ // default region is the only successor.
1272+ for (auto [caseValue, caseRegion] : llvm::zip (getCases (), getCaseRegions ())) {
1273+ if (caseValue == arg.getInt ()) {
1274+ successors.emplace_back (&caseRegion);
1275+ return ;
1276+ }
1277+ }
1278+ successors.emplace_back (&getDefaultRegion ());
1279+ return ;
1280+ }
1281+
1282+ void SwitchOp::getRegionInvocationBounds (
1283+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1284+ auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front ());
1285+ if (!operandValue) {
1286+ // All regions are invoked at most once.
1287+ bounds.append (getNumRegions (), InvocationBounds (/* lb=*/ 0 , /* ub=*/ 1 ));
1288+ return ;
1289+ }
1290+
1291+ unsigned liveIndex = getNumRegions () - 1 ;
1292+ const auto *iteratorToInt = llvm::find (getCases (), operandValue.getInt ());
1293+
1294+ liveIndex = iteratorToInt != getCases ().end ()
1295+ ? std::distance (getCases ().begin (), iteratorToInt)
1296+ : liveIndex;
1297+
1298+ for (unsigned regIndex = 0 , regNum = getNumRegions (); regIndex < regNum;
1299+ ++regIndex)
1300+ bounds.emplace_back (/* lb=*/ 0 , /* ub=*/ regIndex == liveIndex);
1301+
1302+ return ;
1303+ }
1304+
10991305// ===----------------------------------------------------------------------===//
11001306// TableGen'd op method definitions
11011307// ===----------------------------------------------------------------------===//
0 commit comments