Skip to content

Commit

Permalink
Supporting export (#58)
Browse files Browse the repository at this point in the history
evalTop runs exported function now
  • Loading branch information
ahuoguo authored Oct 14, 2024
1 parent 70b3d23 commit f295830
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 76 deletions.
4 changes: 2 additions & 2 deletions benchmarks/wasm/ack.wat
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
(global (;1;) i32 (i32.const 1048576))
(global (;2;) i32 (i32.const 1048576))
(export "memory" (memory 0))
(export "ack" (func $ack))
(export "real_main" (func $real_main))
(export "ack" (func 0))
(export "real_main" (func 1))
(export "__data_end" (global 1))
(export "__heap_base" (global 2)))
4 changes: 2 additions & 2 deletions benchmarks/wasm/pow.wat
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
(global (;1;) i32 (i32.const 1048576))
(global (;2;) i32 (i32.const 1048576))
(export "memory" (memory 0))
(export "power" (func $power))
(export "real_main" (func $real_main))
(export "power" (func 0))
(export "real_main" (func 1))
(export "__data_end" (global 1))
(export "__heap_base" (global 2)))
2 changes: 1 addition & 1 deletion benchmarks/wasm/return.wat
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
call 0
unreachable
)
(start 1)
(export "$real_main" (func 1))
)
15 changes: 11 additions & 4 deletions grammar/gen-wat-parser.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
#
ANTLR4=antlr-4.13.0-complete.jar

java20 -jar $ANTLR4 WatLexer.g4
java20 -jar $ANTLR4 -visitor WatParser.g4
java -jar $ANTLR4 WatLexer.g4
java -jar $ANTLR4 -visitor WatParser.g4

DST=../src/main/java/wasm/
echo "Copy WAT parsers into $DST"

for file in "WatLexer.java" "WatParserBaseVisitor.java" "WatParserListener.java" "WatParserBaseListener.java" "WatParser.java" "WatParserVisitor.java"
do
sed -i "1ipackage gensym.wasm;$line" $file
cp $file $DST/$file
if [[ "$OSTYPE" == "darwin"* ]]; then
sed -i '' $'1i\\\npackage gensym.wasm;\n' $file
else
sed -i "1ipackage gensym.wasm;$line" $file
fi
cp $file "$DST/$file"
rm $file
done

rm *.tokens *.interp
53 changes: 18 additions & 35 deletions src/main/scala/wasm/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ case class TypeDef(id: Option[String], tipe: FuncType) extends Definition
case class Table(id: Option[String], f: TableField) extends Definition
case class Memory(id: Option[String], f: MemoryField) extends Definition
case class Global(id: Option[String], f: GlobalField) extends Definition
case class Elem(id: Option[Int], offset: List[Instr], elemList: ElemList)
extends Definition
case class Elem(id: Option[Int], offset: List[Instr], elemList: ElemList) extends Definition
case class Data(id: Option[String], value: String) extends Definition
case class Start(id: Int) extends Definition
case class Import(mod: String, name: String, desc: ImportDesc) extends Definition
case class Export(name: String, desc: ExportDesc) extends Definition
// FIXME: missing top-level module fields, see WatParser.g4

abstract class ImportDesc extends WIR
Expand All @@ -30,16 +30,9 @@ case class ElemListFunc(funcs: List[String]) extends ElemList
case class ElemListExpr(exprs: List[List[Instr]]) extends ElemList

abstract class FuncField extends WIR
case class FuncBodyDef(tipe: FuncType,
localNames: List[String],
locals: List[ValueType],
body: List[Instr])
extends FuncField
case class FunInlineImport(mod: String,
name: String,
typeUse: Option[Int],
imports: Any /*FIXME*/ )
case class FuncBodyDef(tipe: FuncType, localNames: List[String], locals: List[ValueType], body: List[Instr])
extends FuncField
case class FunInlineImport(mod: String, name: String, typeUse: Option[Int], imports: Any /*FIXME*/ ) extends FuncField
case class FunInlineExport(fd: List[FuncDef]) extends FuncField

abstract class TableField extends WIR
Expand Down Expand Up @@ -69,17 +62,11 @@ case object Alloc extends Instr
case object Free extends Instr
case class Select(ty: Option[List[ValueType]]) extends Instr
case class Block(ty: Option[ValueType], instrs: List[Instr]) extends Instr
case class IdBlock(id: Int, ty: Option[ValueType], instrs: List[Instr])
extends Instr
case class IdBlock(id: Int, ty: Option[ValueType], instrs: List[Instr]) extends Instr
case class Loop(ty: Option[ValueType], instrs: List[Instr]) extends Instr
case class IdLoop(id: Int, ty: Option[ValueType], instrs: List[Instr])
extends Instr
case class If(ty: Option[ValueType],
thenInstrs: List[Instr],
elseInstrs: List[Instr])
extends Instr
case class IdIf(ty: Option[ValueType], thenInstrs: IdBlock, elseInstrs: IdBlock)
extends Instr
case class IdLoop(id: Int, ty: Option[ValueType], instrs: List[Instr]) extends Instr
case class If(ty: Option[ValueType], thenInstrs: List[Instr], elseInstrs: List[Instr]) extends Instr
case class IdIf(ty: Option[ValueType], thenInstrs: IdBlock, elseInstrs: IdBlock) extends Instr
// FIXME: labelId can be string?
case class Br(labelId: Int) extends Instr
case class BrIf(labelId: Int) extends Instr
Expand Down Expand Up @@ -221,16 +208,8 @@ case object SX extends Extension
case object ZX extends Extension

abstract class MemOp(align: Int, offset: Int) extends WIR
case class StoreOp(align: Int,
offset: Int,
tipe: NumType,
pack_size: Option[PackSize])
extends MemOp(align, offset)
case class LoadOp(align: Int,
offset: Int,
tipe: NumType,
pack_size: Option[PackSize],
extension: Option[Extension])
case class StoreOp(align: Int, offset: Int, tipe: NumType, pack_size: Option[PackSize]) extends MemOp(align, offset)
case class LoadOp(align: Int, offset: Int, tipe: NumType, pack_size: Option[PackSize], extension: Option[Extension])
extends MemOp(align, offset)

// Types
Expand All @@ -255,10 +234,7 @@ case class NumType(kind: NumKind) extends ValueType
case class VecType(kind: VecKind) extends ValueType
case class RefType(kind: RefKind) extends ValueType

case class FuncType(argNames /*optional*/: List[String],
inps: List[ValueType],
out: List[ValueType])
extends WasmType
case class FuncType(argNames /*optional*/: List[String], inps: List[ValueType], out: List[ValueType]) extends WasmType

case class GlobalType(ty: ValueType, mut: Boolean) extends WasmType

Expand Down Expand Up @@ -302,3 +278,10 @@ case class I32V(value: Int) extends Num
case class I64V(value: Long) extends Num
case class F32V(value: Float) extends Num
case class F64V(value: Double) extends Num

// Exports
abstract class ExportDesc extends WIR
case class ExportFunc(i: Int) extends ExportDesc
case class ExportTable(i: Int) extends ExportDesc
case class ExportMemory(i: Int) extends ExportDesc
case class ExportGlobal(i: Int) extends ExportDesc
3 changes: 1 addition & 2 deletions src/main/scala/wasm/Memory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ object RTMemory {
def apply(): RTMemory = this.apply(1024)
def apply(size: Int): RTMemory = this.apply(size, None)
def apply(size: Int, maxSize: Option[Int]): RTMemory = {
new RTMemory(RTMemoryType(size, maxSize),
ArrayBuffer.fill[Byte](size * pageSize.toInt)(0))
new RTMemory(RTMemoryType(size, maxSize), ArrayBuffer.fill[Byte](size * pageSize.toInt)(0))
}
}
59 changes: 35 additions & 24 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ case class ModuleInstance(
funcs: HashMap[Int, WIR],
memory: List[RTMemory] = List(RTMemory()),
globals: List[RTGlobal] = List(),
exports: List[Export] = List()
)

object Primtives {
Expand Down Expand Up @@ -230,8 +231,10 @@ object Evaluator {
val mem = frame.module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) => eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret)
case _ => eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret)
case Some(e) =>
eval(rest, I32V(-1) :: newStack, frame, kont, trail, ret)
case _ =>
eval(rest, I32V(oldSize) :: newStack, frame, kont, trail, ret)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
Expand Down Expand Up @@ -274,26 +277,24 @@ object Evaluator {
eval(rest, stack, frame, kont, trail, ret)
case Unreachable => throw Trap()
case Block(ty, inner) =>
val k: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
val k: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
// TODO: block can take inputs too
eval(inner, List(), frame, k, k :: trail, ret+1)
eval(inner, List(), frame, k, k :: trail, ret + 1)
case Loop(ty, inner) =>
// We construct two continuations, one for the break (to the begining of the loop),
// and one for fall-through to the next instruction following the syntactic structure
// of the program.
val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ stack, frame, kont, trail, ret)
def loop(stack: List[Value]): Ans = {
val k: Cont[Ans] = (retStack) => loop(retStack.take(ty.toList.size))
eval(inner, stack, frame, restK, k :: trail, ret+1)
eval(inner, stack, frame, restK, k :: trail, ret + 1)
}
loop(List())
case If(ty, thn, els) =>
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val k: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret)
eval(inner, List(), frame, k, k :: trail, ret+1)
val k: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.toList.size) ++ newStack, frame, kont, trail, ret)
eval(inner, List(), frame, k, k :: trail, ret + 1)
case Br(label) =>
trail(label)(stack)
case BrIf(label) =>
Expand All @@ -311,8 +312,7 @@ object Evaluator {
val newStack = stack.drop(ty.inps.size)
val frameLocals = args ++ locals.map(zero(_))
val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*))
val newK: Cont[Ans] = (retStack) =>
eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret)
val newK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail, ret)
// We push newK on the trail since function creates a new block to escape
// (more or less like `return`)
eval(body, List(), newFrame, newK, List(newK), 0)
Expand All @@ -335,21 +335,32 @@ object Evaluator {
// otherwise, we look up the top-level `start` instruction to locate the entry point.
def evalTop[Ans](module: Module, halt: Cont[Ans], main: Option[String] = None): Ans = {
val instrs = main match {
case Some(_) => module.definitions.flatMap({
case FuncDef(`main`, FuncBodyDef(_, _, _, body)) =>
case Some(func_name) =>
module.definitions.flatMap({
case Export(`func_name`, ExportFunc(fid)) =>
println(s"Entering function $main")
body
module.funcEnv(fid) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) => body
case _ =>
throw new Exception("Entry function has no concrete body")
}
case _ => List()
})
case None => module.definitions.flatMap({
case Start(id) => module.funcEnv(id) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) =>
case None =>
module.definitions.flatMap({
case Start(id) =>
println(s"Entering unnamed function $id")
body
case _ => throw new Exception("Start function has no concrete definition")
}
case _ => List()
})
module.funcEnv(id) match {
case FuncDef(_, FuncBodyDef(_, _, _, body)) => body
case _ =>
throw new Exception("Entry function has no concrete body")
}
case _ => List()
})
}

if (instrs.isEmpty) {
println("Warning: nothing is executed")
}

val types = List()
Expand All @@ -364,8 +375,8 @@ object Evaluator {
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non exhaustive
case _ => ???
// Q: What is the default behavior if case in non-exhaustive
case _ => ???
}
})
.toList
Expand Down
22 changes: 22 additions & 0 deletions src/main/scala/wasm/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,28 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] {
val field = visitGlobalField(ctx.globalField)
Global(name, field)
}

override def visitExport_(ctx: Export_Context): WIR = {
val name: String = ctx.name.getText.substring(1).dropRight(1)
val desc = visitExportDesc(ctx.exportDesc).asInstanceOf[ExportDesc]
Export(name, desc)
}

override def visitExportDesc(ctx: ExportDescContext): WIR = {
val id = if (ctx.idx.VAR() != null) {
println(s"Warning: we don't support labeling yet")
throw new RuntimeException("Unsupported")
} else {
getVar(ctx.idx()).toInt
}
if (ctx.FUNC != null) ExportFunc(id)
else if (ctx.TABLE != null) ExportTable(id)
else if (ctx.MEMORY != null) ExportMemory(id)
else if (ctx.GLOBAL != null) ExportGlobal(id)
else error

}

}

object Parser {
Expand Down
10 changes: 4 additions & 6 deletions src/test/scala/genwasym/TestEval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ import collection.mutable.ArrayBuffer
import org.scalatest.FunSuite

class TestEval extends FunSuite {
def testFile(filename: String,
main: Option[String] = None,
expected: Option[Int] = None) = {
def testFile(filename: String, main: Option[String] = None, expected: Option[Int] = None) = {
val module = Parser.parseFile(filename)
//println(module)
val haltK: Evaluator.Cont[Unit] = stack => {
Expand All @@ -32,8 +30,8 @@ class TestEval extends FunSuite {
// TODO: the power test can be used to test the stack
// For now: 2^10 works, 2^100 results in 0 (TODO: why?),
// and 2^1000 results in a stack overflow
test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("$real_main"), Some(7)) }
test("power") { testFile("./benchmarks/wasm/pow.wat", Some("$real_main"), Some(1024)) }
test("ack") { testFile("./benchmarks/wasm/ack.wat", Some("real_main"), Some(7)) }
test("power") { testFile("./benchmarks/wasm/pow.wat", Some("real_main"), Some(1024)) }
test("start") { testFile("./benchmarks/wasm/start.wat") }
test("fact") { testFile("./benchmarks/wasm/fact.wat", None, Some(120)) }
test("loop") { testFile("./benchmarks/wasm/loop.wat", None, Some(10)) }
Expand All @@ -45,7 +43,7 @@ class TestEval extends FunSuite {

test("return") {
intercept[gensym.wasm.miniwasm.Trap] {
testFile("./benchmarks/wasm/return.wat", None, None)
testFile("./benchmarks/wasm/return.wat", Some("$real_main"), None)
}
}

Expand Down

0 comments on commit f295830

Please sign in to comment.