diff --git a/benchmarks/wasm/even_odd.rs b/benchmarks/wasm/even_odd.rs new file mode 100644 index 00000000..91ff3190 --- /dev/null +++ b/benchmarks/wasm/even_odd.rs @@ -0,0 +1,18 @@ +#[no_mangle] +#[inline(never)] +fn is_even(n: u32) -> bool { + if n == 0 { true } + else { is_odd(n - 1) } +} + +#[no_mangle] +#[inline(never)] +fn is_odd(n: u32) -> bool { + if n == 0 { false } + else { is_even(n - 1) } +} + +#[no_mangle] +fn real_main() -> bool { + is_even(12) +} diff --git a/benchmarks/wasm/even_odd.wat b/benchmarks/wasm/even_odd.wat new file mode 100644 index 00000000..27e10369 --- /dev/null +++ b/benchmarks/wasm/even_odd.wat @@ -0,0 +1,31 @@ +(module + (type (;0;) (func (param i32) (result i32))) + (type (;1;) (func (result i32))) + (func (;0;) (type 0) (param i32) (result i32) + block ;; label = @1 + local.get 0 + br_if 0 (;@1;) + i32.const 1 + return + end + local.get 0 + i32.const -1 + i32.add + call 1) + (func (;1;) (type 0) (param i32) (result i32) + block ;; label = @1 + local.get 0 + br_if 0 (;@1;) + i32.const 0 + return + end + local.get 0 + i32.const -1 + i32.add + call 0) + (func (;2;) (type 1) (result i32) + i32.const 13 + call 1) + (start 2) + (memory (;0;) 16) +) diff --git a/benchmarks/wasm/return.wat b/benchmarks/wasm/return.wat new file mode 100644 index 00000000..f9609173 --- /dev/null +++ b/benchmarks/wasm/return.wat @@ -0,0 +1,13 @@ +(module + (type (;0;) (func)) + (func (;0;) (type 0) + block ;; label = @1 + return + end + unreachable + ) + (func (;1;) (type 0) + call 0 + ) + (start 1) +) \ No newline at end of file diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index d87c03a5..98d5fa85 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -168,14 +168,14 @@ case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value]) object Evaluator { import Primtives._ - type RetCont = List[Value] => Unit - type Cont = List[Value] => Unit + type Cont[A] = List[Value] => A - def eval(insts: List[Instr], - stack: List[Value], - frame: Frame, - trail: List[Cont]) - (implicit kont: Cont): Unit = { + def eval[Ans](insts: List[Instr], + stack: List[Value], + frame: Frame, + kont: Cont[Ans], + trail: List[Cont[Ans]], + ret: Int): Ans = { if (insts.isEmpty) return kont(stack) val inst = insts.head @@ -184,23 +184,23 @@ object Evaluator { //println(f"stack size: ${stack.size}") //println(s"eval: $inst") inst match { - case Drop => eval(rest, stack.tail, frame, trail) + case Drop => eval(rest, stack.tail, frame, kont, trail, ret) case Select(_) => val I32V(cond) :: v2 :: v1 :: newStack = stack val value = if (cond == 0) v1 else v2 - eval(rest, value :: newStack, frame, trail) + eval(rest, value :: newStack, frame, kont, trail, ret) case LocalGet(i) => - eval(rest, frame.locals(i) :: stack, frame, trail) + eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret) case LocalSet(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, newStack, frame, trail) + eval(rest, newStack, frame, kont, trail, ret) case LocalTee(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, stack, frame, trail) + eval(rest, stack, frame, kont, trail, ret) case GlobalGet(i) => - eval(rest, frame.module.globals(i).value :: stack, frame, trail) + eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret) case GlobalSet(i) => val value :: newStack = stack frame.module.globals(i).ty match { @@ -209,19 +209,16 @@ object Evaluator { case GlobalType(_, true) => throw new Exception("Invalid type") case _ => throw new Exception("Cannot set immutable global") } - eval(rest, newStack, frame, trail) + eval(rest, newStack, frame, kont, trail, ret) case MemorySize => - eval(rest, - I32V(frame.module.memory.head.size) :: stack, - frame, - trail) + eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret) case MemoryGrow => val I32V(delta) :: newStack = stack val mem = frame.module.memory.head val oldSize = mem.size mem.grow(delta) match { - case Some(e) => eval(rest, I32V(-1) :: newStack, frame, trail) - case _ => eval(rest, I32V(oldSize) :: newStack, frame, trail) + case Some(e) => eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret) + case _ => eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret) } case MemoryFill => val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack @@ -229,7 +226,7 @@ object Evaluator { throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`? else { frame.module.memory.head.fill(offset, size, value.toByte) - eval(rest, newStack, frame, trail) + eval(rest, newStack, frame, kont, trail, ret) } case MemoryCopy => val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack @@ -237,78 +234,78 @@ object Evaluator { throw new Exception("Out of bounds memory access") else { frame.module.memory.head.copy(dest, src, n) - eval(rest, newStack, frame, trail) + eval(rest, newStack, frame, kont, trail, ret) } - case Const(n) => eval(rest, n :: stack, frame, trail) + case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret) case Binary(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, trail) + eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret) case Unary(op) => val v :: newStack = stack - eval(rest, evalUnaryOp(op, v) :: newStack, frame, trail) + eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret) case Compare(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, trail) + eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret) case Test(op) => val v :: newStack = stack - eval(rest, evalTestOp(op, v) :: newStack, frame, trail) + eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret) case Store(StoreOp(align, offset, ty, None)) => val I32V(v) :: I32V(addr) :: newStack = stack frame.module.memory(0).storeInt(addr + offset, v) - eval(rest, newStack, frame, trail) + eval(rest, newStack, frame, kont, trail, ret) case Load(LoadOp(align, offset, ty, None, None)) => val I32V(addr) :: newStack = stack val value = frame.module.memory(0).loadInt(addr + offset) - eval(rest, I32V(value) :: newStack, frame, trail) + eval(rest, I32V(value) :: newStack, frame, kont, trail, ret) case Nop => - eval(rest, stack, frame, trail) + eval(rest, stack, frame, kont, trail, ret) case Unreachable => throw new RuntimeException("Unreachable") case Block(ty, inner) => - val k: Cont = (retStack) => - eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail) + val k: Cont[Ans] = (retStack) => + eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret) // TODO: block can take inputs too - eval(inner, List(), frame, k :: trail)(k) + eval(inner, List(), frame, k, k :: trail, ret+1) case Loop(ty, inner) => // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure // of the program. - val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail) - def loop(stack: List[Value]): Unit = { - val k: Cont = (retStack) => loop(retStack.take(ty.toList.size)) - eval(inner, stack, frame, k :: trail)(restK) + val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret) + def loop(stack: List[Value]): Ans = { + val k: Cont[Ans] = (retStack) => loop(retStack.take(ty.toList.size)) + eval(inner, stack, frame, restK, k :: trail, ret+1) } loop(List()) case If(ty, thn, els) => val I32V(cond) :: newStack = stack val inner = if (cond != 0) thn else els - val k: Cont = (retStack) => - eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, trail) - eval(inner, List(), frame, k :: trail)(k) + val k: Cont[Ans] = (retStack) => + eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret) + eval(inner, List(), frame, k, k :: trail, ret+1) case Br(label) => trail(label)(stack) case BrIf(label) => val I32V(cond) :: newStack = stack if (cond != 0) trail(label)(newStack) - else eval(rest, newStack, frame, trail) - case Return => kont(stack) + else eval(rest, newStack, frame, kont, trail, ret) + case Return => trail(ret)(stack) case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] => val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f) val args = stack.take(ty.inps.size).reverse val newStack = stack.drop(ty.inps.size) val frameLocals = args ++ locals.map(_ => I32V(0)) // GW: always I32? or depending on their types? val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*)) - val newK: RetCont = (retStack) => - eval(rest, retStack.take(ty.out.size) ++ newStack, frame, trail) + val newK: Cont[Ans] = (retStack) => + eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret) // We push newK on the trail since function creates a new block to escape // (more or less like `return`) - eval(body, List(), newFrame, newK :: trail)(newK) + eval(body, List(), newFrame, newK, newK :: trail, ret+1) case Call(f) if frame.module.funcs(f).isInstanceOf[Import] => frame.module.funcs(f) match { case Import("console", "log", _) => //println(s"[DEBUG] current stack: $stack") val I32V(v) :: newStack = stack println(v) - eval(rest, newStack, frame, trail) + eval(rest, newStack, frame, kont, trail, ret) case f => throw new Exception(s"Unknown import $f") } case _ => @@ -316,4 +313,58 @@ object Evaluator { throw new Exception(s"instruction $inst not implemented") } } + + // If `main` is given, then we use that function as the entry point of the program; + // otherwise, we look up the top-level `start` instruction to locate the entry point. + def evalTop[Ans](module: Module, halt: Cont[Ans], main: Option[String] = None): Ans = { + val instrs = main match { + case Some(_) => module.definitions.flatMap({ + case FuncDef(`main`, FuncBodyDef(_, _, _, body)) => + println(s"Entering function $main") + body + case _ => List() + }) + case None => module.definitions.flatMap({ + case Start(id) => module.funcEnv(id) match { + case FuncDef(_, FuncBodyDef(_, _, _, body)) => + println(s"Entering unnamed function $id") + body + case _ => throw new Exception("Start function has no concrete definition") + } + case _ => List() + }) + } + + val types = List() + val funcs = module.definitions + .collect({ + case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef + }) + .toList + + val globals = module.definitions + .collect({ + case Global(_, GlobalValue(ty, e)) => + (e.head) match { + case Const(c) => RTGlobal(ty, c) + // Q: What is the default behavior if case in non exhaustive + case _ => ??? + } + }) + .toList + + // TODO: correct the behavior for memory + val memory = module.definitions + .collect({ + case Memory(id, MemoryType(min, max_opt)) => + RTMemory(min, max_opt) + }) + .toList + + val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals) + + Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0) + } + + def evalTop(m: Module): Unit = evalTop(m, stack => ()) } diff --git a/src/main/scala/wasm/NewStagedEvalCPS.scala b/src/main/scala/wasm/StagedEvalCPS.scala similarity index 100% rename from src/main/scala/wasm/NewStagedEvalCPS.scala rename to src/main/scala/wasm/StagedEvalCPS.scala diff --git a/src/test/scala/genwasym/TestEval.scala b/src/test/scala/genwasym/TestEval.scala index 368c79bc..81fb7df5 100644 --- a/src/test/scala/genwasym/TestEval.scala +++ b/src/test/scala/genwasym/TestEval.scala @@ -19,70 +19,14 @@ class TestEval extends FunSuite { def testFile(filename: String, main: Option[String] = None, expected: Option[Int] = None) = { val module = Parser.parseFile(filename) //println(module) - - val instrs = main match { - case Some(_) => module.definitions.flatMap({ - case FuncDef(`main`, FuncBodyDef(_, _, _, body)) => - println(s"Entering function $main") - body - case _ => List() - }) - case None => module.definitions.flatMap({ - case Start(id) => module.funcEnv(id) match { - case FuncDef(_, FuncBodyDef(_, _, _, body)) => - println(s"Entering unnamed function $id") - body - case _ => throw new Exception("Start function has no concrete definition") - } - case _ => List() - }) - } - - val types = List() - val funcs = module.definitions - .collect({ - case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef - }) - .toList - - val globals = module.definitions - .collect({ - case Global(_, GlobalValue(ty, e)) => - (e.head) match { - case Const(c) => RTGlobal(ty, c) - // Q: What is the default behavior if case in non exhaustive - case _ => ??? - } - }) - .toList - - // TODO: correct the behavior for memory - val memory = module.definitions - .collect({ - case Memory(id, MemoryType(min, max_opt)) => - RTMemory(min, max_opt) - }) - .toList - - val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals) - - val trailK: Evaluator.Cont = newStack => { - println(s"trail: $newStack") + val haltK: Evaluator.Cont[Unit] = stack => { + println(s"halt cont: $stack") expected match { - case Some(e) => assert(newStack(0) == I32V(e)) + case Some(e) => assert(stack(0) == I32V(e)) case None => () } } - - implicit val restK: Evaluator.Cont = newStack => { - println(s"restK trail: $newStack") - expected match { - case Some(e) => assert(newStack(0) == I32V(e)) - case None => () - } - - } - Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), List(trailK)) + Evaluator.evalTop(module, haltK, main) } // TODO: the power test can be used to test the stack @@ -93,9 +37,11 @@ class TestEval extends FunSuite { test("start") { testFile("./benchmarks/wasm/start.wat") } test("fact") { testFile("./benchmarks/wasm/fact.wat", None, Some(120)) } test("loop") { testFile("./benchmarks/wasm/loop.wat", None, Some(10)) } + test("even-odd") { testFile("./benchmarks/wasm/even_odd.wat", None, Some(1)) } + test("return") { testFile("./benchmarks/wasm/return.wat", None, None) } // Parser works, but the memory issue remains - //test("btree") { testFile("./benchmarks/wasm/btree/2o1u-no-label-for-real.wat") } + // test("btree") { testFile("./benchmarks/wasm/btree/2o1u-no-label-for-real.wat") } // TODO: add more wasm spec tests? // test("memory") { test_btree("./benchmarks/wasm/spectest/test.wat", "$real_main") }