Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return #46

Merged
merged 8 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmarks/wasm/even_odd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#[no_mangle]
#[inline(never)]
fn is_even(n: u32) -> bool {
if n == 0 { true }
else { is_odd(n - 1) }
}

#[no_mangle]
#[inline(never)]
fn is_odd(n: u32) -> bool {
if n == 0 { false }
else { is_even(n - 1) }
}

#[no_mangle]
fn real_main() -> bool {
is_even(12)
}
31 changes: 31 additions & 0 deletions benchmarks/wasm/even_odd.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
(module
(type (;0;) (func (param i32) (result i32)))
(type (;1;) (func (result i32)))
(func (;0;) (type 0) (param i32) (result i32)
block ;; label = @1
local.get 0
br_if 0 (;@1;)
i32.const 1
return
end
local.get 0
i32.const -1
i32.add
call 1)
(func (;1;) (type 0) (param i32) (result i32)
block ;; label = @1
local.get 0
br_if 0 (;@1;)
i32.const 0
return
end
local.get 0
i32.const -1
i32.add
call 0)
(func (;2;) (type 1) (result i32)
i32.const 13
call 1)
(start 2)
(memory (;0;) 16)
)
13 changes: 13 additions & 0 deletions benchmarks/wasm/return.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
(module
(type (;0;) (func))
(func (;0;) (type 0)
block ;; label = @1
return
end
unreachable
)
(func (;1;) (type 0)
call 0
)
(start 1)
)
143 changes: 97 additions & 46 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value])
object Evaluator {
import Primtives._

type RetCont = List[Value] => Unit
type Cont = List[Value] => Unit
type Cont[A] = List[Value] => A

def eval(insts: List[Instr],
stack: List[Value],
frame: Frame,
trail: List[Cont])
(implicit kont: Cont): Unit = {
def eval[Ans](insts: List[Instr],
stack: List[Value],
frame: Frame,
kont: Cont[Ans],
trail: List[Cont[Ans]],
ret: Int): Ans = {
if (insts.isEmpty) return kont(stack)

val inst = insts.head
Expand All @@ -184,23 +184,23 @@ object Evaluator {
//println(f"stack size: ${stack.size}")
//println(s"eval: $inst")
inst match {
case Drop => eval(rest, stack.tail, frame, trail)
case Drop => eval(rest, stack.tail, frame, kont, trail, ret)
case Select(_) =>
val I32V(cond) :: v2 :: v1 :: newStack = stack
val value = if (cond == 0) v1 else v2
eval(rest, value :: newStack, frame, trail)
eval(rest, value :: newStack, frame, kont, trail, ret)
case LocalGet(i) =>
eval(rest, frame.locals(i) :: stack, frame, trail)
eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret)
case LocalSet(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case LocalTee(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, stack, frame, trail)
eval(rest, stack, frame, kont, trail, ret)
case GlobalGet(i) =>
eval(rest, frame.module.globals(i).value :: stack, frame, trail)
eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret)
case GlobalSet(i) =>
val value :: newStack = stack
frame.module.globals(i).ty match {
Expand All @@ -209,111 +209,162 @@ object Evaluator {
case GlobalType(_, true) => throw new Exception("Invalid type")
case _ => throw new Exception("Cannot set immutable global")
}
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case MemorySize =>
eval(rest,
I32V(frame.module.memory.head.size) :: stack,
frame,
trail)
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret)
case MemoryGrow =>
val I32V(delta) :: newStack = stack
val mem = frame.module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, trail)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, trail)
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
if (memOutOfBound(frame, 0, offset, size))
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
else {
frame.module.memory.head.fill(offset, size, value.toByte)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
}
case MemoryCopy =>
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
if (memOutOfBound(frame, 0, src, n) || memOutOfBound(frame, 0, dest, n))
throw new Exception("Out of bounds memory access")
else {
frame.module.memory.head.copy(dest, src, n)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
}
case Const(n) => eval(rest, n :: stack, frame, trail)
case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret)
case Binary(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, trail)
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
case Unary(op) =>
val v :: newStack = stack
eval(rest, evalUnaryOp(op, v) :: newStack, frame, trail)
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret)
case Compare(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, trail)
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
case Test(op) =>
val v :: newStack = stack
eval(rest, evalTestOp(op, v) :: newStack, frame, trail)
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret)
case Store(StoreOp(align, offset, ty, None)) =>
val I32V(v) :: I32V(addr) :: newStack = stack
frame.module.memory(0).storeInt(addr + offset, v)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case Load(LoadOp(align, offset, ty, None, None)) =>
val I32V(addr) :: newStack = stack
val value = frame.module.memory(0).loadInt(addr + offset)
eval(rest, I32V(value) :: newStack, frame, trail)
eval(rest, I32V(value) :: newStack, frame, kont, trail, ret)
case Nop =>
eval(rest, stack, frame, trail)
eval(rest, stack, frame, kont, trail, ret)
case Unreachable => throw new RuntimeException("Unreachable")
case Block(ty, inner) =>
val k: Cont = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail)
val k: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
// TODO: block can take inputs too
eval(inner, List(), frame, k :: trail)(k)
eval(inner, List(), frame, k, k :: trail, ret+1)
case Loop(ty, inner) =>
// We construct two continuations, one for the break (to the begining of the loop),
// and one for fall-through to the next instruction following the syntactic structure
// of the program.
val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail)
def loop(stack: List[Value]): Unit = {
val k: Cont = (retStack) => loop(retStack.take(ty.toList.size))
eval(inner, stack, frame, k :: trail)(restK)
val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
def loop(stack: List[Value]): Ans = {
val k: Cont[Ans] = (retStack) => loop(retStack.take(ty.toList.size))
eval(inner, stack, frame, restK, k :: trail, ret+1)
}
loop(List())
case If(ty, thn, els) =>
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val k: Cont = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, trail)
eval(inner, List(), frame, k :: trail)(k)
val k: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret)
eval(inner, List(), frame, k, k :: trail, ret+1)
case Br(label) =>
trail(label)(stack)
case BrIf(label) =>
val I32V(cond) :: newStack = stack
if (cond != 0) trail(label)(newStack)
else eval(rest, newStack, frame, trail)
case Return => kont(stack)
else eval(rest, newStack, frame, kont, trail, ret)
case Return => trail(ret)(stack)
case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] =>
val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f)
val args = stack.take(ty.inps.size).reverse
val newStack = stack.drop(ty.inps.size)
val frameLocals = args ++ locals.map(_ => I32V(0)) // GW: always I32? or depending on their types?
val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*))
val newK: RetCont = (retStack) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, trail)
val newK: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret)
// We push newK on the trail since function creates a new block to escape
// (more or less like `return`)
eval(body, List(), newFrame, newK :: trail)(newK)
eval(body, List(), newFrame, newK, newK :: trail, ret+1)
case Call(f) if frame.module.funcs(f).isInstanceOf[Import] =>
frame.module.funcs(f) match {
case Import("console", "log", _) =>
//println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
eval(rest, newStack, frame, trail)
eval(rest, newStack, frame, kont, trail, ret)
case f => throw new Exception(s"Unknown import $f")
}
case _ =>
println(inst)
throw new Exception(s"instruction $inst not implemented")
}
}

// If `main` is given, then we use that function as the entry point of the program;
// otherwise, we look up the top-level `start` instruction to locate the entry point.
def evalTop[Ans](module: Module, halt: Cont[Ans], main: Option[String] = None): Ans = {
val instrs = main match {
case Some(_) => module.definitions.flatMap({
case FuncDef(`main`, FuncBodyDef(_, _, _, body)) =>
println(s"Entering function $main")
body
case _ => List()
})
case None => module.definitions.flatMap({
case Start(id) => module.funcEnv(id) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) =>
println(s"Entering unnamed function $id")
body
case _ => throw new Exception("Start function has no concrete definition")
}
case _ => List()
})
}

val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)

Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0)
}

def evalTop(m: Module): Unit = evalTop(m, stack => ())
}
Loading
Loading