Skip to content

[WIP] Lua typecasting #997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import de.peeeq.wurstscript.utils.Utils;
import org.eclipse.lsp4j.jsonrpc.messages.Either;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.File;
import java.io.PrintStream;
Expand Down Expand Up @@ -226,6 +227,33 @@ private void executeCompiletimeExpr(ImCompiletimeExpr cte) {
LocalState localState = new LocalState();
ILconst value = cte.evaluate(globalState, localState);
ImExpr newExpr = constantToExpr(cte.getTrace(), value);
if(translator.isLuaTarget() && value.toString().equals("0")) {
// convert 0 to null/nil, if the value is 0 and not a numeric type
ImExpr expr = cte.getExpr();

if(expr instanceof ImNull) {
newExpr = ImHelper.nullExpr();
} else {
@Nullable ImType exprType = null;
if(expr instanceof ImFunctionCall) {
exprType = ((ImFunctionCall) expr).getFunc().getReturnType();
} else if(expr instanceof ImVarAccess) {
exprType = ((ImVarAccess)expr).getVar().getType();
} else if(expr instanceof ImVarArrayAccess) {
ImType type = ((ImVarArrayAccess)expr).getVar().getType();
if(type instanceof ImArrayLikeType) {
exprType = ((ImArrayLikeType) type).getEntryType();
}
}
if(exprType != null && !TypesHelper.isIntType(exprType) && !TypesHelper.isRealType(exprType)) {
newExpr = ImHelper.nullExpr();
}
}
// TODO is this complete? Are there more cases where 0 must be replaced?
// A function can return null
// null can be a literal
// null can be a variable
}
cte.replaceBy(newExpr);
} catch (InterpreterException e) {
String msg = ILInterpreter.buildStacktrace(globalState, e);
Expand Down Expand Up @@ -259,7 +287,17 @@ public ImVar initFor(ILconstObject obj) {
ImExprs indexesT = indexes.stream()
.map(i -> constantToExpr(trace, ILconstInt.create(i)))
.collect(Collectors.toCollection(JassIm::ImExprs));
addCompiletimeStateInit(JassIm.ImSet(trace, JassIm.ImMemberAccess(trace, JassIm.ImVarAccess(res), JassIm.ImTypeArguments(), var, indexesT), constantToExpr(trace, attrValue)));
ImExpr value2 = constantToExpr(trace, attrValue);
if(translator.isLuaTarget() && value2.toString().equals("0")) {
ImType varType = var.getType();
if(varType instanceof ImArrayLikeType) {
varType = ((ImArrayLikeType) varType).getEntryType();
}
if (!TypesHelper.isIntType(varType) && !TypesHelper.isRealType(varType)) {
value2 = ImHelper.nullExpr();
}
}
addCompiletimeStateInit(JassIm.ImSet(trace, JassIm.ImMemberAccess(trace, JassIm.ImVarAccess(res), JassIm.ImTypeArguments(), var, indexesT), value2));
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public ReflectionNativeProvider(AbstractInterpreter interpreter) {
addProvider(new ImageProvider(interpreter));
addProvider(new IntegerProvider(interpreter));
addProvider(new FrameProvider(interpreter));
addProvider(new LuaEnsureTypeProvider(interpreter));
}

public NativeJassFunction getFunctionPair(String funcName) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package de.peeeq.wurstio.jassinterpreter.providers;

import de.peeeq.wurstscript.intermediatelang.*;
import de.peeeq.wurstscript.intermediatelang.interpreter.AbstractInterpreter;

public class LuaEnsureTypeProvider extends Provider {
public LuaEnsureTypeProvider(AbstractInterpreter interpreter) {
super(interpreter);
}

public ILconstInt intEnsure(ILconstInt x) {
return x;
}

public ILconstString stringEnsure(ILconstString x) {
return x;
}

public ILconstBool boolEnsure(ILconstBool x) {
return x;
}

public ILconstReal realEnsure(ILconstReal x) {
return x;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import de.peeeq.wurstscript.utils.Utils;
import io.vavr.control.Either;
import io.vavr.control.Option;
import org.eclipse.jdt.annotation.Nullable;

import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -98,13 +99,35 @@ private static ImExpr wrapTranslation(Expr e, ImTranslator t, ImExpr translated)
return wrapTranslation(e, t, translated, actualType, expectedTypRaw);
}

static ImExpr wrapTranslation(Element trace, ImTranslator t, ImExpr translated, WurstType actualType, WurstType expectedTypRaw) {
if (t.isLuaTarget()) {
// for lua we do not need fromIndex/toIndex
return translated;
}
static ImExpr wrapLua(Element trace, ImTranslator t, ImExpr translated, WurstType actualType) {
// use ensureType functions for lua
// these functions convert nil to the default value for primitive types (int, string, bool, real)
if (t.isLuaTarget() && actualType instanceof WurstTypeBoundTypeParam) {
WurstTypeBoundTypeParam wtb = (WurstTypeBoundTypeParam) actualType;

@Nullable ImFunction ensureType = null;
switch (wtb.getName()) {
case "integer":
ensureType = t.ensureIntFunc;
break;
case "string":
ensureType = t.ensureStrFunc;
break;
case "boolean":
ensureType = t.ensureBoolFunc;
break;
case "real":
ensureType = t.ensureRealFunc;
break;
}
if(ensureType != null) {
return ImFunctionCall(trace, ensureType, ImTypeArguments(), JassIm.ImExprs(translated), false, CallType.NORMAL);
}
}
return translated;
}

static ImExpr wrapTranslation(Element trace, ImTranslator t, ImExpr translated, WurstType actualType, WurstType expectedTypRaw) {
ImFunction toIndex = null;
ImFunction fromIndex = null;
if (actualType instanceof WurstTypeBoundTypeParam) {
Expand All @@ -129,15 +152,19 @@ static ImExpr wrapTranslation(Element trace, ImTranslator t, ImExpr translated,
if (toIndex != null && fromIndex != null) {
// System.out.println(" --> cancel");
// the two conversions cancel each other out
return translated;
return wrapLua(trace, t, translated, actualType);
} else if (fromIndex != null) {
// System.out.println(" --> fromIndex");
if(t.isLuaTarget()) {
translated = ImFunctionCall(trace, t.ensureIntFunc, ImTypeArguments(), JassIm.ImExprs(translated), false, CallType.NORMAL);
}
// no ensure type necessary here, because the fromIndex function is already type safe
return ImFunctionCall(trace, fromIndex, ImTypeArguments(), JassIm.ImExprs(translated), false, CallType.NORMAL);
} else if (toIndex != null) {
// System.out.println(" --> toIndex");
return ImFunctionCall(trace, toIndex, ImTypeArguments(), JassIm.ImExprs(translated), false, CallType.NORMAL);
return wrapLua(trace, t, ImFunctionCall(trace, toIndex, ImTypeArguments(), JassIm.ImExprs(translated), false, CallType.NORMAL), actualType);
}
return translated;
return wrapLua(trace, t, translated, actualType);
}

public static ImExpr translateIntern(ExprBinary e, ImTranslator t, ImFunction f) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ public class ImTranslator {

private @Nullable ImFunction configFunc = null;

@Nullable public ImFunction ensureIntFunc = null;
@Nullable public ImFunction ensureBoolFunc = null;
@Nullable public ImFunction ensureRealFunc = null;
@Nullable public ImFunction ensureStrFunc = null;

private final Map<ImVar, VarsForTupleResult> varsForTupleVar = new LinkedHashMap<>();

private boolean isUnitTestMode;
Expand Down Expand Up @@ -122,6 +127,17 @@ public ImProg translateProg() {
debugPrintFunction = ImFunction(emptyTrace, $DEBUG_PRINT, ImTypeVars(), ImVars(JassIm.ImVar(wurstProg, WurstTypeString.instance().imTranslateType(this), "msg",
false)), ImVoid(), ImVars(), ImStmts(), flags(IS_NATIVE, IS_BJ));

if(isLuaTarget()) {
ensureIntFunc = JassIm.ImFunction(emptyTrace, "intEnsure", ImTypeVars(), ImVars(JassIm.ImVar(wurstProg, WurstTypeInt.instance().imTranslateType(this), "x", false)), WurstTypeInt.instance().imTranslateType(this), ImVars(), ImStmts(), flags(IS_NATIVE, IS_BJ));
ensureBoolFunc = JassIm.ImFunction(emptyTrace, "boolEnsure", ImTypeVars(), ImVars(JassIm.ImVar(wurstProg, WurstTypeBool.instance().imTranslateType(this), "x", false)), WurstTypeBool.instance().imTranslateType(this), ImVars(), ImStmts(), flags(IS_NATIVE, IS_BJ));
ensureRealFunc = JassIm.ImFunction(emptyTrace, "realEnsure", ImTypeVars(), ImVars(JassIm.ImVar(wurstProg, WurstTypeReal.instance().imTranslateType(this), "x", false)), WurstTypeReal.instance().imTranslateType(this), ImVars(), ImStmts(), flags(IS_NATIVE, IS_BJ));
ensureStrFunc = JassIm.ImFunction(emptyTrace, "stringEnsure", ImTypeVars(), ImVars(JassIm.ImVar(wurstProg, WurstTypeString.instance().imTranslateType(this), "x", false)), WurstTypeString.instance().imTranslateType(this), ImVars(), ImStmts(), flags(IS_NATIVE, IS_BJ));
addFunction(ensureIntFunc);
addFunction(ensureBoolFunc);
addFunction(ensureRealFunc);
addFunction(ensureStrFunc);
}

calculateCompiletimeOrder();

for (CompilationUnit cu : wurstProg) {
Expand All @@ -136,6 +152,7 @@ public ImProg translateProg() {
configFunc = ImFunction(emptyTrace, "config", ImTypeVars(), ImVars(), ImVoid(), ImVars(), ImStmts(), flags());
addFunction(configFunc);
}

finishInitFunctions();
EliminateCallFunctionsWithAnnotation.process(imProg);
removeDuplicateNatives(imProg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ public static void print(LuaStatements stmts, StringBuilder sb, int indent) {
}

public static void print(LuaTableConstructor e, StringBuilder sb, int indent) {
sb.append("{");
sb.append("({");
e.getTableFields().print(sb, indent);
sb.append("}");
sb.append("})");
}

public static void print(LuaTableExprField e, StringBuilder sb, int indent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ public LuaMethod initFor(ImClass a) {

LuaFunction instanceOfFunction = LuaAst.LuaFunction(uniqueName("isInstanceOf"), LuaAst.LuaParams(), LuaAst.LuaStatements());

LuaFunction ensureIntFunction = LuaAst.LuaFunction(uniqueName("intEnsure"), LuaAst.LuaParams(), LuaAst.LuaStatements());
LuaFunction ensureStrFunction = LuaAst.LuaFunction(uniqueName("stringEnsure"), LuaAst.LuaParams(), LuaAst.LuaStatements());
LuaFunction ensureBoolFunction = LuaAst.LuaFunction(uniqueName("boolEnsure"), LuaAst.LuaParams(), LuaAst.LuaStatements());
LuaFunction ensureRealFunction = LuaAst.LuaFunction(uniqueName("realEnsure"), LuaAst.LuaParams(), LuaAst.LuaStatements());

private final Lazy<LuaFunction> errorFunc = Lazy.create(() ->
this.getProg().getFunctions().stream()
.flatMap(f -> {
Expand Down Expand Up @@ -178,6 +183,7 @@ public LuaCompilationUnit translate() {
createStringConcatFunction();
createInstanceOfFunction();
createObjectIndexFunctions();
createEnsureTypeFunctions();

for (ImVar v : prog.getGlobals()) {
translateGlobal(v);
Expand Down Expand Up @@ -379,6 +385,26 @@ function defaultArray(d)
luaModel.add(arrayInitFunction);
}

private void createEnsureTypeFunctions() {
LuaFunction[] ensureTypeFunctions = {ensureIntFunction, ensureBoolFunction, ensureRealFunction, ensureStrFunction};
String[] defaultValue = {"0", "false", "0.0", "\"\""};
for(int i = 0; i < ensureTypeFunctions.length; ++i) {
String[] code = {
"if x == nil then",
" return " + defaultValue[i],
"else",
" return x",
"end"
};

ensureTypeFunctions[i].getParams().add(LuaAst.LuaVariable("x", LuaAst.LuaNoExpr()));
for (String c : code) {
ensureTypeFunctions[i].getBody().add(LuaAst.LuaLiteral(c));
}
luaModel.add(ensureTypeFunctions[i]);
}
}

private void cleanStatements() {
luaModel.accept(new LuaModel.DefaultVisitor() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,37 @@
import java.io.File;
import java.io.IOException;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.testng.AssertJUnit.*;


public class LuaTranslationTests extends WurstScriptTest {

private void assertFunctionReturns(String output, String functionName, String returnValue) {
/*
The function functionName must only return returnValue and do nothing else.
*/
Pattern pattern = Pattern.compile("function\\s+" + functionName + "\\(.*?\\)\\s+return\\s+" + returnValue + "\\s+end");
Matcher matcher = pattern.matcher(output);
assertTrue("Function " + functionName + " with return value " + returnValue + " was not found.", matcher.find());
}

private void assertFunctionCall(String output, String functionName, String arguments) {
/*
The function declaration is ignored by using negative lookbehind.
All function calls must use the specified arguments.
*/
Pattern pattern = Pattern.compile("(?<!\\sfunction\\s)" + functionName + "\\((.*)\\)");
Matcher matcher = pattern.matcher(output);
boolean findAtLeastOne = false;
while (matcher.find()) {
assertEquals(arguments, matcher.group(1));
findAtLeastOne = true;
}
assertTrue("Function call to function " + functionName + " with arguments (" + arguments + ") was not found.", findAtLeastOne);
}

@Test
public void testStdLib() throws IOException {
Expand Down Expand Up @@ -47,7 +73,7 @@ public void nullString1() throws IOException {
" nullString()"
);
String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nullString1.lua"), Charsets.UTF_8);
assertTrue(compiled.contains("return \"\"") && !compiled.contains("return nil"));
assertFunctionReturns(compiled, "nullString", "\"\"");
}

@Test
Expand All @@ -59,7 +85,7 @@ public void nullString2() throws IOException {
" takesString(null)"
);
String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nullString2.lua"), Charsets.UTF_8);
assertTrue(compiled.contains("takesString(\"\")") && !compiled.contains("takesString(nil)"));
assertFunctionCall(compiled, "takesString", "\"\"");
}

@Ignore
Expand Down Expand Up @@ -91,7 +117,7 @@ public void nullObject1() throws IOException {
" nullObject()"
);
String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nullObject1.lua"), Charsets.UTF_8);
assertTrue(compiled.contains("return nil") && !compiled.contains("return \"\""));
assertFunctionReturns(compiled, "nullObject", "nil");
}

@Test
Expand All @@ -104,7 +130,7 @@ public void nullObject2() throws IOException {
" takesObject(null)"
);
String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nullObject2.lua"), Charsets.UTF_8);
assertTrue(compiled.contains("takesObject(nil)") && !compiled.contains("takesObject(\"\")"));
assertFunctionCall(compiled, "takesObject", "nil");
}

@Test
Expand All @@ -117,7 +143,7 @@ public void nullUnit1() throws IOException {
" nullUnit()"
);
String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nullUnit1.lua"), Charsets.UTF_8);
assertTrue(compiled.contains("return nil") && !compiled.contains("return \"\""));
assertFunctionReturns(compiled, "nullUnit", "nil");
}

@Test
Expand All @@ -129,7 +155,7 @@ public void nullUnit2() throws IOException {
" takesUnit(null)"
);
String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_nullUnit2.lua"), Charsets.UTF_8);
assertTrue(compiled.contains("takesUnit(nil)") && !compiled.contains("takesUnit(\"\")"));
assertFunctionCall(compiled, "takesUnit", "nil");
}
}

Loading