diff --git a/resources/tests/GanacheTests/tests.json b/resources/tests/GanacheTests/tests.json index a4244072..d6d61477 100644 --- a/resources/tests/GanacheTests/tests.json +++ b/resources/tests/GanacheTests/tests.json @@ -279,11 +279,6 @@ "expected": "5743", "deploy_logged" : [-57896044618658097711785492504343953926634992332820282019728792003956564819808,5738], "shows_that_we_support": "default constructors + values get written to storage and resurrected on subsequent calls" - }, - { - "file": "TraceThis.obs", - "expected": "-2438", - "shows_that_we_support": "emitting tracers and default constructors for more complicated nestings of objects. note that this emits 30 log messages between deployment and invocation; they look right but haven't been checked carefully" } ] } diff --git a/src/main/scala/edu/cmu/cs/obsidian/codegen/CodeGenYul.scala b/src/main/scala/edu/cmu/cs/obsidian/codegen/CodeGenYul.scala index 8a44808c..07d890c4 100644 --- a/src/main/scala/edu/cmu/cs/obsidian/codegen/CodeGenYul.scala +++ b/src/main/scala/edu/cmu/cs/obsidian/codegen/CodeGenYul.scala @@ -450,13 +450,39 @@ object CodeGenYul extends CodeGenerator { val id = nextTemp() val e_yul = translateExpr(id, e, contractName, checkedTable) val ct = checkedTable.contractLookup(contractName) - decl_0exp(id) +: - e_yul :+ - (if (ct.allFields.exists(f => f.name.equals(x))) { - updateField(ct, x, id) + + val field_address = fieldFromThis(ct, x) + + // look at assignTo. if it's a storage address, then look at the type of e. + // if it's a primitive, ignore it. if it's a contract reference, it needs + // to get traced first + val trace_for_e: Seq[YulStatement] = e.obstype match { + case Some(value) => value match { + case _: PrimitiveType => Seq() + case t: NonPrimitiveType => t match { + case ContractReferenceType(contractType, _, _) => // todo + Seq( + edu.cmu.cs.obsidian.codegen.If(apply("not", compareToThresholdExp(field_address)), + Block(Do(apply(nameTracer(contractType.contractName), id)))) + ) + case StateType(_, _, _) => assert(assertion = false, "not yet implemented"); Seq() + case InterfaceContractType(_, _) => assert(assertion = false, "not yet implemented"); Seq() + case GenericType(_, _) => assert(assertion = false, "not yet implemented"); Seq() + } + case BottomType() => Seq() + } + case None => assert(assertion = false, "encountered an expression without a type annotation"); Seq() + } + + val update_instructions: Seq[YulStatement] = + if (ct.allFields.exists(f => f.name.equals(x))) { + trace_for_e :+ updateField(ct, x, id) } else { - assign1(Identifier(x), id) - }) + Seq(assign1(Identifier(x), id)) + } + + decl_0exp(id) +: (e_yul ++ update_instructions) + case _ => assert(assertion = false, "trying to assign to non-assignable: " + e.toString) Seq() @@ -614,7 +640,7 @@ object CodeGenYul extends CodeGenerator { // here or not. ids.map(id => decl_0exp(id)) ++ seqs.flatten ++ (width match { - case 0 => Seq(ExpressionStatement(FunctionCall(Identifier(name), thisID +: ids))) + case 0 => Do(FunctionCall(Identifier(name), thisID +: ids)) case 1 => val id: Identifier = nextTemp() Seq(decl_1exp(id, FunctionCall(Identifier(name), thisID +: ids)), assign1(retvar, id)) @@ -747,7 +773,7 @@ object CodeGenYul extends CodeGenerator { // we only call the tracer after the constructor for the main contract val traceCall = if (isMainContract) { - Seq(ExpressionStatement(apply(nameTracer(contractType.contractName), Identifier("this")))) + Do(apply(nameTracer(contractType.contractName), Identifier("this"))) } else { Seq() } diff --git a/src/main/scala/edu/cmu/cs/obsidian/codegen/Util.scala b/src/main/scala/edu/cmu/cs/obsidian/codegen/Util.scala index 3c778e20..baee30ab 100644 --- a/src/main/scala/edu/cmu/cs/obsidian/codegen/Util.scala +++ b/src/main/scala/edu/cmu/cs/obsidian/codegen/Util.scala @@ -95,7 +95,7 @@ object Util { * @return the yul if-statement doing the check */ def revertIf(cond: Expression): YulStatement = - edu.cmu.cs.obsidian.codegen.If(cond, Block(Seq(ExpressionStatement(apply("revert", intlit(0), intlit(0)))))) + edu.cmu.cs.obsidian.codegen.If(cond, Block(Do(apply("revert", intlit(0), intlit(0))))) /** * @return the yul call value check statement, which makes sure that funds are not spent inappropriately @@ -111,7 +111,7 @@ object Util { * @return the yul if-statement doing the check */ def revertForwardIfZero(id: Expression): YulStatement = - edu.cmu.cs.obsidian.codegen.If(apply("iszero", id), Block(Seq(ExpressionStatement(apply("revert_forward_1"))))) + edu.cmu.cs.obsidian.codegen.If(apply("iszero", id), Block(Do(apply("revert_forward_1")))) /** * shorthand for building yul assignment statements, here assigning one expression to just one @@ -404,8 +404,8 @@ object Util { def updateField(ct: ContractTable, fieldName: String, value: Expression): YulStatement = { val address_of_field: Expression = fieldFromThis(ct, fieldName) ifInStorge(addr_to_check = address_of_field, - true_case = Seq(ExpressionStatement(apply("sstore", address_of_field, value))), - false_case = Seq(ExpressionStatement(apply("mstore", address_of_field, value))) + true_case = Do(apply("sstore", address_of_field, value)), // todo double check why there isn't a mapToStorageAddr call here; i think it's because you only end up in this branch when `this` is already big enough + false_case = Do(apply("mstore", address_of_field, value)) ) } @@ -595,6 +595,24 @@ object Util { apply("add", x, storage_threshold) } + /** the returned expression evaluates to true iff the argument is above or equal to the storage threshold + * + * @param addr + * @return + */ + def compareToThresholdExp(addr: Expression): Expression = { + apply("gt", addr, apply("sub", storage_threshold, intlit(1))) + } + + /** convenience function for taking an expression and building a singleton statement sequence + * + * @param x + * @return + */ + def Do(x: Expression): Seq[YulStatement] = { + Seq(ExpressionStatement(x)) + } + /** given an expression that represents an address, compute the yul statement that checks if it's * a storage address or not and execute a sequence of statements in either case. if it does not * represent an address, the behaviour is undefined. @@ -606,7 +624,7 @@ object Util { */ def ifInStorge(addr_to_check: Expression, true_case: Seq[YulStatement], false_case: Seq[YulStatement]): YulStatement = { // note: gt(x,y) is x > y; to get x >= y, subtract 1 since they're integers - edu.cmu.cs.obsidian.codegen.Switch(apply("gt", addr_to_check, apply("sub", storage_threshold, intlit(1))), + edu.cmu.cs.obsidian.codegen.Switch(compareToThresholdExp(addr_to_check), Seq(Case(boollit(true), Block(true_case)), Case(boollit(false), Block(false_case)))) } @@ -650,6 +668,4 @@ object Util { case YATContractName(name) => throw new RuntimeException("can't query for a default contract value") } } - - }