From 657fec5a69e1a98f8fd87cda4bab1bd44279d88a Mon Sep 17 00:00:00 2001 From: Guannan Wei Date: Mon, 21 Oct 2024 00:53:03 +0200 Subject: [PATCH] no need to take the 'ret' argument (#59) --- src/main/scala/wasm/MiniWasm.scala | 67 +++++++++++++++--------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index e691ec3e..62c6d2cc 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -197,8 +197,7 @@ object Evaluator { stack: List[Value], frame: Frame, kont: Cont[Ans], - trail: List[Cont[Ans]], - ret: Int): Ans = { + trail: List[Cont[Ans]]): Ans = { if (insts.isEmpty) return kont(stack) val inst = insts.head @@ -207,23 +206,23 @@ object Evaluator { // println(s"inst: ${inst} \t | ${frame.locals} | ${stack.reverse}" ) inst match { - case Drop => eval(rest, stack.tail, frame, kont, trail, ret) + case Drop => eval(rest, stack.tail, frame, kont, trail) case Select(_) => val I32V(cond) :: v2 :: v1 :: newStack = stack val value = if (cond == 0) v1 else v2 - eval(rest, value :: newStack, frame, kont, trail, ret) + eval(rest, value :: newStack, frame, kont, trail) case LocalGet(i) => - eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret) + eval(rest, frame.locals(i) :: stack, frame, kont, trail) case LocalSet(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, newStack, frame, kont, trail, ret) + eval(rest, newStack, frame, kont, trail) case LocalTee(i) => val value :: newStack = stack frame.locals(i) = value - eval(rest, stack, frame, kont, trail, ret) + eval(rest, stack, frame, kont, trail) case GlobalGet(i) => - eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret) + eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail) case GlobalSet(i) => val value :: newStack = stack frame.module.globals(i).ty match { @@ -232,18 +231,18 @@ object Evaluator { case GlobalType(_, true) => throw new Exception("Invalid type") case _ => throw new Exception("Cannot set immutable global") } - eval(rest, newStack, frame, kont, trail, ret) + eval(rest, newStack, frame, kont, trail) case MemorySize => - eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret) + eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail) case MemoryGrow => val I32V(delta) :: newStack = stack val mem = frame.module.memory.head val oldSize = mem.size mem.grow(delta) match { case Some(e) => - eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret) + eval(rest, I32V(-1) :: newStack, frame, kont, trail) case _ => - eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret) + eval(rest, I32V(oldSize) :: newStack, frame, kont, trail) } case MemoryFill => val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack @@ -251,7 +250,7 @@ object Evaluator { throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`? else { frame.module.memory.head.fill(offset, size, value.toByte) - eval(rest, newStack, frame, kont, trail, ret) + eval(rest, newStack, frame, kont, trail) } case MemoryCopy => val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack @@ -259,38 +258,38 @@ object Evaluator { throw new Exception("Out of bounds memory access") else { frame.module.memory.head.copy(dest, src, n) - eval(rest, newStack, frame, kont, trail, ret) + eval(rest, newStack, frame, kont, trail) } - case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret) + case Const(n) => eval(rest, n :: stack, frame, kont, trail) case Binary(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret) + eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail) case Unary(op) => val v :: newStack = stack - eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret) + eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail) case Compare(op) => val v2 :: v1 :: newStack = stack - eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret) + eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail) case Test(op) => val v :: newStack = stack - eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret) + eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail) case Store(StoreOp(align, offset, ty, None)) => val I32V(v) :: I32V(addr) :: newStack = stack frame.module.memory(0).storeInt(addr + offset, v) - eval(rest, newStack, frame, kont, trail, ret) + eval(rest, newStack, frame, kont, trail) case Load(LoadOp(align, offset, ty, None, None)) => val I32V(addr) :: newStack = stack val value = frame.module.memory(0).loadInt(addr + offset) - eval(rest, I32V(value) :: newStack, frame, kont, trail, ret) + eval(rest, I32V(value) :: newStack, frame, kont, trail) case Nop => - eval(rest, stack, frame, kont, trail, ret) + eval(rest, stack, frame, kont, trail) case Unreachable => throw Trap() case Block(ty, inner) => val funcTy = getFuncType(frame.module, ty) val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail, ret) - eval(inner, inputs, frame, restK, restK :: trail, ret + 1) + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + eval(inner, inputs, frame, restK, restK :: trail) case Loop(ty, inner) => // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure @@ -298,9 +297,9 @@ object Evaluator { val funcTy = getFuncType(frame.module, ty) val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail, ret) + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) def loop(retStack: List[Value]): Ans = - eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail, ret + 1) + eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail) loop(inputs) case If(ty, thn, els) => val funcTy = getFuncType(frame.module, ty) @@ -308,36 +307,36 @@ object Evaluator { val inner = if (cond != 0) thn else els val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => - eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail, ret) - eval(inner, inputs, frame, restK, restK :: trail, ret + 1) + eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) + eval(inner, inputs, frame, restK, restK :: trail) case Br(label) => trail(label)(stack) case BrIf(label) => val I32V(cond) :: newStack = stack if (cond != 0) trail(label)(newStack) - else eval(rest, newStack, frame, kont, trail, ret) + else eval(rest, newStack, frame, kont, trail) case BrTable(labels, default) => val I32V(cond) :: newStack = stack val goto = if (cond < labels.length) labels(cond) else default trail(goto)(newStack) - case Return => trail(ret)(stack) + case Return => trail.last(stack) case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] => val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f) val args = stack.take(ty.inps.size).reverse val newStack = stack.drop(ty.inps.size) val frameLocals = args ++ locals.map(zero(_)) val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*)) - val newK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret) + val newK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) // We push newK on the trail since function creates a new block to escape // (more or less like `return`) - eval(body, List(), newFrame, newK, List(newK), 0) + eval(body, List(), newFrame, newK, List(newK)) case Call(f) if frame.module.funcs(f).isInstanceOf[Import] => frame.module.funcs(f) match { case Import("console", "log", _) => //println(s"[DEBUG] current stack: $stack") val I32V(v) :: newStack = stack println(v) - eval(rest, newStack, frame, kont, trail, ret) + eval(rest, newStack, frame, kont, trail) case f => throw new Exception(s"Unknown import $f") } case _ => @@ -404,7 +403,7 @@ object Evaluator { val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals) - Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0) + Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt)) } def evalTop(m: Module): Unit = evalTop(m, stack => ())