diff --git a/examples/exn_output.pdl b/examples/exn_output.pdl new file mode 100644 index 0000000..e69de29 diff --git a/src/main/scala/pipedsl/Main.scala b/src/main/scala/pipedsl/Main.scala index f566a83..42d1e38 100644 --- a/src/main/scala/pipedsl/Main.scala +++ b/src/main/scala/pipedsl/Main.scala @@ -103,8 +103,10 @@ object Main { val specChecker = new SpeculationChecker(ctx) specChecker.check(recvProg, None) val lock_prog = LockOpTranslationPass.run(recvProg) - TimingTypeChecker.check(lock_prog, Some(basetypes)) - val exnprog = ExnTranslationPass.run(lock_prog) + TimingTypeChecker.check(lock_prog, Some(basetypes)) + val exnTranslationPass = new ExnTranslationPass() + val exnprog = exnTranslationPass.run(lock_prog) + new PrettyPrinter(None).printProgram(exnprog) if (printOutput) { val writer = new PrintWriter(outputFile) writer.write("Passed") diff --git a/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala b/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala index 4a690b8..e8b771f 100644 --- a/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala +++ b/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala @@ -110,8 +110,9 @@ object BSVPrettyPrinter { case BTruncate(e) => mkExprString("truncate(", toBSVExprStr(e), ")") case BStructAccess(rec, field) => toBSVExprStr(rec) + "." + toBSVExprStr(field) case BVar(name, _) => name - case BBOp(op, lhs, rhs, isInfix) if isInfix => mkExprString("(", toBSVExprStr(lhs), op, toBSVExprStr(rhs), ")") - case BBOp(op, lhs, rhs, isInfix) if !isInfix => mkExprString( op + "(", toBSVExprStr(lhs), ",", toBSVExprStr(rhs), ")") + case BBOp(op, lhs, rhs, isInfix, omitBrackets) if isInfix && !omitBrackets => mkExprString("(", toBSVExprStr(lhs), op, toBSVExprStr(rhs), ")") + case BBOp(op, lhs, rhs, isInfix, omitBrackets) if isInfix && omitBrackets => mkExprString(toBSVExprStr(lhs), op, toBSVExprStr(rhs)) + case BBOp(op, lhs, rhs, isInfix, _) if !isInfix => mkExprString( op + "(", toBSVExprStr(lhs), ",", toBSVExprStr(rhs), ")") case BUOp(op, expr) => mkExprString("(", op, toBSVExprStr(expr), ")") //TODO incorporate bit types into the typesystem properly //and then remove the custom pack/unpack operations @@ -238,10 +239,10 @@ object BSVPrettyPrinter { } def printBSVRule(rule: BRuleDef): Unit = { - val condString = if (rule.conds.nonEmpty) { - "(" + rule.conds.map(c => toBSVExprStr(c)).mkString(" && ") + ")" - } else { - "" + val condString = rule.conds match { + case BDontCare => "" + case _ => toBSVExprStr(rule.conds) + } w.write(mkStatementString("rule", rule.name, condString)) incIndent() diff --git a/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala b/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala index 59b5857..b33a472 100644 --- a/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala +++ b/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala @@ -452,7 +452,7 @@ object BSVSyntax { case class BStructLit(typ: BStruct, fields: Map[BVar, BExpr]) extends BExpr case class BStructAccess(rec: BExpr, field: BExpr) extends BExpr case class BVar(name: String, typ: BSVType) extends BExpr - case class BBOp(op: String, lhs: BExpr, rhs: BExpr, isInfix: Boolean = true) extends BExpr + case class BBOp(op: String, lhs: BExpr, rhs: BExpr, isInfix: Boolean = true, omitBrackets: Boolean = false) extends BExpr case class BUOp(op: String, expr: BExpr) extends BExpr case class BBitExtract(expr: BExpr, start: BIndex, end: BIndex) extends BExpr case class BConcat(first: BExpr, rest: List[BExpr]) extends BExpr @@ -502,7 +502,7 @@ object BSVSyntax { case class BStructDef(typ: BStruct, derives: List[String]) - case class BRuleDef(name: String, conds: List[BExpr], body: List[BStatement]) + case class BRuleDef(name: String, conds: BExpr, body: List[BStatement]) case class BMethodSig(name: String, typ: MethodType, params: List[BVar]) diff --git a/src/main/scala/pipedsl/codegen/bsv/BluespecGeneration.scala b/src/main/scala/pipedsl/codegen/bsv/BluespecGeneration.scala index cb97fc6..8efbaf0 100644 --- a/src/main/scala/pipedsl/codegen/bsv/BluespecGeneration.scala +++ b/src/main/scala/pipedsl/codegen/bsv/BluespecGeneration.scala @@ -1,6 +1,6 @@ package pipedsl.codegen.bsv -import BSVSyntax.{BBOp, _} +import BSVSyntax.{BBOp, BEmpty, _} import pipedsl.common.DAGSyntax.{PStage, PipelineEdge} import pipedsl.common.Errors.{UnexpectedCommand, UnexpectedExpr} import pipedsl.common.LockImplementation.{LockInterface, MethodInfo} @@ -398,7 +398,7 @@ object BluespecGeneration { private val outputData = BVar("data", translator.toType(mod.ret.getOrElse(TVoid()))) private val outputQueue = BVar("outputQueue", bsInts.getOutputQType(threadIdVar.typ, outputData.typ)) //Registers for exceptions - private val globalExnFlag = BVar("globalExnFlag", bsInts.getRegType(BBool)) + private val globalExnFlag = BVar("_globalExnFlag", bsInts.getRegType(BBool)) //Data types for passing between stages private val edgeStructInfo = getEdgeStructInfo(otherStages, addTId = true, addSpecId = mod.maybeSpec) //First stage should have exactly one input edge by definition @@ -432,7 +432,7 @@ object BluespecGeneration { vars + (m.name -> nvar) case _ => vars } - }) + (mod.name -> BVar("_lock_" + mod.name, translator.getLockedModType(LockImplementation.getModLockImpl))) + }) private val lockRegions: LockInfo = mod.modules.foldLeft[LockInfo](Map())((locks, m) => { locks + (m.name -> BVar(genLockRegionName(m.name), @@ -550,6 +550,19 @@ object BluespecGeneration { } private var stgSpecOrder: Int = 0 + + private def getMergedCond(lhs: BExpr, rhs: BExpr): BExpr = { + if (lhs == BDontCare && rhs == BDontCare){ + BDontCare + } else if (lhs == BDontCare){ + rhs + } else if (rhs == BDontCare){ + lhs + } else { + BBOp("&&", lhs, rhs, isInfix = true, omitBrackets = true) + } + } + /** * Given a pipeline stage and the necessary edge info, * generate a BSV module definition. @@ -570,7 +583,7 @@ object BluespecGeneration { val execRule = getStageRule(stg) //Add a stage kill rule if it needs one val killRule = getStageKillRule(stg) - val rules = if (mod.maybeSpec && killRule.isDefined) List(execRule, killRule.get) else List(execRule) + val rules = if ((mod.maybeSpec && killRule.isDefined) || (is_excepting(mod) && killRule.isDefined)) List(execRule, killRule.get) else List(execRule) stgSpecOrder = 0 translator.setVariablePrefix("") (sBody, rules) @@ -595,7 +608,7 @@ object BluespecGeneration { BDisplay(Some(mod.name.v + ":Thread %d:Executing Stage " + stg.name + " %t"), List(translator.toBSVVar(threadIdVar), BTime)) } else BEmpty - BRuleDef( genParamName(stg) + "_execute", blockingConds ++ recvConds, + BRuleDef( genParamName(stg) + "_execute", getMergedCond(blockingConds, recvConds), writeCmdDecls ++ writeCmdStmts ++ queueStmts :+ debugStmt) } @@ -608,46 +621,40 @@ object BluespecGeneration { */ private def getStageKillRule(stg: PStage): Option[BRuleDef] = { val killConds = getKillConds(stg.getCmds) - if (killConds.isEmpty) { - None - } else { - val recvConds = getRecvConds(stg.getCmds) - val debugStmt = if (debug) { - BDisplay(Some(mod.name.v + ":SpecId %d: Killing Stage " + stg.name + "%t"), - List(getSpecIdVal, BTime)) - } else { BEmpty } - val deqStmts = getEdgeQueueStmts(stg, stg.inEdges) ++ getRecvCmds(stg.getCmds) - val freeStmt = BExprStmt(bsInts.getSpecFree(specTable, getSpecIdVal)) - Some(BRuleDef( genParamName(stg) + "_kill", killConds ++ recvConds, deqStmts :+ freeStmt :+ debugStmt)) - } + val recvConds = getRecvConds(stg.getCmds) + val debugStmt = if (debug) { + BDisplay(Some(mod.name.v + ":SpecId %d: Killing Stage " + stg.name + "%t"), + List(getSpecIdVal, BTime)) + } else { BEmpty } + val deqStmts = getEdgeQueueStmts(stg, stg.inEdges) ++ getRecvCmds(stg.getCmds) + val freeStmt = BExprStmt(bsInts.getSpecFree(specTable, getSpecIdVal)) + Some(BRuleDef( genParamName(stg) + "_kill", getMergedCond(killConds, recvConds), deqStmts :+ freeStmt :+ debugStmt)) } - private def getKillConds(cmds: Iterable[Command]): List[BExpr] = { - cmds.foldLeft(List[BExpr]())((l, c) => c match { - //check definitely misspeculated - // isValid(spec) && !fromMaybe(True, check(spec)) + private def getKillConds(cmds: Iterable[Command]): BExpr = { + var resultingConds: BExpr = BDontCare + var readGlobalExnFlag: Boolean = false + cmds.foreach(c => c match { case CCheckSpec(_) => - l :+ BBOp("&&", BIsValid(translator.toBSVVar(specIdVar)), + resultingConds = getMergedCond(resultingConds, BBOp("&&", BIsValid(translator.toBSVVar(specIdVar)), //order is LATE if stage has no update BUOp("!", BFromMaybe(BBoolLit(true), - bsInts.getSpecCheck(specTable, getSpecIdVal, stgSpecOrder)))) + bsInts.getSpecCheck(specTable, getSpecIdVal, stgSpecOrder))))) //also need these in case we're waiting on responses we need to dequeue case IStageClear() => - l :+ globalExnFlag - + readGlobalExnFlag = true case ICondCommand(cond, cs) => - val condconds = getKillConds(cs) - if (condconds.nonEmpty) { - val nestedConds = condconds.tail.foldLeft(condconds.head)((exp, n) => { - BBOp("&&", exp, n) - }) - val newCond = BBOp("||", BUOp("!", translator.toExpr(cond)), nestedConds) - l :+ newCond - } else { - l - } - case _ => l + val nestedConds = getKillConds(cs) + val newCond = BBOp("||", BUOp("!", translator.toExpr(cond)), nestedConds) + resultingConds = getMergedCond(resultingConds, newCond) + case _ => BDontCare }) + + if(readGlobalExnFlag){ + resultingConds = BBOp("||", globalExnFlag, resultingConds) + } + + resultingConds } private def translateMethod(mod: BVar, mi: MethodInfo): BMethodInvoke = { @@ -670,115 +677,107 @@ object BluespecGeneration { * @param cmds The list of commands to translate * @return The list of translated blocking commands */ - private def getBlockingConds(cmds: Iterable[Command]): List[BExpr] = { - cmds.foldLeft(List[BExpr]())((l, c) => c match { + private def getBlockingConds(cmds: Iterable[Command]): BExpr = { + var resultingConds: BExpr = BDontCare + cmds.foreach(c => c match { case CLockStart(mod) => - l :+ bsInts.getCheckStart(lockRegions(mod)) + resultingConds = getMergedCond(resultingConds, bsInts.getCheckStart(lockRegions(mod))) case cl@IReserveLock(_, mem) => val methodInfo = LockImplementation.getCanReserveInfo(cl) if (methodInfo.isDefined) { - l :+ translateMethod(modParams(mem.id), methodInfo.get) + resultingConds = getMergedCond(resultingConds, translateMethod(modParams(mem.id), methodInfo.get)) } else { - l + resultingConds } case cl@ICheckLockOwned(mem, _, _) => val methodInfo = LockImplementation.getBlockInfo(cl) if (methodInfo.isDefined) { - l :+ translateMethod(getLockName(mem.id), methodInfo.get) + resultingConds = getMergedCond(resultingConds, translateMethod(getLockName(mem.id), methodInfo.get)) } else { - l + resultingConds } case im@IMemWrite(mem, addr, data, _, _, isAtomic) if isAtomic => val methodInfo = LockImplementation.getCanAtomicWrite(mem, addr, data, im.portNum) if (methodInfo.isDefined) { - l :+ translateMethod(getLockName(mem), methodInfo.get) + resultingConds = getMergedCond(resultingConds, translateMethod(getLockName(mem), methodInfo.get)) } else { - l + resultingConds } case im@IMemSend(_, _, mem, data, addr, _, _, isAtomic) if isAtomic => val methodInfo = LockImplementation.getCanAtomicAccess(mem, addr, data, im.portNum) if (methodInfo.isDefined) { - l :+ translateMethod(getLockName(mem), methodInfo.get) + resultingConds = getMergedCond(resultingConds, translateMethod(getLockName(mem), methodInfo.get)) } else { - l + resultingConds } //these are just to find EMemAccesses that are also atomic - case CAssign(_, rhs) => l ++ getBlockConds(rhs) - case CRecv(_, rhs) => l ++ getBlockConds(rhs) - case COutput(_) => l :+ bsInts.getOutCanWrite(outputQueue, translator.toBSVVar(threadIdVar)) + case CAssign(_, rhs) => resultingConds = getMergedCond(resultingConds, getBlockConds(rhs)) + case CRecv(_, rhs) => resultingConds = getMergedCond(resultingConds, getBlockConds(rhs)) + case COutput(_) => resultingConds = getMergedCond(resultingConds, bsInts.getOutCanWrite(outputQueue, translator.toBSVVar(threadIdVar))) //Execute ONLY if check(specid) == Valid(True) && isValid(specid) // fromMaybe(False, check(specId)) <=> check(specid) == Valid(True) - case CCheckSpec(isBlocking) if isBlocking => l ++ List( + case CCheckSpec(isBlocking) if isBlocking => resultingConds = getMergedCond(resultingConds, BBOp("||", BUOp("!", BIsValid(translator.toBSVVar(specIdVar))), BFromMaybe(BBoolLit(false), bsInts.getSpecCheck(specTable, getSpecIdVal, stgSpecOrder)) ) ) //Execute if check(specid) != Valid(False) //fromMaybe(True, check(specId)) <=> check(specid) == (Valid(True) || Invalid) - case CCheckSpec(isBlocking) if !isBlocking => l ++ List( + case CCheckSpec(isBlocking) if !isBlocking => resultingConds = getMergedCond(resultingConds, BBOp("||", BUOp("!", BIsValid(translator.toBSVVar(specIdVar))), //order is LATE if stage has no update BFromMaybe(BBoolLit(true), bsInts.getSpecCheck(specTable, getSpecIdVal, stgSpecOrder)) ) ) case ICondCommand(cond, cs) => - val condconds = getBlockingConds(cs) - if (condconds.nonEmpty) { - val nestedConds = condconds.tail.foldLeft(condconds.head)((exp, n) => { - BBOp("&&", exp, n) - }) - val newCond = BBOp("||", BUOp("!", translator.toExpr(cond)), nestedConds) - l :+ newCond - } else { - l - } - case _ => l + val nestedConds = getMergedCond(getBlockingConds(cs), BDontCare) + val newCond = BBOp("||", BUOp("!", translator.toExpr(cond)), nestedConds) + resultingConds = getMergedCond(resultingConds, newCond) + case _ => resultingConds }) + resultingConds } //only necessary to find atomic Reads and get their blocking conds - private def getBlockConds(e: Expr): List[BExpr] = e match { + private def getBlockConds(e: Expr): BExpr = e match { case EIsValid(ex) => getBlockConds(ex) case EFromMaybe(ex) => getBlockConds(ex) case EToMaybe(ex) => getBlockConds(ex) case EUop(_, ex) => getBlockConds(ex) - case EBinop(_, e1, e2) => getBlockConds(e1) ++ getBlockConds(e2) + case EBinop(_, e1, e2) => getMergedCond(getBlockConds(e1), getBlockConds(e2)) case em@EMemAccess(mem, index, _, _, _, isAtomic) if isAtomic => val methodInfo = LockImplementation.getCanAtomicRead(mem, index, em.portNum) if (methodInfo.isDefined) { - List(translateMethod(getLockName(mem), methodInfo.get)) - } else List() + translateMethod(getLockName(mem), methodInfo.get) + } else BDontCare case EBitExtract(num, _, _) => getBlockConds(num) //can't appear in cond of ternary - case ETernary(_, tval, fval) => getBlockConds(tval) ++ getBlockConds(fval) - case EApp(_, args) => args.foldLeft(List[BExpr]())((l, a) => { - l ++ getBlockConds(a) - }) - case ECall(_, _, args, isAtomic) => args.foldLeft(List[BExpr]())((l, a) => { - l ++ getBlockConds(a) - }) + case ETernary(_, tval, fval) => getMergedCond(getBlockConds(tval), getBlockConds(fval)) + case EApp(_, args) => + var resultingConds: BExpr = BDontCare + args.foreach(a => getMergedCond(resultingConds, getBlockConds(a))) + resultingConds + case ECall(_, _, args, isAtomic) => + var resultingConds: BExpr = BDontCare + args.foreach(a => getMergedCond(resultingConds, getBlockConds(a))) + resultingConds case ECast(_, exp) => getBlockConds(exp) - case _ => List() + case _ => BDontCare } - private def getRecvConds(cmds: Iterable[Command]): List[BExpr] = { - cmds.foldLeft(List[BExpr]())((l, c) => c match { + private def getRecvConds(cmds: Iterable[Command]): BExpr = { + var resultingConds: BExpr = BDontCare + cmds.foreach(c => c match { case IMemRecv(mem: Id, handle: EVar, _: Option[EVar]) => - l :+ bsInts.getCheckMemResp(modParams(mem), translator.toVar(handle), c.portNum, isLockedMemory(mem)) + resultingConds = getMergedCond(resultingConds, bsInts.getCheckMemResp(modParams(mem), translator.toVar(handle), c.portNum, isLockedMemory(mem))) case IRecv(handle, sender, _) => - l :+ bsInts.getModCheckHandle(modParams(sender), translator.toExpr(handle)) + resultingConds = getMergedCond(resultingConds, bsInts.getModCheckHandle(modParams(sender), translator.toExpr(handle))) case ICondCommand(cond, cs) => - val condconds = getRecvConds(cs) - if (condconds.nonEmpty) { - val nestedConds = condconds.tail.foldLeft(condconds.head)((exp, n) => { - BBOp("&&", exp, n) - }) - val newCond = BBOp("||", BUOp("!", translator.toExpr(cond)), nestedConds) - l :+ newCond - } else { - l - } - case _ => l + val nestedConds = getMergedCond(getRecvConds(cs), BDontCare) + val newCond = BBOp("||", BUOp("!", translator.toExpr(cond)), nestedConds) + resultingConds = getMergedCond(resultingConds, newCond) + case _ => resultingConds }) + resultingConds } /** @@ -1375,6 +1374,7 @@ object BluespecGeneration { case CAssign(_, _) => None case CExpr(_) => None case CEmpty() => None + case ISetGlobalExnFlag(state) => Some(BModAssign(globalExnFlag, BBoolLit(state))) case _: IStageClear => None case _: InternalCommand => throw UnexpectedCommand(cmd) case CRecv(_, _) => throw UnexpectedCommand(cmd) diff --git a/src/main/scala/pipedsl/codegen/bsv/BluespecInterfaces.scala b/src/main/scala/pipedsl/codegen/bsv/BluespecInterfaces.scala index b0f62ff..69894fc 100644 --- a/src/main/scala/pipedsl/codegen/bsv/BluespecInterfaces.scala +++ b/src/main/scala/pipedsl/codegen/bsv/BluespecInterfaces.scala @@ -39,21 +39,21 @@ class BluespecInterfaces() { val debugStart = if (debug) { BDisplay(Some("Starting Pipeline %t"), List(BTime)) } else BEmpty val initRule = BRuleDef( name = "initTB", - conds = List(initCond), + conds = initCond, body = initStmts :+ setStartReg :+ debugStart ) val timerRule = BRuleDef( name = "timerCount", - conds = List(), + conds = BDontCare, body = List(BModAssign(timerReg, BBOp("+", timerReg, BOne))) ) val timerDone = BBOp(">=", timerReg, BIntLit(1000000,10,32)) val doneConds = if (modDone.isEmpty) { - List(timerDone) + timerDone } else { - List(BBOp("||", timerDone, modDone.reduce((l, r) => { + BBOp("||", timerDone, modDone.reduce((l, r) => { BBOp("&&", l, r)} - ))) + )) } val doneRule = BRuleDef( name = "stopTB", diff --git a/src/main/scala/pipedsl/common/Syntax.scala b/src/main/scala/pipedsl/common/Syntax.scala index 5af97c3..1efc184 100644 --- a/src/main/scala/pipedsl/common/Syntax.scala +++ b/src/main/scala/pipedsl/common/Syntax.scala @@ -39,6 +39,9 @@ object Syntax { sealed trait SpeculativeAnnotation { var maybeSpec: Boolean = false } + sealed trait ExceptionAnnotation { + var isExcepting: Boolean = false + } sealed trait LockInfoAnnotation { var memOpType: Option[LockType] = None var granularity: LockGranularity = General @@ -649,6 +652,7 @@ object Syntax { case class IAbort(mem: Id) extends InternalCommand case class IStageClear() extends InternalCommand + case class ISetGlobalExnFlag(state: Boolean) extends InternalCommand case class ICondCommand(cond: Expr, cs: List[Command]) extends InternalCommand case class IUpdate(specId: Id, value: EVar, originalSpec: EVar) extends InternalCommand case class ICheck(specId: Id, value: EVar) extends InternalCommand @@ -695,6 +699,7 @@ object Syntax { def map(f : Command => Command) : ExceptBlock def foreach(f : Command => Unit) :Unit def get :Command + def args : List[Id] def copyMeta(other :ExceptBlock) = this.setPos(other.pos) } @@ -703,12 +708,14 @@ object Syntax { { override def map(f: Command => Command): ExceptBlock = this override def foreach(f: Command => Unit): Unit = () + override def args : List[Id] = throw new NoSuchElementException("EmptyExcept") override def get :Command = throw new NoSuchElementException("EmptyExcept") } - case class ExceptFull(args: List[Id], c: Command) extends ExceptBlock + case class ExceptFull(exn_args: List[Id], c: Command) extends ExceptBlock { override def map(f: Command => Command): ExceptBlock = ExceptFull(args, f(c)).copyMeta(this) override def foreach(f: Command => Unit): Unit = f(c) + override def args : List[Id] = exn_args override def get :Command = c } @@ -719,12 +726,13 @@ object Syntax { body: Command, commit_blk: Option[Command], except_blk: ExceptBlock) - extends Definition with RecursiveAnnotation with SpeculativeAnnotation with HasCopyMeta + extends Definition with RecursiveAnnotation with SpeculativeAnnotation with ExceptionAnnotation with HasCopyMeta { override val copyMeta: HasCopyMeta => ModuleDef = { case from :ModuleDef => maybeSpec = from.maybeSpec + isExcepting = from.isExcepting isRecursive = from.isRecursive pos = from.pos this diff --git a/src/main/scala/pipedsl/common/Utilities.scala b/src/main/scala/pipedsl/common/Utilities.scala index f8ebcf9..a48096b 100644 --- a/src/main/scala/pipedsl/common/Utilities.scala +++ b/src/main/scala/pipedsl/common/Utilities.scala @@ -198,6 +198,7 @@ object Utilities { }) case ILockNoOp(_) => Set() case IStageClear() => Set() + case ISetGlobalExnFlag(_) => Set() case IAbort(_) => Set() case CLockStart(_) => Set() case CLockEnd(_) => Set() diff --git a/src/main/scala/pipedsl/passes/ExnTranslationPass.scala b/src/main/scala/pipedsl/passes/ExnTranslationPass.scala index 952b903..c269229 100644 --- a/src/main/scala/pipedsl/passes/ExnTranslationPass.scala +++ b/src/main/scala/pipedsl/passes/ExnTranslationPass.scala @@ -4,8 +4,11 @@ import pipedsl.common.Syntax._ import pipedsl.passes.Passes.{ModulePass, ProgPass} import pipedsl.typechecker.BaseTypeChecker.replaceNamedType -object ExnTranslationPass extends ModulePass[ModuleDef] with ProgPass[Prog]{ - private var exnArgMap = Map[Id, Id]() +class ExnTranslationPass extends ModulePass[ModuleDef] with ProgPass[Prog]{ + private var exnArgIdMap = Map[Id, Id]() + private var exnArgTypeMap = Map[Id, Type]() + + private val localExnFlag = EVar(Id("_localExnFlag")) override def run(m: ModuleDef): ModuleDef = { @@ -22,29 +25,38 @@ object ExnTranslationPass extends ModulePass[ModuleDef] with ProgPass[Prog]{ val new_m = addExnVars(m) new_m.name.typ = m.name.typ val modified_exnblk = m.except_blk.map(convertExnArgsId) - createNewStg(new_m.copy(body = new_m.body, commit_blk = new_m.commit_blk, except_blk = modified_exnblk)) + createNewStg(new_m.copy(body = new_m.body, commit_blk = new_m.commit_blk, except_blk = modified_exnblk).copyMeta(m)) } override def run(p: Prog): Prog = p.copy(moddefs = p.moddefs.map(m => run(m))) def addExnVars(m: ModuleDef): ModuleDef = { + localExnFlag.typ = Some(TBool()) + localExnFlag.id.typ = localExnFlag.typ val fixed_except = m.except_blk match { case ExceptFull(args, c) => var arg_count = 0 args.foreach(arg => { - exnArgMap = exnArgMap + (arg -> Id("_exnArg_"+arg_count.toString())) - arg_count += 1 + arg.typ match { + case Some(t: Type) => + val newExnArgId = Id("_exnArg_"+arg_count.toString()) + arg_count += 1 + exnArgIdMap = exnArgIdMap + (arg -> newExnArgId) + exnArgTypeMap = exnArgTypeMap + (newExnArgId -> t) + case _ => + arg_count += 1 + } }) case ExceptEmpty() => CEmpty() } - val newAssignments = exnArgMap.foldLeft(CSeq(CEmpty(), CEmpty()))((c, id_mapping) => { + val newAssignments = exnArgIdMap.foldLeft(CSeq(CEmpty(), CEmpty()))((c, id_mapping) => { val (lhs, rhs) = id_mapping val setCurrArg = CAssign(EVar(lhs), EBool(false)) CSeq(c, setCurrArg) }) - m.copy(body = CSeq(IStageClear(), convertPrimitives(m.body)), commit_blk = m.commit_blk, except_blk = m.except_blk) + m.copy(body = CSeq(IStageClear(), convertPrimitives(m.body)), commit_blk = m.commit_blk, except_blk = m.except_blk).copyMeta(m) } def convertPrimitives(c: Command): Command = { @@ -53,24 +65,23 @@ object ExnTranslationPass extends ModulePass[ModuleDef] with ProgPass[Prog]{ case CIf(cond, cons, alt) => CIf(cond, convertPrimitives(cons), convertPrimitives(alt)).copyMeta(c) case CTBar(c1, c2) => CTBar(convertPrimitives(c1), CSeq(IStageClear(), convertPrimitives(c2))).copyMeta(c) case CSplit(cases, default) => - val newCases = List[CaseObj]() - val newDefault = convertPrimitives(default) - for (index <- cases.indices) { - val newBody = convertPrimitives(cases(index).body) - newCases :+ cases(index).copy(body = newBody) - } - CSplit(newCases, newDefault) + val newCases = cases.map(c => CaseObj(c.cond, convertPrimitives(c.body))) + CSplit(newCases, convertPrimitives(default)).copyMeta(c) case CExcept(args) => - val localflag = EVar(Id("_localExnFlag")) - localflag.typ = Some(TBool()) - localflag.id.typ = localflag.typ - - val setLocalErrFlag = CAssign(localflag, EBool(true)).copyMeta(c) + val setLocalErrFlag = CAssign(localExnFlag, EBool(true)).copyMeta(c) var arg_count = 0 val setArgs: Command = args.foldLeft[Command](CSeq(setLocalErrFlag, CEmpty()))((c, arg) => { - val setCurrArg = CAssign(EVar(Id("_exnArg_"+arg_count.toString())), arg).copyMeta(c) - arg_count += 1 - CSeq(c, setCurrArg).copyMeta(c) + arg.typ match { + case Some(t: Type) => + val translatedVarId = Id("_exnArg_"+arg_count.toString()) + translatedVarId.setType(exnArgTypeMap.getOrElse(translatedVarId, TVoid())) + val translatedVar = EVar(translatedVarId) + translatedVar.typ = translatedVarId.typ + val setCurrArg = CAssign(translatedVar, arg).copyMeta(c) + arg_count += 1 + CSeq(c, setCurrArg).copyMeta(c) + case _ => c + } }) setArgs case _ => c @@ -79,28 +90,27 @@ object ExnTranslationPass extends ModulePass[ModuleDef] with ProgPass[Prog]{ def convertExnArgsId(c: Command): Command = { c match { - case CSeq(c1, c2) => CSeq(convertExnArgsId(c1), convertExnArgsId(c2)) - case CIf(cond, cons, alt) => CIf(cond, convertExnArgsId(cons), convertExnArgsId(alt)) - case CTBar(c1, c2) => CSeq(convertExnArgsId(c1), convertExnArgsId(c2)); + case CSeq(c1, c2) => CSeq(convertExnArgsId(c1), convertExnArgsId(c2)).copyMeta(c) + case CIf(cond, cons, alt) => CIf(cond, convertExnArgsId(cons), convertExnArgsId(alt)).copyMeta(c) + case CTBar(c1, c2) => CSeq(convertExnArgsId(c1), convertExnArgsId(c2)).copyMeta(c); case CSplit(cases, default) => - val newCases = List[CaseObj]() - val newDefault = convertExnArgsId(default) - for (index <- cases.indices) { - val newBody = convertExnArgsId(cases(index).body) - newCases :+ cases(index).copy(body = newBody) - } - CSplit(newCases, newDefault) + val newCases = cases.map(c => CaseObj(c.cond, convertExnArgsId(c.body))) + CSplit(newCases, convertExnArgsId(default)).copyMeta(c) case CAssign(v, exp) => - val newv = EVar(exnArgMap.getOrElse(v.id, v.id)).setPos(v.pos) - CAssign(newv, exp) + val newv = EVar(exnArgIdMap.getOrElse(v.id, v.id)).setPos(v.pos) + newv.typ = Some(exnArgTypeMap.getOrElse(v.id, TVoid())) + CAssign(newv, exp).copyMeta(c) case CPrint(args) => val newArgs = args.foldLeft(List[Expr]())((l, arg) => { arg match { - case EVar(id) => l :+ EVar(exnArgMap.getOrElse(id, id)).setPos(arg.pos) + case EVar(id) => + val newv = EVar(exnArgIdMap.getOrElse(id, id)).setPos(c.pos) + newv.typ = Some(exnArgTypeMap.getOrElse(id, TVoid())) + l :+ newv case _ => l } }) - CPrint(newArgs) + CPrint(newArgs).copyMeta(c) case _ => c } } @@ -112,19 +122,20 @@ object ExnTranslationPass extends ModulePass[ModuleDef] with ProgPass[Prog]{ } val except_stmts = m.except_blk match { case ExceptFull(_, c) => -// m.modules.filter() - CSeq(IAbort(m.name), c) + val setGlobalExnFlag = ISetGlobalExnFlag(true) + val unsetGlobalExnFlag = ISetGlobalExnFlag(false) + val abortStmts = m.modules.foldLeft(CSeq(CEmpty(), CEmpty()))((c, mod) => + mod.typ match { + case TLockedMemType(mem, _, _) => CSeq(c, IAbort(mod.name)) + case _ => c + }) + CSeq(setGlobalExnFlag, CSeq(abortStmts, CTBar(c, unsetGlobalExnFlag))) case ExceptEmpty() => CEmpty() } - val localflag = EVar(Id("_localExnFlag")) - localflag.typ = Some(TBool()) - localflag.id.typ = localflag.typ - val checkLocalFlag = CIf(localflag, commit_stmts, except_stmts) - val newBody = CTBar(m.body, checkLocalFlag) + val checkLocalFlag = CIf(localExnFlag, commit_stmts, except_stmts) + val newBody = CSeq(m.body, checkLocalFlag) - val inputTyps = m.inputs.foldLeft[List[Type]](List())((l, p) => { l :+ p.typ }) //TODO require memory or module types - m.name.typ = Some(TModType(inputTyps, List[Type](), m.ret, Some(m.name))) - m.copy(body = newBody, commit_blk = None, except_blk = ExceptEmpty()) + m.copy(body = newBody, commit_blk = m.commit_blk, except_blk = m.except_blk).copyMeta(m) } } diff --git a/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala b/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala index c14f766..a271516 100644 --- a/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala +++ b/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala @@ -100,7 +100,15 @@ object BaseTypeChecker extends TypeChecks[Id, Type] { // // } // env.add() - m.except_blk.foreach(checkCommand(m.name, _, bodyEnv)) + if(m.except_blk.isInstanceOf[ExceptFull]){ + val exnenv = m.except_blk.args.foldLeft[Environment[Id, Type]](bodyEnv)((env, arg) => { + arg.typ match { + case Some(t: Type) => env.add(arg, t) + case None => env.add(arg, TVoid()) + } + }) + m.except_blk.foreach(checkCommand(m.name, _, exnenv)) + } outEnv }