Skip to content

Commit

Permalink
refactor and nicify things
Browse files Browse the repository at this point in the history
  • Loading branch information
Kraks committed Oct 4, 2024
1 parent fe59dad commit 4d9fe0c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 97 deletions.
123 changes: 85 additions & 38 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ object Evaluator {
def eval(insts: List[Instr],
stack: List[Value],
frame: Frame,
kont: Cont,
trail: List[Cont],
kont: Cont)
(implicit retKontIdx: Int): Unit = {
ret: Int): Unit = {
if (insts.isEmpty) return kont(stack)

val inst = insts.head
Expand All @@ -185,23 +185,23 @@ object Evaluator {
//println(f"stack size: ${stack.size}")
//println(s"eval: $inst")
inst match {
case Drop => eval(rest, stack.tail, frame, trail, kont)
case Drop => eval(rest, stack.tail, frame, kont, trail, ret)
case Select(_) =>
val I32V(cond) :: v2 :: v1 :: newStack = stack
val value = if (cond == 0) v1 else v2
eval(rest, value :: newStack, frame, trail, kont)
eval(rest, value :: newStack, frame, kont, trail, ret)
case LocalGet(i) =>
eval(rest, frame.locals(i) :: stack, frame, trail, kont)
eval(rest, frame.locals(i) :: stack, frame, kont, trail, ret)
case LocalSet(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, newStack, frame, trail, kont)
eval(rest, newStack, frame, kont, trail, ret)
case LocalTee(i) =>
val value :: newStack = stack
frame.locals(i) = value
eval(rest, stack, frame, trail, kont)
eval(rest, stack, frame, kont, trail, ret)
case GlobalGet(i) =>
eval(rest, frame.module.globals(i).value :: stack, frame, trail, kont)
eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail, ret)
case GlobalSet(i) =>
val value :: newStack = stack
frame.module.globals(i).ty match {
Expand All @@ -210,113 +210,160 @@ object Evaluator {
case GlobalType(_, true) => throw new Exception("Invalid type")
case _ => throw new Exception("Cannot set immutable global")
}
eval(rest, newStack, frame, trail, kont)
eval(rest, newStack, frame, kont, trail, ret)
case MemorySize =>
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, trail, kont)
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail, ret)
case MemoryGrow =>
val I32V(delta) :: newStack = stack
val mem = frame.module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, trail, kont)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, trail, kont)
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
if (memOutOfBound(frame, 0, offset, size))
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
else {
frame.module.memory.head.fill(offset, size, value.toByte)
eval(rest, newStack, frame, trail, kont)
eval(rest, newStack, frame, kont, trail, ret)
}
case MemoryCopy =>
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
if (memOutOfBound(frame, 0, src, n) || memOutOfBound(frame, 0, dest, n))
throw new Exception("Out of bounds memory access")
else {
frame.module.memory.head.copy(dest, src, n)
eval(rest, newStack, frame, trail, kont)
eval(rest, newStack, frame, kont, trail, ret)
}
case Const(n) => eval(rest, n :: stack, frame, trail, kont)
case Const(n) => eval(rest, n :: stack, frame, kont, trail, ret)
case Binary(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, trail, kont)
eval(rest, evalBinOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
case Unary(op) =>
val v :: newStack = stack
eval(rest, evalUnaryOp(op, v) :: newStack, frame, trail, kont)
eval(rest, evalUnaryOp(op, v) :: newStack, frame, kont, trail, ret)
case Compare(op) =>
val v2 :: v1 :: newStack = stack
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, trail, kont)
eval(rest, evalRelOp(op, v1, v2) :: newStack, frame, kont, trail, ret)
case Test(op) =>
val v :: newStack = stack
eval(rest, evalTestOp(op, v) :: newStack, frame, trail, kont)
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail, ret)
case Store(StoreOp(align, offset, ty, None)) =>
val I32V(v) :: I32V(addr) :: newStack = stack
frame.module.memory(0).storeInt(addr + offset, v)
eval(rest, newStack, frame, trail, kont)
eval(rest, newStack, frame, kont, trail, ret)
case Load(LoadOp(align, offset, ty, None, None)) =>
val I32V(addr) :: newStack = stack
val value = frame.module.memory(0).loadInt(addr + offset)
eval(rest, I32V(value) :: newStack, frame, trail, kont)
eval(rest, I32V(value) :: newStack, frame, kont, trail, ret)
case Nop =>
eval(rest, stack, frame, trail, kont)
eval(rest, stack, frame, kont, trail, ret)
case Unreachable => throw new RuntimeException("Unreachable")
case Block(ty, inner) =>
val k: Cont = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail, kont)
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
// TODO: block can take inputs too
eval(inner, List(), frame, k :: trail, k)(retKontIdx + 1)
eval(inner, List(), frame, k, k :: trail, ret+1)
case Loop(ty, inner) =>
// We construct two continuations, one for the break (to the begining of the loop),
// and one for fall-through to the next instruction following the syntactic structure
// of the program.
val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, trail, kont)
val restK: Cont = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
def loop(stack: List[Value]): Unit = {
val k: Cont = (retStack) => loop(retStack.take(ty.toList.size))
eval(inner, stack, frame, k :: trail, restK)(retKontIdx + 1)
eval(inner, stack, frame, restK, k :: trail, ret+1)
}
loop(List())
case If(ty, thn, els) =>
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val k: Cont = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, trail, kont)
eval(inner, List(), frame, k :: trail, k)(retKontIdx + 1)
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret)
eval(inner, List(), frame, k, k :: trail, ret+1)
case Br(label) =>
trail(label)(stack)
case BrIf(label) =>
val I32V(cond) :: newStack = stack
println(s"brif: $cond")
if (cond != 0) trail(label)(newStack)
else {
println(s"br if rest $rest")
eval(rest, newStack, frame, trail, kont)
}
case Return => trail(retKontIdx)(stack)
else eval(rest, newStack, frame, kont, trail, ret)
case Return => trail(ret)(stack)
case Call(f) if frame.module.funcs(f).isInstanceOf[FuncDef] =>
val FuncDef(_, FuncBodyDef(ty, _, locals, body)) = frame.module.funcs(f)
println(s"calling $f")
val args = stack.take(ty.inps.size).reverse
val newStack = stack.drop(ty.inps.size)
val frameLocals = args ++ locals.map(_ => I32V(0)) // GW: always I32? or depending on their types?
val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*))
val newK: RetCont = (retStack) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, trail, kont)
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret)
// We push newK on the trail since function creates a new block to escape
// (more or less like `return`)
eval(body, List(), newFrame, newK :: trail, newK)(retKontIdx + 1)
eval(body, List(), newFrame, newK, newK :: trail, ret+1)
case Call(f) if frame.module.funcs(f).isInstanceOf[Import] =>
frame.module.funcs(f) match {
case Import("console", "log", _) =>
//println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
eval(rest, newStack, frame, trail, kont)
eval(rest, newStack, frame, kont, trail, ret)
case f => throw new Exception(s"Unknown import $f")
}
case _ =>
println(inst)
throw new Exception(s"instruction $inst not implemented")
}
}

// If `main` is given, then we use that function as the entry point of the program;
// otherwise, we look up the top-level `start` instruction to locate the entry point.
def evalTop(module: Module, halt: Cont, main: Option[String]): Unit = {
val instrs = main match {
case Some(_) => module.definitions.flatMap({
case FuncDef(`main`, FuncBodyDef(_, _, _, body)) =>
println(s"Entering function $main")
body
case _ => List()
})
case None => module.definitions.flatMap({
case Start(id) => module.funcEnv(id) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) =>
println(s"Entering unnamed function $id")
body
case _ => throw new Exception("Start function has no concrete definition")
}
case _ => List()
})
}

val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)

Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt), 0)
}
}
62 changes: 3 additions & 59 deletions src/test/scala/genwasym/TestEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,70 +19,14 @@ class TestEval extends FunSuite {
def testFile(filename: String, main: Option[String] = None, expected: Option[Int] = None) = {
val module = Parser.parseFile(filename)
//println(module)

val instrs = main match {
case Some(_) => module.definitions.flatMap({
case FuncDef(`main`, FuncBodyDef(_, _, _, body)) =>
println(s"Entering function $main")
body
case _ => List()
})
case None => module.definitions.flatMap({
case Start(id) => module.funcEnv(id) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) =>
println(s"Entering unnamed function $id")
body
case _ => throw new Exception("Start function has no concrete definition")
}
case _ => List()
})
}

val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)

val trailK: Evaluator.Cont = newStack => {
println(s"trail: $newStack")
val haltK: Evaluator.Cont = newStack => {
println(s"halt cont: $newStack")
expected match {
case Some(e) => assert(newStack(0) == I32V(e))
case None => ()
}
}

implicit val restK: Evaluator.Cont = newStack => {
println(s"restK trail: $newStack")
expected match {
case Some(e) => assert(newStack(0) == I32V(e))
case None => ()
}

}
Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), List(trailK), restK)(0)
Evaluator.evalTop(module, haltK, main)
}

// TODO: the power test can be used to test the stack
Expand Down

0 comments on commit 4d9fe0c

Please sign in to comment.