Skip to content

Commit

Permalink
fixed bug in code order generation for nested if statements
Browse files Browse the repository at this point in the history
  • Loading branch information
dz333 committed Sep 27, 2021
1 parent 388d3c7 commit 5d4bb6c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 34 deletions.
11 changes: 8 additions & 3 deletions src/main/scala/pipedsl/common/DAGSyntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,15 @@ object DAGSyntax {
}

/**
* Add the new commands to the end of this stage.
* Add the new commands either to the end or beginning of this stage.
*/
def mergeStmts(newCmds: Iterable[Command]): Unit = {
this.setCmds(this.getCmds ++ newCmds)
def mergeStmts(newCmds: Iterable[Command], addAfter: Boolean = true): Unit = {
val finalCmds = if (addAfter) {
this.getCmds ++ newCmds
} else {
newCmds ++ this.getCmds
}
this.setCmds(finalCmds)
}
}

Expand Down
23 changes: 17 additions & 6 deletions src/main/scala/pipedsl/common/Locks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import pipedsl.common.DAGSyntax.PStage
import pipedsl.common.Dataflow.DFMap
import pipedsl.common.Errors.InvalidLockState
import pipedsl.common.Syntax.Annotations.TypeAnnotation
import pipedsl.common.Syntax.{CLockEnd, CLockOp, CLockStart, Command, EVar, ICheckLockFree, ICheckLockOwned, IReleaseLock, IReserveLock, Id, LockArg}
import pipedsl.common.Syntax.{CLockEnd, CLockOp, CLockStart, Command, EVar, Expr, ICheckLockFree, ICheckLockOwned, ICondCommand, IReleaseLock, IReserveLock, Id, LockArg}
import pipedsl.common.Utilities.updateSetMap

import scala.util.parsing.input.Position
Expand Down Expand Up @@ -113,8 +113,21 @@ object Locks {
* @param stg The stage to modify
*/
def eliminateLockRegions(stg: PStage): Unit = {
//get all ids that we start or stop regions for in this stage
val (startedRegions, endedRegions) = stg.getCmds.foldLeft(
val tmpCmds: Iterable[Command] = eliminateUnconditionalLockRegions(stg.getCmds)
val newCmds = tmpCmds.foldLeft(List[Command]())((l, c) => c match {
case ICondCommand(cond, cs) =>
val ncs = eliminateUnconditionalLockRegions(cs)
l :+ ICondCommand(cond, ncs.toList).setPos(c.pos)
case _ => l :+ c
})
stg.setCmds(newCmds)
}

//This does not recurse into ICondCommands, thus the top level function
//recurses exactly 1 level (we assume there are not nested ICondCommands)
private def eliminateUnconditionalLockRegions(cmds: Iterable[Command]): Iterable[Command] = {
//get all ids that we start or stop regions for in this set of commands
val (startedRegions, endedRegions) = cmds.foldLeft(
(Set[Id](), Set[Id]()))((s:(Set[Id], Set[Id]), c) => c match {
case CLockStart(mod) => (s._1 + mod, s._2)
case CLockEnd(mod) => (s._1, s._2 + mod)
Expand All @@ -123,14 +136,12 @@ object Locks {
//anytime we start and end in the same stage, we don't need those
val unnecessaryReservations = startedRegions.intersect(endedRegions)
//returns only necessary reservation cmds and all other cmds
val newCmds = stg.getCmds.filter {
cmds.filter {
case CLockStart(mod) if unnecessaryReservations.contains(mod) => false
case CLockEnd(mod) if unnecessaryReservations.contains(mod) => false
case _ => true
}
stg.setCmds(newCmds)
}

/**
* Define common helper methods implicit classes.
*/
Expand Down
55 changes: 30 additions & 25 deletions src/main/scala/pipedsl/passes/CollapseStagesPass.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pipedsl.passes

import pipedsl.common.DAGSyntax._
import pipedsl.common.Locks.eliminateLockRegions
import pipedsl.common.Syntax._
import pipedsl.common.Utilities.{andExpr, getReachableStages, getUsedVars, isReceivingCmd, updateListMap}
import pipedsl.passes.Passes.StagePass
Expand Down Expand Up @@ -32,7 +33,7 @@ object CollapseStagesPass extends StagePass[List[PStage]] {
val isBranchesComb = s.condStages.forall(stg => stg.head.outEdges.exists(e => e.to == s.joinStage))
val isDefaultComb = s.defaultStages.head.outEdges.exists(e => e.to == s.joinStage)
//Merge in the first true and false stages since that delay is artificial
mergeStages(s, s.condStages.map(stg => stg.head) :+ s.defaultStages.head, false)
mergeStages(s, s.condStages.map(stg => stg.head) :+ s.defaultStages.head, isForward = false)
//Update IF stage metadata
s.condStages = s.condStages.map(stg => stg.tail)
s.defaultStages= s.defaultStages.tail
Expand All @@ -44,21 +45,21 @@ object CollapseStagesPass extends StagePass[List[PStage]] {
val newOutEdge = PipelineEdge(None, None, s, s.joinStage, joinStgInputs)
s.removeEdgesTo(s.joinStage)
s.addEdge(newOutEdge)
mergeStages(s, List(s.joinStage), false)
mergeStages(s, List(s.joinStage), isForward = false)
} else {
//merge join stage into its successor, as long as its successor isn't another If stage
//and that stage ALSO only has one predecessor
if (s.joinStage.succs.size == 1 && s.joinStage.succs.exists(p => p match {
case _: IfStage => false
case nstg => nstg.preds.size == 1
})) {
mergeStages(s.joinStage, List(s.joinStage.succs.head), false);
mergeStages(s.joinStage, List(s.joinStage.succs.head), isForward = false);
}
}
//there must only be one by construction
val priorstg = s.inEdges.head.from
//merge this into the prior stage since that delay was added unnecessarily
mergeStages(priorstg, List(s), false)
mergeStages(priorstg, List(s), isForward = false)
case _ =>
}

Expand All @@ -82,19 +83,30 @@ object CollapseStagesPass extends StagePass[List[PStage]] {
}

//Ensures that ICondCommands don't contain any ICondCommands
//Try to minimize the number of ICondCommands we generate in order to avoid confusingly generated code
private def flattenCondStmts(cond: Expr, stmts: Iterable[Command]): List[Command] = {
var condMap: Map[Expr, List[Command]] = ListMap()
stmts.foreach {
case ICondCommand(cex, cs) =>
val mapcondition = andExpr(Some(cond), Some(cex)).get
condMap = updateListMap(condMap, mapcondition, cs)
case c =>
val mapcondition = cond
condMap = updateListMap(condMap, mapcondition, c)
}
condMap.keys.foldLeft(List[Command]())((l, k) => {
l :+ ICondCommand(k, condMap(k))
var resultList: List[Command] = List()
val lastCmd = stmts.foldLeft(ICondCommand(cond, List()).setPos(cond.pos))((ic, s) => {
s match {
case ICondCommand(cex, cs) if ic.cond != cex =>
val condition = andExpr(Some(cond), Some(cex)).get
condition.setPos(cond.pos)
resultList = resultList :+ ic //push last command to output list
ICondCommand(condition, cs).setPos(s.pos) //return new conditional command with composite condition
case ICondCommand(_, cs) => //condition matches, just add our subcommands to it
ICondCommand(ic.cond, ic.cs ++ cs).setPos(ic.pos)
case c if ic.cond != cond =>
resultList = resultList :+ ic //push last command to output list
ICondCommand(cond, List(c)).setPos(c.pos) //start new command with original condition
case c =>
ICondCommand(ic.cond, ic.cs :+ c).setPos(ic.pos) //just add to command list
}
})
if (lastCmd.cs.nonEmpty) {
resultList :+ lastCmd //add in the remaining command if its subcommands not empty
} else {
resultList
}
}

//split the given commands into a pair (receiving, normal) commands
Expand Down Expand Up @@ -127,9 +139,6 @@ object CollapseStagesPass extends StagePass[List[PStage]] {
*/
private def mergeStages(target: PStage, srcs: Iterable[PStage], isForward: Boolean): Unit = {
var newstmts = List[Command]()
val lockids = srcs.foldLeft(Set[LockArg]())((ids, s) => {
ids ++ getLockIds(s.getCmds)
})
srcs.foreach(src => {
val hasIncorrectEdges = if (isForward) {
(s: PStage) => s.outEdges.exists(e => e.to != target) ||
Expand Down Expand Up @@ -157,11 +166,7 @@ object CollapseStagesPass extends StagePass[List[PStage]] {
//merge in the commands
var receivingStmts = List[Command]()
if (cond.isDefined) {
val needIds = lockids.diff(getLockIds(src.getCmds).toSet)
val noops = needIds.foldLeft(List[Command]())((l, id) => {
l :+ ILockNoOp(id)
})
val flattenedCmds = flattenCondStmts(cond.get, src.getCmds ++ noops)
val flattenedCmds = flattenCondStmts(cond.get, src.getCmds)
val (recvstmts, normalstmts) = splitReceivingStmts(flattenedCmds)
newstmts = newstmts ++ normalstmts
if (isForward) {
Expand Down Expand Up @@ -220,7 +225,7 @@ object CollapseStagesPass extends StagePass[List[PStage]] {
target.addEdge(newedge)
})
})
//merge all subsequent stages at the same time
target.mergeStmts(newstmts)
target.mergeStmts(newstmts, addAfter = !isForward) //merge all subsequent stages at the same time
eliminateLockRegions(target) //after merging then remove any unnecessary lock region statements immediately
}
}

0 comments on commit 5d4bb6c

Please sign in to comment.