Skip to content

Commit

Permalink
no need to take the 'ret' argument (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kraks authored Oct 20, 2024
1 parent ea3855a commit 657fec5
Showing 1 changed file with 33 additions and 34 deletions.
67 changes: 33 additions & 34 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ object Evaluator {
stack: List[Value],
frame: Frame,
kont: Cont[Ans],
trail: List[Cont[Ans]],
ret: Int): Ans = {
trail: List[Cont[Ans]]): Ans = {
if (insts.isEmpty) return kont(stack)

val inst = insts.head
Expand All @@ -207,23 +206,23 @@ object Evaluator {
// println(s"inst: ${inst} \t | ${frame.locals} | ${stack.reverse}" )

inst match {
case Drop => eval(rest, stack.tail, frame, kont, trail, ret)
case Drop => eval(rest, stack.tail, frame, kont, trail)
case Select(_) =>
val I32V(cond) :: v2 :: v1 :: newStack = stack
val value = if (cond == 0) v1 else v2
eval(rest, value :: newStack, frame, kont, trail, ret)
eval(rest, value :: newStack, frame, kont, trail)
case LocalGet(i) =>
eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret)
eval(rest, frame.locals(i) :: stack, frame, kont, trail)
case LocalSet(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, newStack, frame, kont, trail, ret)
eval(rest, newStack, frame, kont, trail)
case LocalTee(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, stack, frame, kont, trail, ret)
eval(rest, stack, frame, kont, trail)
case GlobalGet(i) =>
eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret)
eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail)
case GlobalSet(i) =>
val value :: newStack = stack
frame.module.globals(i).ty match {
Expand All @@ -232,112 +231,112 @@ object Evaluator {
case GlobalType(_, true) => throw new Exception("Invalid type")
case _ => throw new Exception("Cannot set immutable global")
}
eval(rest, newStack, frame, kont, trail, ret)
eval(rest, newStack, frame, kont, trail)
case MemorySize =>
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret)
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail)
case MemoryGrow =>
val I32V(delta) :: newStack = stack
val mem = frame.module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) =>
eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret)
eval(rest, I32V(-1) :: newStack, frame, kont, trail)
case _ =>
eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret)
eval(rest, I32V(oldSize) :: newStack, frame, kont, trail)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
if (memOutOfBound(frame, 0, offset, size))
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
else {
frame.module.memory.head.fill(offset, size, value.toByte)
eval(rest, newStack, frame, kont, trail, ret)
eval(rest, newStack, frame, kont, trail)
}
case MemoryCopy =>
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
if (memOutOfBound(frame, 0, src, n) || memOutOfBound(frame, 0, dest, n))
throw new Exception("Out of bounds memory access")
else {
frame.module.memory.head.copy(dest, src, n)
eval(rest, newStack, frame, kont, trail, ret)
eval(rest, newStack, frame, kont, trail)
}
case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret)
case Const(n) => eval(rest, n :: stack, frame, kont, trail)
case Binary(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail)
case Unary(op) =>
val v :: newStack = stack
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret)
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail)
case Compare(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail)
case Test(op) =>
val v :: newStack = stack
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret)
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail)
case Store(StoreOp(align, offset, ty, None)) =>
val I32V(v) :: I32V(addr) :: newStack = stack
frame.module.memory(0).storeInt(addr + offset, v)
eval(rest, newStack, frame, kont, trail, ret)
eval(rest, newStack, frame, kont, trail)
case Load(LoadOp(align, offset, ty, None, None)) =>
val I32V(addr) :: newStack = stack
val value = frame.module.memory(0).loadInt(addr + offset)
eval(rest, I32V(value) :: newStack, frame, kont, trail, ret)
eval(rest, I32V(value) :: newStack, frame, kont, trail)
case Nop =>
eval(rest, stack, frame, kont, trail, ret)
eval(rest, stack, frame, kont, trail)
case Unreachable => throw Trap()
case Block(ty, inner) =>
val funcTy = getFuncType(frame.module, ty)
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
val restK: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail, ret)
eval(inner, inputs, frame, restK, restK :: trail, ret + 1)
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail)
eval(inner, inputs, frame, restK, restK :: trail)
case Loop(ty, inner) =>
// We construct two continuations, one for the break (to the begining of the loop),
// and one for fall-through to the next instruction following the syntactic structure
// of the program.
val funcTy = getFuncType(frame.module, ty)
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
val restK: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail, ret)
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail)
def loop(retStack: List[Value]): Ans =
eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail, ret + 1)
eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail)
loop(inputs)
case If(ty, thn, els) =>
val funcTy = getFuncType(frame.module, ty)
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val (inputs, restStack) = newStack.splitAt(funcTy.inps.size)
val restK: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail, ret)
eval(inner, inputs, frame, restK, restK :: trail, ret + 1)
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail)
eval(inner, inputs, frame, restK, restK :: trail)
case Br(label) =>
trail(label)(stack)
case BrIf(label) =>
val I32V(cond) :: newStack = stack
if (cond != 0) trail(label)(newStack)
else eval(rest, newStack, frame, kont, trail, ret)
else eval(rest, newStack, frame, kont, trail)
case BrTable(labels, default) =>
val I32V(cond) :: newStack = stack
val goto = if (cond < labels.length) labels(cond) else default
trail(goto)(newStack)
case Return => trail(ret)(stack)
case Return => trail.last(stack)
case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] =>
val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f)
val args = stack.take(ty.inps.size).reverse
val newStack = stack.drop(ty.inps.size)
val frameLocals = args ++ locals.map(zero(_))
val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*))
val newK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret)
val newK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail)
// We push newK on the trail since function creates a new block to escape
// (more or less like `return`)
eval(body, List(), newFrame, newK, List(newK), 0)
eval(body, List(), newFrame, newK, List(newK))
case Call(f) if frame.module.funcs(f).isInstanceOf[Import] =>
frame.module.funcs(f) match {
case Import("console", "log", _) =>
//println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
eval(rest, newStack, frame, kont, trail, ret)
eval(rest, newStack, frame, kont, trail)
case f => throw new Exception(s"Unknown import $f")
}
case _ =>
Expand Down Expand Up @@ -404,7 +403,7 @@ object Evaluator {

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)

Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0)
Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt))
}

def evalTop(m: Module): Unit = evalTop(m, stack => ())
Expand Down

0 comments on commit 657fec5

Please sign in to comment.