From 2eb36f046fe722ea38b080a905ded4800899789a Mon Sep 17 00:00:00 2001 From: ahuoguo <52595524+ahuoguo@users.noreply.github.com> Date: Mon, 4 Nov 2024 18:05:27 -0500 Subject: [PATCH] Basic `(module bin ...)` support (#62) * basic .bin.wast working --------- Co-authored-by: butterunderflow Co-authored-by: Guannan Wei --- .github/workflows/scala.yml | 4 + .gitmodules | 3 + benchmarks/wasm/script/script_basic.bin.wast | 7 ++ src/main/scala/wasm/MiniWasm.scala | 2 + src/main/scala/wasm/Parser.scala | 126 +++++++++++++++++-- src/test/scala/genwasym/TestScriptRun.scala | 6 +- src/test/scala/genwasym/TestSyntax.scala | 17 +++ third-party/wasmfx-tools | 1 + 8 files changed, 152 insertions(+), 14 deletions(-) create mode 100644 benchmarks/wasm/script/script_basic.bin.wast create mode 100644 src/test/scala/genwasym/TestSyntax.scala create mode 160000 third-party/wasmfx-tools diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index a691b744..1938fc40 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -63,6 +63,10 @@ jobs: make sudo make install sudo ldconfig + - name: Install wasmfx-tools + run: | + cd third-party/wasmfx-tools + cargo build --release - name: Generate models run: sbt 'runMain gensym.GenerateExternal' - name: Run tests diff --git a/.gitmodules b/.gitmodules index a42bcff5..8184fb1b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -24,3 +24,6 @@ path = third-party/lms-clean url = https://github.com/TiarkRompf/lms-clean ignore = dirty +[submodule "third-party/wasmfx-tools"] + path = third-party/wasmfx-tools + url = git@github.com:wasmfx/wasmfx-tools.git diff --git a/benchmarks/wasm/script/script_basic.bin.wast b/benchmarks/wasm/script/script_basic.bin.wast new file mode 100644 index 00000000..e0ae65fe --- /dev/null +++ b/benchmarks/wasm/script/script_basic.bin.wast @@ -0,0 +1,7 @@ +(module binary + "\00\61\73\6d\01\00\00\00\01\85\80\80\80\00\01\60" + "\00\01\7f\03\82\80\80\80\00\01\00\07\87\80\80\80" + "\00\01\03\6f\6e\65\00\00\0a\8a\80\80\80\00\01\84" + "\80\80\80\00\00\41\01\0b" +) +(assert_return (invoke "one") (i32.const 0x1)) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 2f84d63f..f2ae33aa 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -63,6 +63,8 @@ object Primtives { (lhs, rhs) match { case (I32V(v1), I32V(v2)) => I32V(v1 + v2) case (I64V(v1), I64V(v2)) => I64V(v1 + v2) + case (F32V(v1), F32V(v2)) => F32V(v1 + v2) + case (F64V(v1), F64V(v2)) => F64V(v1 + v2) case _ => throw new Exception("Invalid types") } case Mul(_) => diff --git a/src/main/scala/wasm/Parser.scala b/src/main/scala/wasm/Parser.scala index 3757bcbc..b53ce638 100644 --- a/src/main/scala/wasm/Parser.scala +++ b/src/main/scala/wasm/Parser.scala @@ -3,6 +3,7 @@ package gensym.wasm.parser import gensym.wasm.ast._ import gensym.wasm.source._ +import scala.util.Try import scala.util.parsing.combinator._ import scala.util.parsing.input.Positional import scala.util.matching.Regex @@ -14,6 +15,9 @@ import scala.collection.JavaConverters._ import collection.mutable.{HashMap, ListBuffer} import gensym.wasm._ +import java.io.OutputStream + + import scala.collection.mutable class GSWasmVisitor extends WatParserBaseVisitor[WIR] { @@ -188,23 +192,78 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { ??? } + // TODO: This doesn't seems quite correct + def parseHexFloat(text: String): Float = { + if (text.startsWith("0x") || text.startsWith("-0x") || text.startsWith("+0x")) { + // Remove optional sign and "0x" prefix + val cleanText = text.replaceFirst("^[+-]?0x", "") + // why removing the seemling irrelevant following two lines will effect + // the value being parsed? + val value: Float = BigDecimal(text).floatValue + print(f"cleanText = $cleanText, value = $value\n") + + val Array(mantissa, exponent) = cleanText.split("p", 2) + + // Convert mantissa and exponent + val mantissaValue = java.lang.Float.intBitsToFloat(java.lang.Integer.parseUnsignedInt(mantissa.replace(".", ""), 16)) + val exponentValue = Math.pow(2, exponent.toInt).toFloat + // print(s"mantissaValue = $mantissaValue, exponentValue = $exponentValue\n") + mantissaValue * exponentValue + } else { + text.toFloat // Fall back to regular decimal parsing + } + } + + def visitLiteralWithType(ctx: LiteralContext, ty: NumType): Num = { if (ctx.NAT != null) { ty.kind match { - case I32Type => I32V(ctx.NAT.getText.toInt) - case I64Type => I64V(ctx.NAT.getText.toLong) + case I32Type => { + if (ctx.NAT.getText.startsWith("0x")) { + I32V(Integer.parseInt(ctx.NAT.getText.substring(2), 16)) + } else { + I32V(ctx.NAT.getText.toInt) + } + } + case I64Type => { + if (ctx.NAT.getText.startsWith("0x")) { + I64V(java.lang.Long.parseLong(ctx.NAT.getText.substring(2), 16)) + } else { + I64V(ctx.NAT.getText.toLong) + } + } } } else if (ctx.INT != null) { ty.kind match { - case I32Type => I32V(ctx.INT.getText.toInt) - case I64Type => I64V(ctx.INT.getText.toLong) + case I32Type => { + if (ctx.INT.getText.startsWith("0x")) { + I32V(Integer.parseInt(ctx.INT.getText.substring(2), 16)) + } else { + I32V(ctx.INT.getText.toInt) + } + } + case I64Type => { + if (ctx.INT.getText.startsWith("0x")) { + I64V(java.lang.Long.parseLong(ctx.INT.getText.substring(2), 16)) + } else { + I64V(ctx.INT.getText.toLong) + } + } } + // TODO: parsing support for hex representation for f32/f64 not quite there yet } else if (ctx.FLOAT != null) { ty.kind match { - case F32Type => F32V(ctx.FLOAT.getText.toFloat) - case F64Type => F64V(ctx.FLOAT.getText.toDouble) + case F32Type => + val parsedValue = Try(parseHexFloat(ctx.FLOAT.getText).toFloat).getOrElse(ctx.FLOAT.getText.toFloat) + F32V(parsedValue) + + case F64Type => + // TODO: not processed at all + val parsedValue = ctx.FLOAT.getText.toDouble + F64V(parsedValue) } - } else error + } + else error } override def visitPlainInstr(ctx: PlainInstrContext): Instr = { @@ -635,10 +694,45 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { else error } - override def visitScriptModule(ctx: ScriptModuleContext): Module = { + override def visitScriptModule(ctx: ScriptModuleContext): Module = { if (ctx.module_ != null) { visitModule_(ctx.module_).asInstanceOf[Module] - } else { + } + else if (ctx.BIN != null) { + + val bin = ctx.STRING_ + val hexString = bin.asScala.toList.map(_.getText.substring(1).dropRight(1)).mkString + + val byteArray: Array[Byte] = hexStringToByteArray(hexString) + + // just for fact checking + // val filePath = "temp.bin" + // Files.write(Paths.get(filePath), byteArray) + + // use `wasmfx-tools` to convert the binary file to a text file + val processBuilder = new ProcessBuilder("./third-party/wasmfx-tools/target/release/wasm-tools", "print") + + val process = processBuilder.start() + val outputStream: OutputStream = process.getOutputStream + try { + outputStream.write(byteArray) + outputStream.flush() + } finally { + outputStream.close() // Close the stream to signal end of input + } + + val output = scala.io.Source.fromInputStream(process.getInputStream).mkString + val errorOutput = scala.io.Source.fromInputStream(process.getErrorStream).mkString + val exitCode = process.waitFor() + + // println(s"Exit code: $exitCode") + // println(s"Output:\n$output") + // println(s"Error Output:\n$errorOutput") + + val module = Parser.parse(output) + module + } + else { throw new RuntimeException("Unsupported") } } @@ -688,14 +782,20 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { Script(cmds.toList) } - override def visitTag(ctx: TagContext): WIR = { - val name = getVar(ctx.bindVar) - val ty = visitFuncType(ctx.funcType) - Tag(name, ty) + // Function to convert a hex string representation to an Array[Byte] + def hexStringToByteArray(hex: String): Array[Byte] = { + // Split the input string by '\' and filter out empty strings + val byteStrings = hex.split("\\\\").filter(_.nonEmpty) + + byteStrings.map { byteStr => + // Parse the hex value to a byte + Integer.parseInt(byteStr, 16).toByte + } } } + object Parser { private def makeWatVisitor(input: String) = { val charStream = new ANTLRInputStream(input) diff --git a/src/test/scala/genwasym/TestScriptRun.scala b/src/test/scala/genwasym/TestScriptRun.scala index 2c1e3f9d..c2b79156 100644 --- a/src/test/scala/genwasym/TestScriptRun.scala +++ b/src/test/scala/genwasym/TestScriptRun.scala @@ -5,7 +5,6 @@ import gensym.wasm.miniwasmscript.ScriptRunner import org.scalatest.FunSuite - class TestScriptRun extends FunSuite { def testFile(filename: String): Unit = { val script = Parser.parseScriptFile(filename).get @@ -16,4 +15,9 @@ class TestScriptRun extends FunSuite { test("simple script") { testFile("./benchmarks/wasm/script/script_basic.wast") } + + test("simple bin script") { + testFile("./benchmarks/wasm/script/script_basic.bin.wast") + } + } diff --git a/src/test/scala/genwasym/TestSyntax.scala b/src/test/scala/genwasym/TestSyntax.scala new file mode 100644 index 00000000..8fecf874 --- /dev/null +++ b/src/test/scala/genwasym/TestSyntax.scala @@ -0,0 +1,17 @@ +package gensym.wasm + +import gensym.wasm.parser.Parser +import org.scalatest.FunSuite + +class TestSyntax extends FunSuite { + def testFile(filename: String) = { + val script = Parser.parseScriptFile(filename) + println(s"script = $script") + assert(script != None, "this syntax is not defined in antlr grammar") + } + + test("basic script") { + testFile("./benchmarks/wasm/script/script_basic.wabt") + } +} + diff --git a/third-party/wasmfx-tools b/third-party/wasmfx-tools new file mode 160000 index 00000000..c9218cf2 --- /dev/null +++ b/third-party/wasmfx-tools @@ -0,0 +1 @@ +Subproject commit c9218cf2f439ff38bdb6fbf172fcc237b971ea03