diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index d8035481..5da89770 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -174,9 +174,9 @@ object Evaluator { def eval(insts: List[Instr], stack: List[Value], frame: Frame, + kont: Cont, trail: List[Cont], - kont: Cont) - (implicit retKontIdx: Int): Unit = { + ret: Int): Unit = { if (insts.isEmpty) return kont(stack) val inst = insts.head @@ -185,23 +185,23 @@ object Evaluator { //println(f"stack size: ${stack.size}") //println(s"eval: $inst") inst match { - case Drop => eval(rest, stack.tail, frame, trail, kont) + case Drop => eval(rest, stack.tail, frame, kont, trail, ret) case Select(_) => val I32V(cond) :: v2 :: v1 :: newStack = stack val value = if (cond == 0) v1 else v2 - eval(rest, value :: newStack, frame, trail, kont) + eval(rest, value :: newStack, frame, kont, trail, ret) case LocalGet(i) => - eval(rest, frame.locals(i) :: stack, frame, trail, kont) + eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret) case LocalSet(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, newStack, frame, trail, kont) + eval(rest, newStack, frame, kont, trail, ret) case LocalTee(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, stack, frame, trail, kont) + eval(rest, stack, frame, kont, trail, ret) case GlobalGet(i) => - eval(rest, frame.module.globals(i).value :: stack, frame, trail, kont) + eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret) case GlobalSet(i) => val value :: newStack = stack frame.module.globals(i).ty match { @@ -210,16 +210,16 @@ object Evaluator { case GlobalType(_, true) => throw new Exception("Invalid type") case _ => throw new Exception("Cannot set immutable global") } - eval(rest, newStack, frame, trail, kont) + eval(rest, newStack, frame, kont, trail, ret) case MemorySize => - eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, trail, kont) + eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret) case MemoryGrow => val I32V(delta) :: newStack = stack val mem = frame.module.memory.head val oldSize = mem.size mem.grow(delta) match { - case Some(e) => eval(rest, I32V(-1) :: newStack, frame, trail, kont) - case _ => eval(rest, I32V(oldSize) :: newStack, frame, trail, kont) + case Some(e) => eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret) + case _ => eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret) } case MemoryFill => val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack @@ -227,7 +227,7 @@ object Evaluator { throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`? else { frame.module.memory.head.fill(offset, size, value.toByte) - eval(rest, newStack, frame, trail, kont) + eval(rest, newStack, frame, kont, trail, ret) } case MemoryCopy => val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack @@ -235,83 +235,78 @@ object Evaluator { throw new Exception("Out of bounds memory access") else { frame.module.memory.head.copy(dest, src, n) - eval(rest, newStack, frame, trail, kont) + eval(rest, newStack, frame, kont, trail, ret) } - case Const(n) => eval(rest, n :: stack, frame, trail, kont) + case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret) case Binary(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, trail, kont) + eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret) case Unary(op) => val v :: newStack = stack - eval(rest, evalUnaryOp(op, v) :: newStack, frame, trail, kont) + eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret) case Compare(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, trail, kont) + eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret) case Test(op) => val v :: newStack = stack - eval(rest, evalTestOp(op, v) :: newStack, frame, trail, kont) + eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret) case Store(StoreOp(align, offset, ty, None)) => val I32V(v) :: I32V(addr) :: newStack = stack frame.module.memory(0).storeInt(addr + offset, v) - eval(rest, newStack, frame, trail, kont) + eval(rest, newStack, frame, kont, trail, ret) case Load(LoadOp(align, offset, ty, None, None)) => val I32V(addr) :: newStack = stack val value = frame.module.memory(0).loadInt(addr + offset) - eval(rest, I32V(value) :: newStack, frame, trail, kont) + eval(rest, I32V(value) :: newStack, frame, kont, trail, ret) case Nop => - eval(rest, stack, frame, trail, kont) + eval(rest, stack, frame, kont, trail, ret) case Unreachable => throw new RuntimeException("Unreachable") case Block(ty, inner) => val k: Cont = (retStack) => - eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail, kont) + eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret) // TODO: block can take inputs too - eval(inner, List(), frame, k :: trail, k)(retKontIdx + 1) + eval(inner, List(), frame, k, k :: trail, ret+1) case Loop(ty, inner) => // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure // of the program. - val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail, kont) + val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret) def loop(stack: List[Value]): Unit = { val k: Cont = (retStack) => loop(retStack.take(ty.toList.size)) - eval(inner, stack, frame, k :: trail, restK)(retKontIdx + 1) + eval(inner, stack, frame, restK, k :: trail, ret+1) } loop(List()) case If(ty, thn, els) => val I32V(cond) :: newStack = stack val inner = if (cond != 0) thn else els val k: Cont = (retStack) => - eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, trail, kont) - eval(inner, List(), frame, k :: trail, k)(retKontIdx + 1) + eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret) + eval(inner, List(), frame, k, k :: trail, ret+1) case Br(label) => trail(label)(stack) case BrIf(label) => val I32V(cond) :: newStack = stack - println(s"brif: $cond") if (cond != 0) trail(label)(newStack) - else { - println(s"br if rest $rest") - eval(rest, newStack, frame, trail, kont) - } - case Return => trail(retKontIdx)(stack) + else eval(rest, newStack, frame, kont, trail, ret) + case Return => trail(ret)(stack) case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] => val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f) - println(s"calling $f") val args = stack.take(ty.inps.size).reverse val newStack = stack.drop(ty.inps.size) val frameLocals = args ++ locals.map(_ => I32V(0)) // GW: always I32? or depending on their types? val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*)) val newK: RetCont = (retStack) => - eval(rest, retStack.take(ty.out.size) ++ newStack, frame, trail, kont) + eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret) // We push newK on the trail since function creates a new block to escape // (more or less like `return`) - eval(body, List(), newFrame, newK :: trail, newK)(retKontIdx + 1) + eval(body, List(), newFrame, newK, newK :: trail, ret+1) case Call(f) if frame.module.funcs(f).isInstanceOf[Import] => frame.module.funcs(f) match { case Import("console", "log", _) => //println(s"[DEBUG] current stack: $stack") val I32V(v) :: newStack = stack println(v) - eval(rest, newStack, frame, trail, kont) + eval(rest, newStack, frame, kont, trail, ret) case f => throw new Exception(s"Unknown import $f") } case _ => @@ -319,4 +314,56 @@ object Evaluator { throw new Exception(s"instruction $inst not implemented") } } + + // If `main` is given, then we use that function as the entry point of the program; + // otherwise, we look up the top-level `start` instruction to locate the entry point. + def evalTop(module: Module, halt: Cont, main: Option[String]): Unit = { + val instrs = main match { + case Some(_) => module.definitions.flatMap({ + case FuncDef(`main`, FuncBodyDef(_, _, _, body)) => + println(s"Entering function $main") + body + case _ => List() + }) + case None => module.definitions.flatMap({ + case Start(id) => module.funcEnv(id) match { + case FuncDef(_, FuncBodyDef(_, _, _, body)) => + println(s"Entering unnamed function $id") + body + case _ => throw new Exception("Start function has no concrete definition") + } + case _ => List() + }) + } + + val types = List() + val funcs = module.definitions + .collect({ + case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef + }) + .toList + + val globals = module.definitions + .collect({ + case Global(_, GlobalValue(ty, e)) => + (e.head) match { + case Const(c) => RTGlobal(ty, c) + // Q: What is the default behavior if case in non exhaustive + case _ => ??? + } + }) + .toList + + // TODO: correct the behavior for memory + val memory = module.definitions + .collect({ + case Memory(id, MemoryType(min, max_opt)) => + RTMemory(min, max_opt) + }) + .toList + + val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals) + + Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0) + } } diff --git a/src/test/scala/genwasym/TestEval.scala b/src/test/scala/genwasym/TestEval.scala index 5cdb5507..05e5cd50 100644 --- a/src/test/scala/genwasym/TestEval.scala +++ b/src/test/scala/genwasym/TestEval.scala @@ -19,70 +19,14 @@ class TestEval extends FunSuite { def testFile(filename: String, main: Option[String] = None, expected: Option[Int] = None) = { val module = Parser.parseFile(filename) //println(module) - - val instrs = main match { - case Some(_) => module.definitions.flatMap({ - case FuncDef(`main`, FuncBodyDef(_, _, _, body)) => - println(s"Entering function $main") - body - case _ => List() - }) - case None => module.definitions.flatMap({ - case Start(id) => module.funcEnv(id) match { - case FuncDef(_, FuncBodyDef(_, _, _, body)) => - println(s"Entering unnamed function $id") - body - case _ => throw new Exception("Start function has no concrete definition") - } - case _ => List() - }) - } - - val types = List() - val funcs = module.definitions - .collect({ - case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef - }) - .toList - - val globals = module.definitions - .collect({ - case Global(_, GlobalValue(ty, e)) => - (e.head) match { - case Const(c) => RTGlobal(ty, c) - // Q: What is the default behavior if case in non exhaustive - case _ => ??? - } - }) - .toList - - // TODO: correct the behavior for memory - val memory = module.definitions - .collect({ - case Memory(id, MemoryType(min, max_opt)) => - RTMemory(min, max_opt) - }) - .toList - - val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals) - - val trailK: Evaluator.Cont = newStack => { - println(s"trail: $newStack") + val haltK: Evaluator.Cont = newStack => { + println(s"halt cont: $newStack") expected match { case Some(e) => assert(newStack(0) == I32V(e)) case None => () } } - - implicit val restK: Evaluator.Cont = newStack => { - println(s"restK trail: $newStack") - expected match { - case Some(e) => assert(newStack(0) == I32V(e)) - case None => () - } - - } - Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), List(trailK), restK)(0) + Evaluator.evalTop(module, haltK, main) } // TODO: the power test can be used to test the stack