Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify cps #26

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
48 changes: 29 additions & 19 deletions headers/gensym/state_tsnt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,15 @@ class Frame {
public:
using Env = std::map<Id, PtrVal>;
using Cont = std::function<std::monostate(SS&, PtrVal)>;
Cont cont;
size_t prev_stack_size;
Cont k;
private:
Env env;
public:
Frame(Cont ct): cont(ct), env() {}
Frame() : env() {}
Frame(Env env) : env(std::move(env)) {}
Frame() : env(std::map<Id, PtrVal>{}) {}
Frame(size_t ss, Cont k): prev_stack_size(ss), k(k), env() {}

size_t size() { return env.size(); }
PtrVal lookup_id(Id id) const { return env.at(id); }
Frame&& assign(Id id, PtrVal v) {
Expand Down Expand Up @@ -204,12 +206,16 @@ class Stack {
return std::move(*this);
}
PtrVal error_loc() { return errno_location; }
typename Frame::Cont pop(size_t keep) {
auto &it = env.at(env.size() - 1);
auto ret = it.cont;
Stack&& pop(size_t keep) {
mem.take(keep);
env.take(env.size() - 1);
return ret;
return std::move(*this);
}
std::monostate pop(SS& s, PtrVal v) {
auto f = env.at(env.size() - 1);
mem.take(f.prev_stack_size);
env.take(env.size() - 1);
return f.k(s, v);
}
Stack&& push() {
return push(Frame());
Expand All @@ -218,8 +224,8 @@ class Stack {
env.push_back(std::move(f));
return std::move(*this);
}
Stack&& push(std::function<std::monostate(SS&, PtrVal)> cont) {
return push(Frame(cont));
Stack&& push(size_t ss, std::function<std::monostate(SS&, PtrVal)> k) {
return push(Frame(ss, k));
}

Stack&& assign(Id id, PtrVal val) {
Expand Down Expand Up @@ -382,10 +388,10 @@ class SS {
return read_res;
}
PtrVal at_simpl(PtrVal addr) {
auto loc = addr->to_LocV();
ASSERT(loc != nullptr, "Lookup an non-address value");
if (loc->k == LocV::kStack) return stack.at(loc->l);
return heap.at(loc->l);
auto loc = addr->to_LocV();
ASSERT(loc != nullptr, "Lookup an non-address value");
if (loc->k == LocV::kStack) return stack.at(loc->l);
return heap.at(loc->l);
}
PtrVal at(PtrVal addr, size_t size) {
if (auto loc = addr->to_LocV()) {
Expand Down Expand Up @@ -461,12 +467,16 @@ class SS {
stack.push();
return std::move(*this);
}
SS&& push(std::function<std::monostate(SS&, PtrVal)> cont) {
stack.push(cont);
SS&& push(size_t ss, std::function<std::monostate(SS&, PtrVal)> cont) {
stack.push(ss, cont);
return std::move(*this);
}
SS&& pop(size_t keep) {
stack.pop(keep);
return std::move(*this);
}
typename Frame::Cont pop(size_t keep) {
return stack.pop(keep);
std::monostate pop(PtrVal v) {
return stack.pop(*this, v);
}
SS&& assign(Id id, PtrVal val) {
stack.assign(id, val);
Expand Down Expand Up @@ -586,8 +596,8 @@ inline std::monostate cps_apply(PtrVal v, SS ss, List<PtrVal> args, std::functio
ABORT("cps_apply: not applicable");
}

inline std::monostate cont_apply(std::function<std::monostate(SS&, PtrVal)> cont, SS& ss, PtrVal val) {
return cont(ss, val);
inline std::monostate pop_cont_apply(SS& ss, PtrVal val) {
return ss.pop(val);
}

#endif
33 changes: 30 additions & 3 deletions src/main/scala/gensym/Codegen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ trait GenericGSCodeGen extends CppSAICodeGenBase {
case _ => super.traverse(n)
}

// Note: this enhances the EmitHelper in lms-clean, to be merged
implicit class EmitHelperExt(val sc: StringContext) {
def realEmit(x: Any): Unit = x match {
case (x: Def) => shallow(x)
case (n: Node) => shallow(n)
case (s: String) => emit(s)
case (i: Int) => emit(i.toString)
case (args: List[Def]) =>
if (args.nonEmpty) {
realEmit(args(0))
args.tail.foreach { arg => emit(", "); realEmit(arg) }
}
}
def es(args: Any*): Unit = {
val strings = sc.parts.iterator
val expressions = args.iterator
emit(strings.next)
while(strings.hasNext) {
realEmit(expressions.next)
emit(strings.next)
}
}
def esln(args: Any*): Unit = {
es(args:_*)
emitln()
}
}

override def shallow(n: Node): Unit = n match {
case n @ Node(s, "P", List(x), _) => es"std::cout << $x << std::endl"
case Node(s,"kStack", _, _) => emit("LocV::kStack")
Expand Down Expand Up @@ -138,9 +166,8 @@ trait GenericGSCodeGen extends CppSAICodeGenBase {
case Node(s, "ss-update", List(ss, k, v, sz), _) => es"$ss.update($k, $v, $sz)"
case Node(s, "ss-update", List(ss, k, v), _) => es"$ss.update($k, $v)"
case Node(s, "ss-update-seq", List(ss, k, v), _) => es"$ss.update_seq($k, $v)"
case Node(s, "ss-push", List(ss), _) => es"$ss.push()"
case Node(s, "ss-push", List(ss, k), _) => es"$ss.push($k)"
case Node(s, "ss-pop", List(ss, n), _) => es"$ss.pop($n)"
case Node(s, "ss-push", ss::args, _) => es"$ss.push($args)"
case Node(s, "ss-pop", ss::args, _) => es"$ss.pop($args)"
case Node(s, "ss-addpc", List(ss, e), _) => es"$ss.add_PC($e)"
case Node(s, "add-pc", List(pc, e), _) => es"$pc.add($e)"
case Node(s, "ss-addpcset", List(ss, es), _) => es"$ss.add_PC_set($es)"
Expand Down
17 changes: 10 additions & 7 deletions src/main/scala/gensym/EngineBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ trait EngineBase extends SAIOps { self: BasicDefs with ValueDefs =>
def info(msg: String) = unchecked("INFO(\"" + msg + "\")")

val mainRename = "gs_main"
def getRealFunName(funName: String): String = {
val newFname = if (funName != "@main") "__GS_USER_"+funName.tail else mainRename
newFname.replaceAllLiterally(".","_")
}
val gsPrefix = "__GS_USER_"

def getRealFunName(funName: String, prefix: String = gsPrefix): String =
if (funName != "@main") gsPrefix + funName.tail.replaceAllLiterally(".", "_")
else mainRename

def strippedFunName(funName: String): String = getRealFunName(funName, "")

def getRealBlockFunName(ctx: Ctx): String = blockNameMap(Counter.block.get(ctx.toString))

def compile(funName: String, b: BB): Unit = {
Expand Down Expand Up @@ -100,7 +104,7 @@ trait EngineBase extends SAIOps { self: BasicDefs with ValueDefs =>
}
val fn = repExternFun(f, ret, argTypes)
val node = Unwrap(fn).asInstanceOf[Backend.Sym]
funNameMap(node) = "__GS_NATIVE_EXTERNAL_"+mangledName.tail
funNameMap(node) = "__GS_NATIVE_"+mangledName.tail
FunFuns(mangledName) = fn
}

Expand Down Expand Up @@ -148,7 +152,6 @@ trait EngineBase extends SAIOps { self: BasicDefs with ValueDefs =>
// "When indexing into a (optionally packed) structure, only i32 integer
// constants are allowed"
// TODO: the align argument for getTySize
// TODO: test this
val indexCst: List[Long] = index.map { case IntV(n, _) => n.toLong }
IntV(calculateOffsetStatic(ty, indexCst), DEFAULT_INDEX_BW)
case PackedStruct(types) =>
Expand All @@ -159,7 +162,7 @@ trait EngineBase extends SAIOps { self: BasicDefs with ValueDefs =>
}

// Note: we can also assign symbolic values here
def uninitValue: Rep[Value] = IntV(0, 8) //NullPtr()
def uninitValue: Rep[Value] = IntV(0, 8)

def evalHeapAtomicConst(v: Constant, ty: LLVMType): Rep[Value] = v match {
case BoolConst(b) => IntV(if (b) 1 else 0, 1)
Expand Down
89 changes: 48 additions & 41 deletions src/main/scala/gensym/GenericDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ object Counter {
val block = Counter()
val variable = Counter()
val function = Counter()
val cont = Counter()
val branchStat: HashMap[Int, Int] = HashMap[Int, Int]()
def setBranchNum(ctx: Ctx, n: Int): Unit = {
val blockId = Counter.block.get(ctx.toString)
Expand Down Expand Up @@ -339,16 +340,16 @@ trait ValueDefs { self: SAIOps with BasicDefs with Opaques =>
}

object IntOp2 {
def applyNoOpt(op: String, o1: Rep[Value], o2: Rep[Value]): Rep[Value] =
def primOp2(op: String, o1: Rep[Value], o2: Rep[Value]): Rep[Value] =
"int_op_2".reflectWith[Value](op, o1, o2)
def apply(op: String, o1: Rep[Value], o2: Rep[Value]): Rep[Value] =
if (!Config.opt) applyNoOpt(op, o1, o2)
else op match {
case "neq" => neq(o1, o2)
case "eq" => eq(o1, o2)
case "add" => add(o1, o2)
case _ => applyNoOpt(op, o1, o2)
}

def apply(op: String, o1: Rep[Value], o2: Rep[Value]): Rep[Value] = op match {
case "neq" => neq(o1, o2)
case "eq" => eq(o1, o2)
case "add" => add(o1, o2)
case "mul" => mul(o1, o2)
case _ => primOp2(op, o1, o2)
}

def unapply(v: Rep[Value]): Option[(String, Rep[Value], Rep[Value])] = Unwrap(v) match {
case gNode("int_op_2", bConst(x: String)::(o1: bSym)::(o2: bSym)::_) =>
Expand All @@ -357,41 +358,48 @@ trait ValueDefs { self: SAIOps with BasicDefs with Opaques =>
}

def add(v1: Rep[Value], v2: Rep[Value]): Rep[Value] = (v1, v2) match {
case (IntV(n1, bw1), IntV(n2, bw2)) if (bw1 == bw2) => IntV(n1 + n2, bw1)
case _ => applyNoOpt("add", v1, v2)
case (IntV(n1, bw1), IntV(n2, bw2)) if (bw1 == bw2) && Config.opt => IntV(n1 + n2, bw1)
case _ => primOp2("add", v1, v2)
}

def mul(v1: Rep[Value], v2: Rep[Value]): Rep[Value] = (v1, v2) match {
case (IntV(n1, bw1), IntV(n2, bw2)) if (bw1 == bw2) => IntV(n1 * n2, bw1)
case _ => applyNoOpt("mul", v1, v2)
case (IntV(n1, bw1), IntV(n2, bw2)) if (bw1 == bw2) && Config.opt => IntV(n1 * n2, bw1)
case _ => primOp2("mul", v1, v2)
}

def neq(o1: Rep[Value], o2: Rep[Value]): Rep[Value] = (Unwrap(o1), Unwrap(o2)) match {
case (gNode("bv_sext", (e1: bExp)::bConst(bw1: Int)::_),
gNode("bv_sext", (e2: bExp)::bConst(bw2: Int)::_)) if bw1 == bw2 =>
gNode("bv_sext", (e2: bExp)::bConst(bw2: Int)::_)) if bw1 == bw2 && Config.opt =>
val v1 = Wrap[Value](e1)
val v2 = Wrap[Value](e2)
if (v1.bw == v2.bw) applyNoOpt("neq", v1, v2)
else applyNoOpt("neq", o1, o2)
case _ => applyNoOpt("neq", o1, o2)
if (v1.bw == v2.bw) neq(v1, v2)
else primOp2("neq", o1, o2)
case _ => primOp2("neq", o1, o2)
}
def eq(o1: Rep[Value], o2: Rep[Value]): Rep[Value] = (Unwrap(o1), Unwrap(o2)) match {
case (gNode("bv_sext", (e1: bExp)::bConst(bw1: Int)::_),
gNode("bv_sext", (e2: bExp)::bConst(bw2: Int)::_)) if bw1 == bw2 =>
gNode("bv_sext", (e2: bExp)::bConst(bw2: Int)::_)) if bw1 == bw2 && Config.opt =>
val v1 = Wrap[Value](e1)
val v2 = Wrap[Value](e2)
if (v1.bw == v2.bw) applyNoOpt("eq", v1, v2)
else applyNoOpt("eq", o1, o2)
case _ => applyNoOpt("eq", o1, o2)
if (v1.bw == v2.bw) eq(v1, v2)
else primOp2("eq", o1, o2)
case _ => primOp2("eq", o1, o2)
}
}

object FloatOp2 {
def apply(op: String, o1: Rep[Value], o2: Rep[Value]) = "float_op_2".reflectWith[Value](op, o1, o2)
}

object ContOpt {
def dummyCont[W[_]](implicit m: Manifest[W[SS]]): ContOpt[W] = ContOpt[W]((s, v) => ())
def fromRepCont[W[_]](k: Rep[PCont[W]])(implicit m: Manifest[W[SS]]) = ContOpt[W]((s, v) => k(s, v))
}

case class ContOpt[W[_]](k: (Rep[W[SS]], Rep[Value]) => Rep[Unit])(implicit m: Manifest[W[SS]]) {
lazy val repK = fun(k(_, _))
lazy val repK: Rep[PCont[W]] =
if (Config.onStackCont && !usingPureEngine) unchecked[PCont[W]]("pop_cont_apply")
else fun(k(_, _))
def apply(s: Rep[W[SS]], v: Rep[Value]): Rep[Unit] = k(s, v)
}

Expand Down Expand Up @@ -433,25 +441,23 @@ trait ValueDefs { self: SAIOps with BasicDefs with Opaques =>
case _ => "direct_apply".reflectWith[List[(SS, Value)]](v, s, args)
}

// The CPS version
// W[_] is parameterized over pass-by-value (Id) or pass-by-ref (Ref) of SS
def apply[W[_]](s: Rep[W[SS]], args: Rep[List[Value]], k: Rep[PCont[W]])(implicit m: Manifest[W[SS]]): Rep[Unit] =
v match {
case ExternalFun("noop", ty) if Config.opt => k(s, defaultRetVal(ty))
case ExternalFun(f, ty) => f.reflectWith[Unit](s, args, k)
case CPSFunV(f) => f(s, args, k) // direct call
case _ => "cps_apply".reflectWith[Unit](v, s, args, k) // indirect call
}

// This `apply` works for the CPS version that takes an optimizable continuation `ContOpt`.
// Using `ContOpt`, we may choose to call the continuation at staging-time, or to generate
// the continuation function into the second stage.
// W[_] is parameterized over pass-by-value (Id[_]) or pass-by-ref (Ref[_]) of SS.
def apply[W[_]](s: Rep[W[SS]], args: Rep[List[Value]], k: ContOpt[W])(implicit m: Manifest[W[SS]]): Rep[Unit] =
v match {
case ExternalFun("noop", ty) if Config.opt => k(s, defaultRetVal(ty))
case ExternalFun(f, ty) if ExternalFun.isDeterministic(f) && !usingPureEngine =>
case ExternalFun("noop", ty) if Config.opt =>
// Avoids generating continuations for the `noop` function.
k(s, defaultRetVal(ty))
case ExternalFun(f, ty) if Config.opt && ExternalFun.isDeterministic(f) && !usingPureEngine =>
// Avoids generating continuations for the _imperative_ CPS engine if the function is deterministic.
// Be careful: since the state is not passed/returned, with the imperative backend it means
// the state must be passed by reference to f_det! Currently not all deterministic functions
// defined in backend works in this way (see external_shared.hpp).
k(s, (f+"_det").reflectCtrlWith[Value](s, args))
case ExternalFun(f, ty) if ExternalFun.isDeterministic(f) && usingPureEngine =>
case ExternalFun(f, ty) if Config.opt && ExternalFun.isDeterministic(f) && usingPureEngine =>
// Avoids generating continuations for the _pure_ engine if the function is deterministic.
val sv = (f+"_det").reflectCtrlWith[(W[SS], Value)](s, args)
k(sv._1, sv._2)
case ExternalFun(f, ty) => f.reflectWith[Unit](s, args, k.repK)
Expand All @@ -464,22 +470,22 @@ trait ValueDefs { self: SAIOps with BasicDefs with Opaques =>
val foldableOp = StaticSet[String]("make_SymV", "make_IntV", "bv_sext", "bv_zext")

def sExt(bw: Int): Rep[Value] = Unwrap(v) match {
case gNode(s, (v1: bExp)::bConst(bw1: Int)::_) if (foldableOp(s) && (bw1 == bw)) => v
case gNode("make_IntV", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1 =>
case gNode(s, (v1: bExp)::bConst(bw1: Int)::_) if foldableOp(s) && (bw1 == bw) && Config.opt => v
case gNode("make_IntV", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1 && Config.opt =>
// sExt(IntV(n, bw1), bw) ⇒ IntV(n, bw)
IntV(Wrap[Long](v1), bw)
case gNode("bv_sext", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1=>
case gNode("bv_sext", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1 && Config.opt =>
// sExt(sExt(n, bw1), bw) ⇒ sExt(n, bw)
Wrap[Value](v1).sExt(bw)
case _ => "bv_sext".reflectWith[Value](v, bw)
}

def zExt(bw: Int): Rep[Value] = Unwrap(v) match {
case gNode(s, (v1: bExp)::bConst(bw1: Int)::_) if (foldableOp(s) && (bw1 == bw)) => v
case gNode("make_IntV", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1 =>
case gNode(s, (v1: bExp)::bConst(bw1: Int)::_) if foldableOp(s) && (bw1 == bw) && Config.opt => v
case gNode("make_IntV", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1 && Config.opt =>
// zExt(IntV(n, bw1), bw) ⇒ IntV(n, bw)
IntV(Wrap[Long](v1), bw)
case gNode("bv_zext", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1=>
case gNode("bv_zext", (v1: bExp)::bConst(bw1: Int)::_) if bw > bw1 && Config.opt =>
// zExt(sExt(n, bw1), bw) ⇒ zExt(n, bw)
Wrap[Value](v1).zExt(bw)
case _ => "bv_zext".reflectWith[Value](v, bw)
Expand All @@ -502,6 +508,7 @@ trait ValueDefs { self: SAIOps with BasicDefs with Opaques =>
def *(rhs: Rep[Value]): Rep[Value] = IntOp2.mul(v, rhs)
def &(rhs: Rep[Value]): Rep[Value] = IntOp2("and", v, rhs)
def |(rhs: Rep[Value]): Rep[Value] = IntOp2("or", v, rhs)
def ≡(rhs: Rep[Value]): Rep[Value] = IntOp2.eq(v, rhs)
def unary_! : Rep[Value] = IntOp1.neg(v)
def unary_~ : Rep[Value] = IntOp1.bvnot(v)

Expand Down
Loading