From bf83440853dbd87bfc317399b5f61ce12948b23d Mon Sep 17 00:00:00 2001 From: Charles Sherk Date: Mon, 14 Feb 2022 01:29:13 -0500 Subject: [PATCH] monolithic commit. type inference now uses Z3 --- build.sbt | 4 +- generic_funcs_pass.pdl | 59 -------- src/main/scala/pipedsl/Main.scala | 2 + src/main/scala/pipedsl/Parser.scala | 2 +- .../codegen/bsv/BSVPrettyPrinter.scala | 8 +- .../scala/pipedsl/codegen/bsv/BSVSyntax.scala | 11 +- .../codegen/bsv/ConstraintsToBluespec.scala | 88 +++++++++++ .../scala/pipedsl/common/Constraints.scala | 142 ++++++++++++++++++ src/main/scala/pipedsl/common/Errors.scala | 7 + src/main/scala/pipedsl/common/Syntax.scala | 25 ++- src/main/scala/pipedsl/common/Utilities.scala | 35 +++-- .../FunctionConstraintChecker.scala | 128 ++++++++++++++++ .../typechecker/TypeInferenceWrapper.scala | 102 ++++++++----- .../typecheckTests/generic_func_fail.pdl | 9 +- .../typecheckTests/generic_funcs_passz3.pdl | 66 ++++++++ .../generic_funcs_passz3.typechecksol | 1 + 16 files changed, 572 insertions(+), 117 deletions(-) delete mode 100644 generic_funcs_pass.pdl create mode 100644 src/main/scala/pipedsl/codegen/bsv/ConstraintsToBluespec.scala create mode 100644 src/main/scala/pipedsl/common/Constraints.scala create mode 100644 src/main/scala/pipedsl/typechecker/FunctionConstraintChecker.scala create mode 100644 src/test/tests/typecheckTests/generic_funcs_passz3.pdl create mode 100644 src/test/tests/typecheckTests/solutions/generic_funcs_passz3.typechecksol diff --git a/build.sbt b/build.sbt index e4b1fcf7..af50e78c 100644 --- a/build.sbt +++ b/build.sbt @@ -24,7 +24,9 @@ libraryDependencies ++= Seq( "org.scalactic" %% "scalactic" % "3.2.2", ) +scalacOptions += "-language:implicitConversions" + //Deployment Options assemblyJarName in assembly := "pdl.jar" test in assembly := {} -mainClass in assembly := Some("pipedsl.Main") \ No newline at end of file +mainClass in assembly := Some("pipedsl.Main") diff --git a/generic_funcs_pass.pdl b/generic_funcs_pass.pdl deleted file mode 100644 index ec97f65b..00000000 --- a/generic_funcs_pass.pdl +++ /dev/null @@ -1,59 +0,0 @@ -// def error(inpt :int<10>) :bool -// { -// shamt = cast(inpt{4:0}, uint<5>); -// return true; -// } - -// def extract(inpt :uint<32>) :bool -// { -// uint<5> a = inpt{0:4}; -// return true; -// } - -// def iden(a :int) :int -// { -// return a; -// } - - -// def adder(x :int, y :int) :int -// { -// return x + iden(y); -// } - -// def my_concat(a :int, b :int, c :int) :int -// { -// tmp = a ++ b ++ c; -// other = b ++ a ++ c; -// return tmp + other; -// } - -// def indexing(x :int, y :int) :int<5> -// { -// tmp = (x ++ y){4:0}; -// return (x ++ y){4:0}; -// } - -// // def in_scope(x :int, y :int) :int<1> -// // { -// // return x{J:0}; -// // } - - -// def infty(x :int<5>, y :int<5>) :int<10> -// { -// return x ++ y; -// } - -pipe test1(inpt: uint<32>)[rf: int<32>[32]] :int<100> -{ - uint<5> a = inpt{0:4}; - //output (adder(3, 4)); - output (4); -} - - - -circuit { - r = memory(int<32>, 32); -} \ No newline at end of file diff --git a/src/main/scala/pipedsl/Main.scala b/src/main/scala/pipedsl/Main.scala index 06f13770..190ab0f5 100644 --- a/src/main/scala/pipedsl/Main.scala +++ b/src/main/scala/pipedsl/Main.scala @@ -78,9 +78,11 @@ object Main { val canonProg2 = new CanonicalizePass().run(verifProg) //new PrettyPrinter(None).printProgram(canonProg2) val canonProg1 = new TypeInference(autocast).checkProgram(canonProg2) + val canonProg = canonProg1//LockOpTranslationPass.run(canonProg1) //new PrettyPrinter(None).printProgram(canonProg) val basetypes = BaseTypeChecker.check(canonProg, None) + FunctionConstraintChecker.check(canonProg) val nprog = new BindModuleTypes(basetypes).run(canonProg) MarkNonRecursiveModulePass.run(nprog) val recvProg = SimplifyRecvPass.run(nprog) diff --git a/src/main/scala/pipedsl/Parser.scala b/src/main/scala/pipedsl/Parser.scala index 76fd560b..7eb3308f 100644 --- a/src/main/scala/pipedsl/Parser.scala +++ b/src/main/scala/pipedsl/Parser.scala @@ -129,7 +129,7 @@ class Parser(rflockImpl: String) extends RegexParsers with PackratParsers { lazy val indexAtom :P[EIndex] = dlog(positioned { - iden ^^ { id => EIndVar(id).setType(TInteger()) } | + iden ^^ { id => EIndVar(Id(generic_type_prefix + id.v)).setType(TInteger()) } | posint ^^ { n => EIndConst(n).setType(TInteger()) } })("index atom") diff --git a/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala b/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala index 4d4849d1..4a690b80 100644 --- a/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala +++ b/src/main/scala/pipedsl/codegen/bsv/BSVPrettyPrinter.scala @@ -15,10 +15,12 @@ object BSVPrettyPrinter { mkExprString(toBSVTypeStr(v.typ), v.name) } - private def toProvisoString(name: String, p: Proviso): String = p match { + def toProvisoString(name: String, p: Proviso): String = p match { case PBits(szName) => "Bits#(" + name + "," + szName +")" case PAdd(num1, num2, sum) => s"Add#($num1, $num2, $sum)" case PMin(name, min) => s"Min#($name, $min, $min)" + case PMax(num1, num2, max) => s"Max#($num1, $num2, $max)" + case PEq(num1, num2) => s"Add#($num1, 0, $num2)" } private def getTypeParams(typ: BSVType): Set[BTypeParam] = typ match { @@ -277,9 +279,9 @@ object BSVPrettyPrinter { func.name, "(", paramstr, ")", if(func.provisos.nonEmpty) { - " provisos(" + + "\n\tprovisos(" + func.provisos.tail.foldLeft(toProvisoString("", func.provisos.head))((str, proviso) => - str + ", " + toProvisoString("", proviso)) + + str + ",\n\t\t" + toProvisoString("", proviso)) + ")" } else "" )) diff --git a/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala b/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala index e1c49654..9dbd58bf 100644 --- a/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala +++ b/src/main/scala/pipedsl/codegen/bsv/BSVSyntax.scala @@ -1,11 +1,13 @@ package pipedsl.codegen.bsv import pipedsl.codegen.Translations.Translator +import pipedsl.codegen.bsv.ConstraintsToBluespec.to_provisos import pipedsl.common.Errors.{MissingType, UnexpectedBSVType, UnexpectedCommand, UnexpectedExpr, UnexpectedType} import pipedsl.common.{LockImplementation, Syntax} import pipedsl.common.LockImplementation.LockInterface import pipedsl.common.Syntax.Latency.{Asynchronous, Combinational} import pipedsl.common.Syntax._ +import pipedsl.common.Utilities.generic_type_prefix object BSVSyntax { @@ -18,6 +20,8 @@ object BSVSyntax { case class PBits(szName: String) extends Proviso case class PAdd(num1 :String, num2 :String, sum :String) extends Proviso case class PMin(name :String, min :Int) extends Proviso + case class PMax(num1 :String, num2 :String, max :String) extends Proviso + case class PEq(num1 :String, num2 :String) extends Proviso sealed trait BSVType case class BNumericType(sz: Int) extends BSVType @@ -145,7 +149,7 @@ object BSVSyntax { case EIndConst(v) => BIndConst(v) case EIndAdd(l, r) => BIndAdd(toBSVIndex(l), toBSVIndex(r)) case EIndSub(l, r) => BIndSub(toBSVIndex(l), toBSVIndex(r)) - case EIndVar(id) => BIndVar(id.v) + case EIndVar(id) => BIndVar("val" + id.v) } def toExpr(e: Expr): BExpr = e match { @@ -318,7 +322,8 @@ object BSVSyntax { private def getProvisos(b :FuncDef) :List[Proviso] = { val tmp = (b.adds.toList.map(pairid => PAdd(pairid._1._1, pairid._1._2, pairid._2.v)) ++ - b.mins.toList.map(pair => PMin(pair._1, pair._2))).distinct + b.mins.toList.map(pair => PMin(pair._1, pair._2))).distinct ++ + to_provisos(b.constraints) tmp } @@ -326,7 +331,7 @@ object BSVSyntax { { b.templateTypes.map(id => { - val tmp = Id(id.v + "_val") + val tmp = Id("val" + id.v) tmp.typ = Some(TInteger()) BDecl(toVar(tmp), Some(BValueOf(id.v))) diff --git a/src/main/scala/pipedsl/codegen/bsv/ConstraintsToBluespec.scala b/src/main/scala/pipedsl/codegen/bsv/ConstraintsToBluespec.scala new file mode 100644 index 00000000..db1f6162 --- /dev/null +++ b/src/main/scala/pipedsl/codegen/bsv/ConstraintsToBluespec.scala @@ -0,0 +1,88 @@ +package pipedsl.codegen.bsv + +import pipedsl.common.Constraints._ +import pipedsl.codegen.bsv.BSVSyntax.{PAdd, PEq, PMax, Proviso} +import pipedsl.common.Syntax.Id + +/** + * code to turn a list of constraints into a list of Bluespec Provisos + */ +object ConstraintsToBluespec + { + + var globl_cnt = 1 + + def freshVar() :Id = + { + globl_cnt += 1; + Id("_proviso_fresh" + globl_cnt) + } + + def collect_vars(e :IntExpr) :Set[Id] = e match + { + case IntConst(v) => Set() + case IntVar(id) => Set(id) + case IntAdd(a, b) => collect_vars(a).union(collect_vars(b)) + case IntSub(a, b) => collect_vars(a).union(collect_vars(b)) + case IntMax(a, b) => collect_vars(a).union(collect_vars(b)) + } + + def collect_vars(c :Constraint) : Set[Id] = c match + { + case RelLt(a, b) => collect_vars(a).union(collect_vars(b)) + case ReGe(a, b) => collect_vars(a).union(collect_vars(b)) + case ReEq(a, b) => collect_vars(a).union(collect_vars(b)) + } + + def tile_one(c :IntExpr) : (List[Proviso], IntValue) = + { + def helper(a :IntExpr, b :IntExpr, tp :(String, String, String) => Proviso) = + { + val frsh = freshVar() + val (prov_left, left_dest) = tile_one(a) + val (prov_right, right_dest) = tile_one(b) + ((prov_left ++ prov_right).prepended(tp(left_dest.toString, right_dest.toString, frsh.v)), + IntVar(frsh)) + } + c match + { + case c :IntConst=> (List(), c) + case c :IntVar => (List(), c) + case IntAdd(a, b) => helper(a, b, PAdd) + case IntSub(a, b) => helper(a, b, (left, right, dest) => PAdd(right, dest, left)) + case IntMax(a, b) => helper(a, b, PMax) + } + } + + def to_provisos_one(c :Constraint) :List[Proviso] = + { + c match + { + //max(a, b) = b <=> b >= a + //b > a <=> b >= a + 1) + //a < b <=> a + 1 <= b + //a + 1 <= b <=> (max (1+ a) b) = b + case RelLt(a, b) => + val (provs_left, left) = tile_one(a) + val (provs_right, right) = tile_one(b) + val frsh = freshVar() + val one_p_left = PAdd("1", left.toString, frsh.v) + val prov = PMax(frsh.v, right.toString, right.toString) + (provs_left ++ provs_right).prependedAll(List(one_p_left, prov)) + case ReGe(a, b) => + val (provs_left, left) = tile_one(a) + val (provs_right, right) = tile_one(b) + val prov = PMax(left.toString, right.toString, left.toString) + (provs_left ++ provs_right).prepended(prov) + case ReEq(a, b) => + val (provs_left, left) = tile_one(a) + val (provs_right, right) = tile_one(b) + val prov = PEq(left.toString, right.toString) + (provs_left ++ provs_right).prepended(prov) + } + } + + def to_provisos(cstrts :List[Constraint]) :List[Proviso] = + cstrts.flatMap(to_provisos_one).distinct + + } diff --git a/src/main/scala/pipedsl/common/Constraints.scala b/src/main/scala/pipedsl/common/Constraints.scala new file mode 100644 index 00000000..29a31935 --- /dev/null +++ b/src/main/scala/pipedsl/common/Constraints.scala @@ -0,0 +1,142 @@ +package pipedsl.common + +import pipedsl.common.Syntax.{EIndAdd, EIndConst, EIndSub, EIndex, Id, TBitWidth, TNamedType} +import com.microsoft.z3.{Status, AST => Z3AST, ArithExpr => Z3ArithExpr, BoolExpr => Z3BoolExpr, Context => Z3Context, Expr => Z3Expr, IntExpr => Z3IntExpr, Solver => Z3Solver} +import pipedsl.common.Errors.UnsatisfiableConstraint +import pipedsl.common.Utilities.{degenerify, generic_type_prefix, not_gen_pref} + +import scala.collection.mutable +import scala.language.implicitConversions + + +object Constraints + { + sealed trait Constraint + sealed trait IntRel extends Constraint + + sealed trait IntExpr + + sealed trait IntValue extends IntExpr + + case class IntConst(v :Int) extends IntValue + { + override def toString: String = v.toString + } + case class IntVar(id :Id) extends IntValue + { + override def toString: String = id.v + } + case class IntAdd(a :IntExpr, b :IntExpr) extends IntExpr + case class IntSub(a :IntExpr, b :IntExpr) extends IntExpr + case class IntMax(a :IntExpr, b :IntExpr) extends IntExpr + + case class RelLt(a :IntExpr, b :IntExpr) extends Constraint + case class ReGe(a :IntExpr, b :IntExpr) extends Constraint + case class ReEq(a :IntExpr, b :IntExpr) extends Constraint + + def degenerify_expr(i :IntExpr) :IntExpr = i match + { + case _ :IntConst => i + case IntVar(id) => IntVar(Id(not_gen_pref + id.v)) + case IntAdd(a, b) => IntAdd(degenerify_expr(a), degenerify_expr(b)) + case IntSub(a, b) => IntSub(degenerify_expr(a), degenerify_expr(b)) + case IntMax(a, b) => IntMax(degenerify_expr(a), degenerify_expr(b)) + } + + def degenerify_constr(c :Constraint) :Constraint = c match + { + case RelLt(a, b) => RelLt(degenerify_expr(a), degenerify_expr(b)) + case ReGe(a, b) => ReGe(degenerify_expr(a), degenerify_expr(b)) + case ReEq(a, b) => ReEq(degenerify_expr(a), degenerify_expr(b)) + } + + + def to_z3(ctxt : Z3Context, cons :Constraint) :Z3BoolExpr = + { + cons match + { + case RelLt(a, b) => ctxt.mkLt(to_z3(ctxt, a), to_z3(ctxt, b)) + case ReEq(a, b) => ctxt.mkEq(to_z3(ctxt, a), to_z3(ctxt, b)) + case ReGe(a, b) => ctxt.mkGe(to_z3(ctxt,a ), to_z3(ctxt, b)) + } + } + + def to_z3(ctxt :Z3Context, expr :IntExpr) :Z3ArithExpr = expr match + { + case IntConst(v) => ctxt.mkInt(v) + case IntVar(id) => ctxt.mkIntConst(id.v) + case IntAdd(a, b) => ctxt.mkAdd(to_z3(ctxt, a), to_z3(ctxt, b)) + case IntSub(a, b) => ctxt.mkSub(to_z3(ctxt, a), to_z3(ctxt, b)) + //(max a b) = (if (> a b) a b) + case IntMax(a, b) => + val z3a = to_z3(ctxt, a); val z3b = to_z3(ctxt, b) + ctxt.mkITE(ctxt.mkGt(z3a, z3b), z3a, z3b).asInstanceOf[Z3ArithExpr] + + } + + object ImplicitConstraints + { + implicit def toConstraint(i :Int) :IntExpr = IntConst(i) + + implicit def toConstraint(w :TBitWidth) :IntExpr = w match + { + case Syntax.TBitWidthVar(name) =>IntVar(name) + case Syntax.TBitWidthLen(len) => IntConst(len) + case Syntax.TBitWidthAdd(b1, b2) => IntAdd(toConstraint(b1), toConstraint(b2)) + case Syntax.TBitWidthSub(b1, b2) => IntSub(toConstraint(b1), toConstraint(b2)) + case Syntax.TBitWidthMax(b1, b2) => IntMax(toConstraint(b1), toConstraint(b2)) + } + + implicit def toConstraint(i : EIndex) :IntExpr = i match + { + case EIndConst(v) => IntConst(v) + case Syntax.EIndAdd(l, r) => IntAdd(toConstraint(l), toConstraint(r)) + case Syntax.EIndSub(l, r) => IntSub(toConstraint(l), toConstraint(r)) + case Syntax.EIndVar(id) => IntVar(Id(id.v)) + } + } + + case class NotConst() extends RuntimeException + + def eval_const(expr: IntExpr) :Int = expr match + { + case value: IntValue => value match + { + case IntConst(v) => v + case IntVar(_) => throw NotConst() + } + case IntAdd(a, b) => eval_const(a) + eval_const(b) + case IntSub(a, b) => eval_const(a) - eval_const(b) + case IntMax(a, b) => math.max(eval_const(a), eval_const(b)) + } + + def check_const(c :Constraint) :Unit = c match + { + case RelLt(a, b) => if (! (eval_const(a) < eval_const(b))) + throw UnsatisfiableConstraint(c) + case ReGe(a, b) => if (! (eval_const(a) >= eval_const(b))) + throw UnsatisfiableConstraint(c) + case ReEq(a, b) => if (! (eval_const(a) == eval_const(b))) + throw UnsatisfiableConstraint(c) + } + + def reduce_constraint_list(lst :List[Constraint]) :List[Constraint] = + { + val set = mutable.HashSet[Constraint]() + lst.foreach(c => + { + if(!set.contains(c)) + { + try + { + check_const(c) + } catch + { + case _: NotConst => set.add(c) + } + } + }) + set.toList + } + + } diff --git a/src/main/scala/pipedsl/common/Errors.scala b/src/main/scala/pipedsl/common/Errors.scala index 8ee53836..45f8c960 100644 --- a/src/main/scala/pipedsl/common/Errors.scala +++ b/src/main/scala/pipedsl/common/Errors.scala @@ -198,9 +198,16 @@ object Errors { withPos(s"Not enough constraints provided to infer types. Found error at $e", e.pos) ) + case class UnsatisfiableConstraint(c :Any) extends RuntimeException( + s"Cannot satisfy constraint $c" + ) + case class IntWidthNotSpecified() extends RuntimeException case class NotSoGenericAreWe(id :Id, needed :Type) extends RuntimeException( withPos(s"Generic type $id should not need to be set to ${needed}.", id.pos) ) + case class BadConstraintsAtCall(app :EApp) extends RuntimeException( + withPos(s"Constraints for $app not satisfied", app.pos) + ) } diff --git a/src/main/scala/pipedsl/common/Syntax.scala b/src/main/scala/pipedsl/common/Syntax.scala index d7f24dd6..a43e21a5 100644 --- a/src/main/scala/pipedsl/common/Syntax.scala +++ b/src/main/scala/pipedsl/common/Syntax.scala @@ -6,6 +6,7 @@ import pipedsl.common.LockImplementation.LockInterface import pipedsl.common.Locks.{General, LockGranularity, LockState} import com.microsoft.z3.BoolExpr import pipedsl.common.Syntax.EIndConst +import pipedsl.common.Utilities.{generic_type_prefix, is_my_generic} import pipedsl.typechecker.Subtypes import scala.collection.mutable @@ -53,6 +54,7 @@ object Syntax { { val adds: mutable.Map[(String, String), Id] = mutable.HashMap[(String, String), Id]() val mins: mutable.Map[String, Int] = mutable.HashMap[String, Int]() + var constraints: List[Constraints.Constraint] = List() } } @@ -357,7 +359,11 @@ object Syntax { override def stringRep(): String = { val lst = (b1.stringRep() :: b2.stringRep() :: Nil).sorted - lst.head + "_ADD_" + lst(1) + val tmp = "A" + lst.head + "_" + lst(1) + "A" + if (is_my_generic(lst.head, accept_lit = true) && is_my_generic(lst(1), accept_lit = true)) + generic_type_prefix + tmp + else + tmp } } object TBitWidthAdd @@ -375,8 +381,11 @@ object Syntax { { override def stringRep(): String = { - val lst = (b1.stringRep() :: b2.stringRep() :: Nil).sorted - lst.head + "_SUB_" + lst(1) + val tmp = "S" + b1.stringRep() + "_" + b2.stringRep() + "S" + if (is_my_generic(b1, accept_lit = true) && is_my_generic(b2, accept_lit = true)) + generic_type_prefix + tmp + else + tmp } } object TBitWidthSub @@ -395,7 +404,15 @@ object Syntax { case class TBitWidthMax(b1: TBitWidth, b2: TBitWidth) extends TBitWidth { override def stringRep(): String = - b1.stringRep() + "_max_" + b2.stringRep() + { + val lst = (b1.stringRep() :: b2.stringRep() :: Nil).sorted + val tmp = "M" + lst.head + "_" + lst(1) + "M" + if (is_my_generic(lst.head, accept_lit = true) && is_my_generic(lst(1), accept_lit = true)) + generic_type_prefix + tmp + else + tmp + } + } object TBitWidthMax { diff --git a/src/main/scala/pipedsl/common/Utilities.scala b/src/main/scala/pipedsl/common/Utilities.scala index 5d02b892..7f032230 100644 --- a/src/main/scala/pipedsl/common/Utilities.scala +++ b/src/main/scala/pipedsl/common/Utilities.scala @@ -728,20 +728,37 @@ object Utilities { case s:String => s.startsWith(generic_type_prefix) } + def string_is_int (s :String) :Boolean= + { + try + { + s.toInt; true + } catch + { + case _: Throwable => false + } + } + /** * check if a type represents a generic, and is one of THE generics of the * current function. (There are generic types of other called functions, but * they are specialized, and can be distinguished) */ - def is_my_generic(t :Any) :Boolean = t match { - case TNamedType(name) => name.v.startsWith(generic_type_prefix) && !name.v.endsWith("*") - case TSizedInt(l, _) => is_my_generic(t) - case TBitWidthAdd(b1, b2) => is_my_generic(b1) || is_my_generic(b2) - case TBitWidthVar(name) => name.v.startsWith(generic_type_prefix) && !name.v.endsWith("*") - case _:Type => false - case name:Id => name.v.startsWith(generic_type_prefix) && !name.v.endsWith("*") - case name:String => name.startsWith(generic_type_prefix) && !name.endsWith("*") - } + def is_my_generic(t :Any, accept_lit :Boolean = false) :Boolean = + { + val tmp = t match { + case TNamedType(name) => name.v.startsWith(generic_type_prefix) && !name.v.endsWith("*") + case TSizedInt(l, _) => is_my_generic(l, accept_lit) + case TBitWidthAdd(b1, b2) => is_my_generic(b1, accept_lit) || is_my_generic(b2, accept_lit) + case TBitWidthVar(name) => name.v.startsWith(generic_type_prefix) && !name.v.endsWith("*") + case _:Type => accept_lit + case name:Id => name.v.startsWith(generic_type_prefix) && !name.v.endsWith("*") + case name:String if string_is_int(name) => accept_lit + case name:String => name.startsWith(generic_type_prefix) && !name.endsWith("*") + } + tmp + } + val not_gen_pref = "__NOT" diff --git a/src/main/scala/pipedsl/typechecker/FunctionConstraintChecker.scala b/src/main/scala/pipedsl/typechecker/FunctionConstraintChecker.scala new file mode 100644 index 00000000..be15eae1 --- /dev/null +++ b/src/main/scala/pipedsl/typechecker/FunctionConstraintChecker.scala @@ -0,0 +1,128 @@ +package pipedsl.typechecker + +import com.microsoft.z3.{Status, Context => Z3Context, Solver => Z3Solver} +import pipedsl.common.Constraints.ImplicitConstraints._ +import pipedsl.common.Constraints._ +import pipedsl.common.Errors.{BadConstraintsAtCall, MissingType} +import pipedsl.common.Syntax._ +import pipedsl.common.Utilities.degenerify +import scala.collection.mutable +import scala.language.implicitConversions + +object FunctionConstraintChecker + { + def check(p: Prog): Unit = + { + val cons_map = mutable.HashMap[Id, FuncDef]() + p.fdefs.foreach(f => cons_map.addOne(f.name, f)) + p.fdefs.foreach(checkFunc(cons_map, _)) + p.moddefs.foreach(checkMod(cons_map, _)) + } + + def checkFunc(cons_map: mutable.HashMap[Id, FuncDef], f: FuncDef): Unit = + { + val ctxt = new Z3Context() + val solv = ctxt.mkSolver() + f.constraints.foreach(c => solv.add(to_z3(ctxt, c))) + checkCmd(f.body, ctxt, solv, cons_map) + } + + def checkMod(cons_map: mutable.HashMap[Id, FuncDef], m: ModuleDef): Unit = + { + val ctxt = new Z3Context() + val solv = ctxt.mkSolver() + checkCmd(m.body, ctxt, solv, cons_map) + } + + def checkCmd(c: Command, ctxt: Z3Context, solv: Z3Solver, cons_map: mutable.HashMap[Id, FuncDef]): Unit = + { + def _checkCmd(cmd: Command): Unit = checkCmd(cmd, ctxt, solv, cons_map) + + def _checkExpr(ex: Expr): Unit = checkExpr(ex, ctxt, solv, cons_map) + + c match + { + case CSeq(c1, c2) => _checkCmd(c1); _checkCmd(c2) + case CTBar(c1, c2) => _checkCmd(c1); _checkCmd(c2) + case CIf(cond, cons, alt) => _checkExpr(cond); _checkCmd(cons); _checkCmd(alt) + case CAssign(lhs, rhs) => _checkExpr(rhs) + case CRecv(lhs, rhs) => _checkExpr(rhs) + case CSpecCall(handle, pipe, args) => args.foreach(_checkExpr) + case CVerify(handle, args, preds, update) => args.foreach(_checkExpr) + case CUpdate(newHandle, handle, args, preds) => args.foreach(_checkExpr) + case CPrint(args) => args.foreach(_checkExpr) + case COutput(exp) => _checkExpr(exp) + case CReturn(exp) => _checkExpr(exp) + case CExpr(exp) => _checkExpr(exp) + case CSplit(cases, default) => cases.foreach(c => + { + _checkExpr(c.cond); + _checkCmd(c.body) + }); + _checkCmd(default) + case _ => () + } + } + + def extract_width(t: Type): Option[IntExpr] = t match + { + case TSizedInt(len, _) => Some(len) + case _ => None + } + + private implicit class PipelineContainer[F](val value: F) + { + def |>[G](f: F => G): G = f(value) + } + + def type_of_fdef(f: FuncDef): TFun = TFun(f.args.map(a => a.typ), f.ret) + + def checkExpr(e: Expr, ctxt: Z3Context, solv: Z3Solver, cons_map: mutable.HashMap[Id, FuncDef]): Unit = + { + def _checkExpr(ex: Expr): Unit = checkExpr(ex, ctxt, solv, cons_map) + + e match + { + case EIsValid(ex) => _checkExpr(ex) + case EFromMaybe(ex) => _checkExpr(ex) + case EToMaybe(ex) => _checkExpr(ex) + case EUop(_, ex) => _checkExpr(ex) + case EBinop(_, e1, e2) => _checkExpr(e1); _checkExpr(e2) + case ERecAccess(rec, _) => _checkExpr(rec) + case ERecLiteral(fields) => fields.foreach((ex) => _checkExpr(ex._2)) + case EMemAccess(_, index, wmask, _, _, _) => _checkExpr(index); wmask.foreach(_checkExpr) + case EBitExtract(num, _, _) => _checkExpr(num) + case ETernary(cond, tval, fval) => _checkExpr(cond); _checkExpr(tval); _checkExpr(fval) + case ea@EApp(func, args) => + + type_of_fdef(cons_map(func)).matchOrError(e.pos, "func type", "func type") + { case TFun(targs, ret) => solv.push() + val contraints_here = targs.zip(args.map(e => e.typ.getOrElse(throw MissingType(e.pos, "arg type")))).map(pair => + { + (pair._1 |> extract_width, pair._2 |> degenerify |> extract_width) match + { + case (Some(a), Some(b)) => Some(ReEq(a, b)) + case _ => None + } + }).collect + { case Some(cons) => cons + }.prependedAll((ret |> extract_width, e.typ.getOrElse(throw MissingType(e.pos, "ret type")) |> degenerify |> extract_width) match + { case (Some(a), Some(b)) => List(ReEq(a, b)) + case _ => List() + }).map(degenerify_constr) + val called_cons = cons_map(func).constraints.map(degenerify_constr) + val constraints = called_cons prependedAll contraints_here + constraints.foreach(c => solv.add(to_z3(ctxt, c))) + solv.check() match + { + case Status.UNSATISFIABLE | Status.UNKNOWN => throw BadConstraintsAtCall(ea) + case Status.SATISFIABLE => () + } + solv.pop() + } + case ECall(_, _, args) => args.foreach(_checkExpr) + case ECast(_, exp) => _checkExpr(exp) + case _ => () + } + } + } diff --git a/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala b/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala index 61893555..16c68a8c 100644 --- a/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala +++ b/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala @@ -1,16 +1,20 @@ package pipedsl.typechecker -import pipedsl.common.Errors +import pipedsl.common.{Errors, Syntax} import pipedsl.common.Errors._ import pipedsl.common.Syntax.Latency.{Asynchronous, Combinational, Latency, Sequential} import pipedsl.common.Syntax._ +import pipedsl.common.Constraints._ +import pipedsl.common.Constraints.ImplicitConstraints._ import pipedsl.common.Utilities.{defaultReadPorts, defaultWritePorts, degenerify, fopt_func, is_generic, is_my_generic, specialize, typeMap, typeMapFunc, typeMapModule, without_prefix} import pipedsl.typechecker.Environments.{EmptyTypeEnv, Environment, TypeEnv} import pipedsl.typechecker.Subtypes.{canCast, isSubtype} import com.microsoft.z3.{Status, AST => Z3AST, ArithExpr => Z3ArithExpr, BoolExpr => Z3BoolExpr, Context => Z3Context, IntExpr => Z3IntExpr, Solver => Z3Solver} - import TBitWidthImplicits._ +import pipedsl.codegen.bsv.ConstraintsToBluespec.to_provisos + import scala.collection.mutable +import scala.language.implicitConversions object TypeInferenceWrapper { @@ -68,6 +72,10 @@ object TypeInferenceWrapper val w1 = subst_into_type(typevar, toType, b1) |> to_width val w2 = subst_into_type(typevar, toType, b2) |> to_width TBitWidthAdd(w1, w2).setPos(inType.pos) + case TBitWidthSub(b1, b2) => + val w1 = subst_into_type(typevar, toType, b1) |> to_width + val w2 = subst_into_type(typevar, toType, b2) |> to_width + TBitWidthSub(w1, w2).setPos(inType.pos) case TBitWidthMax(b1, b2) => val t1 = TBitWidthMax(subst_into_type(typevar, toType, b1) |> to_width, subst_into_type(typevar, toType, b2) |> to_width) @@ -96,6 +104,10 @@ object TypeInferenceWrapper private var context = new Z3Context() private var solver = context.mkSolver() + + private var constraints = List[Constraint]() + //we'll keep a local list of the statements, and check it at the end w/ z3 + private def type_subst_map_fopt(t :Type, tp_mp: mutable.HashMap[Id, Type], templated :mutable.HashSet[Id] = mutable.HashSet.empty) :Option[Type] = try { Some(type_subst_map(t, tp_mp, templated)) @@ -128,18 +140,29 @@ object TypeInferenceWrapper case None => throw IntWidthNotSpecified() } case m@TMaybe(btyp) => m.copy(btyp = type_subst_map(btyp, tp_mp, templated)) - case TBitWidthAdd(b1, b2) => + case tba@TBitWidthAdd(b1, b2) => val len1 = type_subst_map(b1, tp_mp, templated) val len2 = type_subst_map(b2, tp_mp, templated) (len1, len2) match { case (TBitWidthLen(l1), TBitWidthLen(l2)) => TBitWidthLen(l1 + l2) case (l1:TBitWidth, l2:TBitWidth) => - { - val lst = (l1.stringRep()::l2.stringRep()::Nil).sorted - val id = Id(lst.head + "_ADD_" + lst(1)) - add_map.addOne((lst.head, lst(1)), id) + val id = Id(tba.stringRep()) + //val id = Id("_A" + lst.head + "_ADD_" + lst(1)) + //add_map.addOne((lst.head, lst(1)), id) + constraints = constraints.prepended( + ReEq(IntAdd(IntVar(Id(l1.stringRep())), IntVar(Id(l2.stringRep()))), IntVar(id))) + TBitWidthVar(id) + } + case tbs@TBitWidthSub(b1, b2) => + val len1 = type_subst_map(b1, tp_mp, templated) + val len2 = type_subst_map(b2, tp_mp, templated) + (len1, len2) match { + case (TBitWidthLen(l1), TBitWidthLen(l2)) => TBitWidthLen(l1 + l2) + case (l1 :TBitWidth, l2 :TBitWidth) => + val id = Id(tbs.stringRep()) + constraints = constraints.prepended( + ReEq(IntSub(IntVar(Id(l1.stringRep())), IntVar(Id(l2.stringRep()))), IntVar(id))) TBitWidthVar(id) - } } case TBitWidthMax(b1, b2) => val len1 = type_subst_map(b1, tp_mp, templated) |> to_len @@ -234,6 +257,7 @@ object TypeInferenceWrapper min_map.clear() context = new Z3Context() solver = context.mkSolver() + constraints = List() val inputTypes = f.args.map(a => a.typ) val funType = TFun(inputTypes, f.ret) val funEnv = env.add(f.name, funType) @@ -242,6 +266,17 @@ object TypeInferenceWrapper val inEnv = f.args.foldLeft[Environment[Id, Type]](template_vals_env)((env, a) => env.add(a.name, a.typ)) val (fixed_cmd, _, subst) = checkCommand(f.body, inEnv.asInstanceOf[TypeEnv], List()) + constraints = reduce_constraint_list(constraints) + solver.push() + constraints.foreach(c => solver.add(to_z3(context, c))) + val stat = solver.check() + solver.pop() + stat match + { + case Status.UNSATISFIABLE | + Status.UNKNOWN => throw new RuntimeException("YOU HAVE FUCKED UP IN THE GRANDEST OF WAYS") + case Status.SATISFIABLE => + } val hash = mutable.HashMap.from(subst) val newFunc = typeMapFunc(f.copy(body = fixed_cmd, args = f.args.map(p => p.copy(typ = type_subst_map(p.typ, hash, templated))), @@ -249,6 +284,7 @@ object TypeInferenceWrapper /*TODO: add another pass over the types to make substitutions for bit expressions. Collect them as provisos*/ newFunc.adds.addAll(add_map) + newFunc.constraints = newFunc.constraints ++ constraints newFunc.mins.addAll(min_map) (funEnv, newFunc) } @@ -461,7 +497,7 @@ object TypeInferenceWrapper }) val newEnv = lhs match { - case EVar(id) => rhsEnv.add(id, tlhs) + case EVar(id) => rhsEnv.remove(id).add(id, tlhs) case _ => rhsEnv } lhs.typ = Some(apply_subst_typ(s1, lhstyp)) @@ -528,20 +564,23 @@ object TypeInferenceWrapper case (t1 :TBitWidthVar, t2 :TBitWidthVar) if is_my_generic(t1) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) case (t1 :TBitWidthVar, t2 :TBitWidthVar) if is_generic(t2) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) case (t1 :TBitWidthVar, t2 :TBitWidthVar) if is_generic(t1) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) - case (t1: TBitWidthVar, t2: TBitWidth) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) - case (t1: TBitWidth, t2: TBitWidthVar) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) + case (TBitWidthAdd(a1: TBitWidthLen, a2), TBitWidthLen(len)) => unify(a2, TBitWidthLen(len - a1.len), binop) case (TBitWidthAdd(a2, a1: TBitWidthLen), TBitWidthLen(len)) => unify(a2, TBitWidthLen(len - a1.len), binop) case (TBitWidthLen(len), TBitWidthAdd(a1: TBitWidthLen, a2)) => unify(a2, TBitWidthLen(len - a1.len), binop) case (TBitWidthLen(len), TBitWidthAdd(a2, a1: TBitWidthLen)) => unify(a2, TBitWidthLen(len - a1.len), binop) - case (TBitWidthAdd(a1: TBitWidthVar, a2: TBitWidthVar), TBitWidthLen(len)) => if (len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") - val (s1, c1) = unify(a1, TBitWidthLen(len / 2), binop) - val (s2, c2) = unify(a2, TBitWidthLen(len / 2), binop) - (compose_subst(s1, s2), c1 || c2) - case (TBitWidthLen(len), TBitWidthAdd(a1: TBitWidthVar, a2: TBitWidthVar)) => if (len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") - val (s1, c1) = unify(a1, TBitWidthLen(len / 2), binop) - val (s2, c2) = unify(a2, TBitWidthLen(len / 2), binop) - (compose_subst(s1, s2), c1 || c2) + + case (t1: TBitWidthVar, t2: TBitWidth) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) + case (t1: TBitWidth, t2: TBitWidthVar) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) +// case (TBitWidthAdd(a1: TBitWidthVar, a2: TBitWidthVar), TBitWidthLen(len)) => if (len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") +// val (s1, c1) = unify(a1, TBitWidthLen(len / 2), binop) +// val (s2, c2) = unify(a2, TBitWidthLen(len / 2), binop) +// (compose_subst(s1, s2), c1 || c2) +// case (TBitWidthLen(len), TBitWidthAdd(a1: TBitWidthVar, a2: TBitWidthVar)) => if (len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") +// val (s1, c1) = unify(a1, TBitWidthLen(len / 2), binop) +// val (s2, c2) = unify(a2, TBitWidthLen(len / 2), binop) +// (compose_subst(s1, s2), c1 || c2) + case (t1: TBitWidthLen, t2: TBitWidthLen) => if (autocast) { @@ -550,6 +589,9 @@ object TypeInferenceWrapper { if (t2.len != t1.len) throw UnificationError(t1, t2) else (List(), false) } + case (t1 :TBitWidth, t2 :TBitWidth) => + constraints = constraints.prepended(ReEq(t1, t2)) + (List(), binop) case (other1, other2) if other1 == other2 => (List(), false) case _ => throw UnificationError(a, b) @@ -628,6 +670,7 @@ object TypeInferenceWrapper case TSignVar(name1) => name1 == name case TBitWidthVar(name1) => name1 == name case TBitWidthAdd(b1, b2) => occursIn(name, b1) || occursIn(name, b2) + case TBitWidthSub(b1, b2) => occursIn(name, b1) || occursIn(name, b2) case TBitWidthMax(b1, b2) => occursIn(name, b1) || occursIn(name, b2) case TBitWidthLen(_) => false case _: TSignedNess => false @@ -774,24 +817,13 @@ object TypeInferenceWrapper t match { case TSizedInt(bitwidth, signedness) => - solver.push() - solver.add(context.mkLt(context.mkSub(z3_of_index(end), z3_of_index(start)), z3_of_width(bitwidth))) - solver.add(context.mkGe(z3_of_index(start), context.mkInt(0))) - solver.add(context.mkGe(z3_of_index(end), context.mkInt(0))) - val sat = solver.check() - solver.pop() - if (sat != Status.SATISFIABLE) - throw InvalidBitExtraction(s"$end - $start", bitwidth) + constraints = constraints.prepended(ReGe(toConstraint(end), toConstraint(0))) + constraints = constraints.prepended(ReGe(toConstraint(start), toConstraint(0))) + constraints = constraints.prepended(RelLt(toConstraint(end), toConstraint(bitwidth))) + constraints = constraints.prepended(RelLt(toConstraint(start), toConstraint(bitwidth))) + constraints = constraints.prepended(RelLt(IntSub(toConstraint(end), toConstraint(start)), toConstraint(bitwidth))) (s, TSizedInt(TBitWidthAdd(TBitWidthSub(end,start), 1), signedness), en, b.copy(num = fixed_num).copyMeta(b)) - //NOTE: old implementation. Not getting rid of until confident in Z3 -// case TSizedInt(widthvar, signedNess: TSignedNess) => -// val width = widthvar.stringRep() -// val len = math.abs(end - start) + 1 -// min_map.addOne((width, len)) -// if(!is_my_generic(widthvar)) -// throw LackOfConstraints(e) -// (s, TSizedInt(TBitWidthLen(len), signedNess), en, b.copy(num = fixed_num).copyMeta(b)) case b => throw UnificationError(b, TSizedInt(TBitWidthLen(32), TUnsigned())) //TODO Add better error message } //TODO case trn@ETernary(cond, tval, fval) => val (sc, tc, env1, fixed_cond) = infer(env, cond) diff --git a/src/test/tests/typecheckTests/generic_func_fail.pdl b/src/test/tests/typecheckTests/generic_func_fail.pdl index 206b1dfa..cd562211 100644 --- a/src/test/tests/typecheckTests/generic_func_fail.pdl +++ b/src/test/tests/typecheckTests/generic_func_fail.pdl @@ -1,6 +1,11 @@ -def adder(x :int, y :int) :int +def indexing(x :int, y :int) :int<6> { - return x + y + 1<1>; + return x{J - 1:0}; +} + +def error(x :int<3>, y :int<6>) :int<6> +{ + return indexing(x, y); } circuit { diff --git a/src/test/tests/typecheckTests/generic_funcs_passz3.pdl b/src/test/tests/typecheckTests/generic_funcs_passz3.pdl new file mode 100644 index 00000000..e33594cc --- /dev/null +++ b/src/test/tests/typecheckTests/generic_funcs_passz3.pdl @@ -0,0 +1,66 @@ + +def extract(inpt :uint<32>) :bool +{ + uint<5> a = inpt{4:0}; + return true; +} + +def iden(a :int) :int +{ + return a; +} + + +def adder(x :int, y :int) :int +{ + return x + iden(y); +} + +def my_concat(a :int, b :int, c :int) :int +{ + tmp = a ++ b ++ c; + other = b ++ a ++ c; + return tmp + other; +} + +def indexing(x :int, y :int) :int<5> +{ + tmp = (x ++ y){4:0}; + return (x ++ y){4:0}; +} + +def in_scope(x :int, y :int) :int +{ + int z = x{J - 1:0}; + q = x{I - 1:0}; + return z; +} + + +def test_in_scope(x :int<6>, y :int<3>) :int<3> +{ + return in_scope(x, y); +} + +def infty(x :int<5>, y :int<5>) :int<10> +{ + return x ++ y; +} + +pipe test1(inpt: uint<32>)[rf: int<32>[32]] :int<100> +{ + uint<5> a = inpt{4:0}; + output (adder(3, 4)); + +} + +pipe test2(inpt :int<3>)[] :int<5> +{ + tmp = indexing(inpt, inpt); + output(tmp); +} + + +circuit { + r = memory(int<32>, 32); +} \ No newline at end of file diff --git a/src/test/tests/typecheckTests/solutions/generic_funcs_passz3.typechecksol b/src/test/tests/typecheckTests/solutions/generic_funcs_passz3.typechecksol new file mode 100644 index 00000000..863339fb --- /dev/null +++ b/src/test/tests/typecheckTests/solutions/generic_funcs_passz3.typechecksol @@ -0,0 +1 @@ +Passed