diff --git a/src/main/scala/pipedsl/Parser.scala b/src/main/scala/pipedsl/Parser.scala index d247ca47..a400e81e 100644 --- a/src/main/scala/pipedsl/Parser.scala +++ b/src/main/scala/pipedsl/Parser.scala @@ -35,7 +35,8 @@ class Parser extends RegexParsers with PackratParsers { val e = EInt(n, base, if (bits.isDefined) bits.get else log2(n)) e.typ = bits match { case Some(b) => Some(TSizedInt(TBitWidthLen(b), SignFactory.ofBool(!isUnsigned))) - case None => None + case None if isUnsigned => Some(TSizedInt(TBitWidthLen(e.bits), TUnsigned())) + case _ => None } // e.typ = Some(TSizedInt(TBitWidthLen(e.bits), unsigned = isUnsigned)) e @@ -264,7 +265,7 @@ class Parser extends RegexParsers with PackratParsers { seqCmd } - lazy val sizedInt: P[Type] = "int" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TUnsigned() /*unsigned = false*/) } | + lazy val sizedInt: P[Type] = "int" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TSigned() /*unsigned = false*/) } | "uint" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TUnsigned() /*unsigned = true*/) } lazy val latency: P[Latency.Latency] = diff --git a/src/main/scala/pipedsl/common/DAGSyntax.scala b/src/main/scala/pipedsl/common/DAGSyntax.scala index 3032affb..7c2603f3 100644 --- a/src/main/scala/pipedsl/common/DAGSyntax.scala +++ b/src/main/scala/pipedsl/common/DAGSyntax.scala @@ -276,13 +276,13 @@ object DAGSyntax { val defaultNum = conds.size val condVar = EVar(Id("__cond" + n.v)) val intSize = log2(defaultNum) - condVar.typ = Some(TSizedInt(TBitWidthLen(intSize), sign = true)) + condVar.typ = Some(TSizedInt(TBitWidthLen(intSize), TUnsigned())) condVar.id.typ = condVar.typ var eTernary = ETernary(conds(defaultNum - 1), EInt(defaultNum - 1, bits = intSize), EInt(defaultNum, bits = intSize)) for(i <- defaultNum-2 to 0 by -1 ) { eTernary = ETernary(conds(i), EInt(i, bits = intSize), eTernary.copy()) } - this.addCmd(CAssign(condVar, eTernary, Some(TSizedInt(TBitWidthLen(intSize), sign = true)))) + this.addCmd(CAssign(condVar, eTernary, Some(TSizedInt(TBitWidthLen(intSize), TUnsigned())))) for (i <- 0 until defaultNum) { this.addEdgeTo(condStages(i).head, condSend = Some (EBinop(EqOp("=="), condVar, EInt(i, bits = intSize)))) condStages(i).last.addEdgeTo(joinStage, condRecv = Some (EBinop(EqOp("=="), condVar, EInt(i, bits = intSize)))) diff --git a/src/main/scala/pipedsl/common/Errors.scala b/src/main/scala/pipedsl/common/Errors.scala index a4471fd3..61a3f54e 100644 --- a/src/main/scala/pipedsl/common/Errors.scala +++ b/src/main/scala/pipedsl/common/Errors.scala @@ -180,6 +180,6 @@ object Errors { ) case class UnificationError(t1: Type, t2: Type) extends RuntimeException( - s"Unable to unify type $t1 and type $t2" + withPos(s"Unable to unify type $t1 and type $t2", t1.pos) ) } diff --git a/src/main/scala/pipedsl/common/PrettyPrinter.scala b/src/main/scala/pipedsl/common/PrettyPrinter.scala index 1f18f0a9..39fd16ac 100644 --- a/src/main/scala/pipedsl/common/PrettyPrinter.scala +++ b/src/main/scala/pipedsl/common/PrettyPrinter.scala @@ -148,7 +148,7 @@ class PrettyPrinter(output: Option[File]) { def printType(t: Type): Unit = pline(printTypeToString(t)) def printTypeToString(t: Type): String = t match { - case TSizedInt(len, unsigned) => (if (!unsigned) "s" else "") + "int<" + len.toString + ">" + case TSizedInt(len, sign) => (if (sign.signed()) "s" else "") + "int<" + len.toString + ">" case TVoid() => "void" case TBool() => "bool" case TFun(args, ret) => "(" + args.map(a => printTypeToString(a)).mkString(",") + ") -> " + printTypeToString(ret) diff --git a/src/main/scala/pipedsl/common/Syntax.scala b/src/main/scala/pipedsl/common/Syntax.scala index 30bb4d46..288df447 100644 --- a/src/main/scala/pipedsl/common/Syntax.scala +++ b/src/main/scala/pipedsl/common/Syntax.scala @@ -104,7 +104,9 @@ object Syntax { case TBitWidthLen(len) => len.toString case TBitWidthMax(b1, b2) => "max(" + b1 + ", " + b2 + ")" case TBitWidthVar(name) => "bitVar(" + name + ")" - + case TSigned() => "signed" + case TUnsigned() => "unsigned" + case TSignVar(name) => "sign(" + name + ")" } } // Types that can be upcast to Ints @@ -115,10 +117,12 @@ object Syntax { { case TSigned() => true case TUnsigned() => false + case _ => false } def unsigned() :Boolean = this match { case TSigned() => false case TUnsigned() => true + case _ => false } } object SignFactory diff --git a/src/main/scala/pipedsl/common/Utilities.scala b/src/main/scala/pipedsl/common/Utilities.scala index 783dfc88..02655493 100644 --- a/src/main/scala/pipedsl/common/Utilities.scala +++ b/src/main/scala/pipedsl/common/Utilities.scala @@ -6,6 +6,8 @@ import pipedsl.common.DAGSyntax.PStage import pipedsl.common.Errors.UnexpectedCommand import pipedsl.common.Syntax._ +import scala.annotation.tailrec + object Utilities { @@ -305,6 +307,99 @@ object Utilities { } } + + def opt_func[A, B](f :A => B) : Option[A] => Option[B] = + { + case Some(value) => Some(f(value)) + case None => None + } + + private def typeMapExpr(e :Expr, f_opt : Option[Type] => Option[Type]) : Unit = + { + println(s"setting ${e.typ} to ${f_opt(e.typ)}") + e.typ = f_opt(e.typ) + e match + { + case EIsValid(ex) => typeMapExpr(ex, f_opt) + case EFromMaybe(ex) => typeMapExpr(ex, f_opt) + case EToMaybe(ex) => typeMapExpr(ex, f_opt) + case EUop(_, ex) => typeMapExpr(ex, f_opt) + case EBinop(_, e1, e2) => typeMapExpr(e1, f_opt); typeMapExpr(e2, f_opt) + case ERecAccess(rec, fieldName) => typeMapId(fieldName, f_opt); typeMapExpr(rec, f_opt); + case ERecLiteral(fields) => + fields.foreach(idex => {typeMapId(idex._1, f_opt); typeMapExpr(idex._2, f_opt)}) + case EMemAccess(mem, index, wmask) => + typeMapId(mem, f_opt); typeMapExpr(index, f_opt) + wmask.fold(())(typeMapExpr(_, f_opt)) + case EBitExtract(num, _, _) => typeMapExpr(num, f_opt) + case ETernary(cond, tval, fval) => + typeMapExpr(cond, f_opt); typeMapExpr(tval, f_opt); typeMapExpr(fval, f_opt) + case EApp(func, args) => + typeMapId(func, f_opt); args.foreach(typeMapExpr(_, f_opt)) + case ECall(mod, args) => + typeMapId(mod, f_opt); args.foreach(typeMapExpr(_, f_opt)) + case EVar(id) => typeMapId(id, f_opt) + case ECast(_, exp) => typeMapExpr(exp, f_opt) + case expr: CirExpr => expr match + { + case CirLock(mem, _, _) => typeMapId(mem, f_opt) + case CirNew(mod, mods) => typeMapId(mod, f_opt) + mods.foreach((i: Id) => typeMapId(i, f_opt)) + case CirCall(mod, args) => typeMapId(mod, f_opt) + args.foreach(typeMapExpr(_, f_opt)) + case _ => () + } + case _ => () + } + } + private def typeMapCmd(c :Command, f_opt :Option[Type] => Option[Type]) :Unit = c match + { + case CSeq(c1, c2) => typeMapCmd(c1, f_opt); typeMapCmd(c2, f_opt) + case CTBar(c1, c2) => typeMapCmd(c1, f_opt); typeMapCmd(c2, f_opt) + case CIf(cond, cons, alt) => typeMapExpr(cond, f_opt); typeMapCmd(cons, f_opt); typeMapCmd(alt, f_opt) + case CAssign(lhs, rhs, _) => typeMapExpr(lhs, f_opt); typeMapExpr(rhs, f_opt) + case CRecv(lhs, rhs, _) => typeMapExpr(lhs, f_opt); typeMapExpr(rhs, f_opt) + case CSpecCall(handle, pipe, args) => + typeMapExpr(handle, f_opt); typeMapId(pipe, f_opt); args.foreach(typeMapExpr(_, f_opt)) + case CVerify(handle, args, preds) => + typeMapExpr(handle, f_opt); args.foreach(typeMapExpr(_, f_opt)); preds.foreach(typeMapExpr(_, f_opt)) + case CInvalidate(handle) => typeMapExpr(handle, f_opt) + case CPrint(args) => args.foreach(typeMapExpr(_, f_opt)) + case COutput(exp) => typeMapExpr(exp, f_opt) + case CReturn(exp) => typeMapExpr(exp, f_opt) + case CExpr(exp) => typeMapExpr(exp, f_opt) + case CLockStart(mod) => typeMapId(mod, f_opt) + case CLockEnd(mod) => typeMapId(mod, f_opt) + case CLockOp(mem, _, _) => + typeMapId(mem.id, f_opt); + mem.evar match + {case Some(value) => typeMapExpr(value, f_opt) + case None => ()} + case CSplit(cases, default) => + cases.foreach(cs => {typeMapExpr(cs.cond, f_opt); typeMapCmd(cs.body, f_opt)}) + typeMapCmd(default, f_opt) + case _ => () + } + + private def typeMapId(i: Id, f_opt: Option[Type] => Option[Type]):Unit = + { + i.typ = f_opt(i.typ) + } + + + + def typeMapFunc(fun :FuncDef, f_opt :Option[Type] => Option[Type]) :Unit = typeMapCmd(fun.body, f_opt) + def typeMapModule(mod :ModuleDef, f_opt :Option[Type] => Option[Type]) :Unit = typeMapCmd(mod.body, f_opt) + + def typeMap(p: Prog, f: Type => Type) :Unit= + { + val f_opt = opt_func(f) + p.fdefs.foreach(typeMapFunc(_, f_opt)) + p.moddefs.foreach(typeMapModule(_, f_opt)) + } + + + /** Like [[Z3Context.mkAnd]], but automatically casts inputs to [[Z3BoolExpr]]s. */ def mkAnd(ctx: Z3Context, expressions: Z3AST *): Z3BoolExpr = ctx.mkAnd(expressions.map(ast => ast.asInstanceOf[Z3BoolExpr]):_*) diff --git a/src/main/scala/pipedsl/typechecker/Subtypes.scala b/src/main/scala/pipedsl/typechecker/Subtypes.scala index a4f253a1..dfceda9e 100644 --- a/src/main/scala/pipedsl/typechecker/Subtypes.scala +++ b/src/main/scala/pipedsl/typechecker/Subtypes.scala @@ -42,8 +42,8 @@ object Subtypes { } def areEqual(t1: Type, t2: Type): Boolean = (t1, t2) match { - case (TSizedInt(l1, u1), TBool()) => l1.asInstanceOf[TBitWidthLen].len == 1 && u1.asInstanceOf[TSignedNess].unsigned() - case (TBool(), TSizedInt(l1, u1)) => l1.asInstanceOf[TBitWidthLen].len == 1 && u1.asInstanceOf[TSignedNess].unsigned() + case (TSizedInt(l1, u1), TBool()) => l1.asInstanceOf[TBitWidthLen].len == 1 && u1.unsigned() + case (TBool(), TSizedInt(l1, u1)) => l1.asInstanceOf[TBitWidthLen].len == 1 && u1.unsigned() case (TSizedInt(l1, u1), TSizedInt(l2, u2)) => l1 == l2 && u1 == u2 case (TMemType(e1, as1, r1, w1, rp1, wp1), TMemType(e2, as2, r2, w2, rp2, wp2)) diff --git a/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala b/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala index 8229b7c2..9bb8c90b 100644 --- a/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala +++ b/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala @@ -5,9 +5,11 @@ import pipedsl.common.Syntax._ import pipedsl.typechecker.Subtypes.{areEqual, canCast, isSubtype} import pipedsl.common.Errors import pipedsl.common.Syntax.Latency.{Asynchronous, Combinational, Sequential} -import pipedsl.common.Utilities.{defaultReadPorts, defaultWritePorts} +import pipedsl.common.Utilities.{defaultReadPorts, defaultWritePorts, opt_func, typeMapModule} import pipedsl.typechecker.Environments.{Environment, TypeEnv} +import scala.collection.mutable + object TypeInferenceWrapper { type Subst = List[(Id, Type)] @@ -17,12 +19,19 @@ object TypeInferenceWrapper { case t@TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => t.copy(elem = subst_into_type(typevar, toType, elem)) case t1@TLockedMemType(t2@TMemType(elem, _, _, _, _, _), _, _) => t1.copy(t2.copy(elem = subst_into_type(typevar, toType, elem))) - case TSizedInt(len, unsigned) => TSizedInt(subst_into_type(typevar, toType, len).asInstanceOf[TBitWidth], unsigned) + case TSizedInt(len, signedness) => + TSizedInt(subst_into_type(typevar, toType, len).asInstanceOf[TBitWidth], + subst_into_type(typevar, toType, signedness).asInstanceOf[TSignedNess]) case TString() => inType case TBool() => inType case TVoid() => inType + case TSigned() => inType + case TUnsigned() => inType case TFun(args, ret) => TFun(args.map(a => subst_into_type(typevar, toType, a)), subst_into_type(typevar, toType, ret)) - case TNamedType(name) => if (name == typevar) toType else inType + case TNamedType(name) => + if (name == typevar) toType else inType + case TSignVar(name) => + if (name == typevar) toType else inType case TModType(inputs, refs, retType, name) => TModType(inputs.map(i => subst_into_type(typevar, toType, i)), refs.map(r => subst_into_type(typevar, toType, r)), retType match { case Some(value) => Some(subst_into_type(typevar, toType, value)) case None => None @@ -50,6 +59,34 @@ object TypeInferenceWrapper } } + private def type_subst_map(t :Type, tp_mp :mutable.HashMap[Id, Type]) :Type = t match + { + case TSignVar(nm) => type_subst_map(tp_mp.getOrElse(nm, t), tp_mp) + case sz@TSizedInt(len, sign) => + sz.copy(len = type_subst_map(len, tp_mp).asInstanceOf[TBitWidth], + sign = type_subst_map(sign, tp_mp).asInstanceOf[TSignedNess]) + case f@TFun(args, ret) => f.copy(args = args.map(type_subst_map(_, tp_mp)), ret = type_subst_map(ret, tp_mp)) + case r@TRecType(name, fields) => r.copy(fields = fields.map((idtp) => (idtp._1, type_subst_map(idtp._2, tp_mp)))) + case m@TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => + m.copy(elem = type_subst_map(elem, tp_mp)) + case m@TModType(inputs, refs, retType, name) => + m.copy(inputs = inputs.map(type_subst_map(_, tp_mp)), refs = refs.map(type_subst_map(_, tp_mp))) + case l@TLockedMemType(mem, idSz, limpl) => + l.copy(mem = type_subst_map(mem, tp_mp).asInstanceOf[TMemType]) + case TNamedType(name) => type_subst_map(tp_mp.getOrElse(name, t), tp_mp) + case m@TMaybe(btyp) => m.copy(btyp = type_subst_map(btyp, tp_mp)) + case TBitWidthAdd(b1, b2) => + val tmp = TBitWidthLen(type_subst_map(b1, tp_mp).asInstanceOf[TBitWidthLen].len + + type_subst_map(b2, tp_mp).asInstanceOf[TBitWidthLen].len) + println(s"mapping $b1 + $b2 to ${tmp.len}") + tmp + case TBitWidthMax(b1, b2) => + TBitWidthLen(Math.max(type_subst_map(b1, tp_mp).asInstanceOf[TBitWidthLen].len, + type_subst_map(b2, tp_mp).asInstanceOf[TBitWidthLen].len)) + case TBitWidthVar(name) => type_subst_map(tp_mp.getOrElse(name, t), tp_mp) + case _ => t + } + class TypeInference @@ -112,30 +149,24 @@ object TypeInferenceWrapper c.typ = Some(ltyp) (ltyp, tenv) case CirNew(mod, mods) => - { val mtyp = tenv(mod) mtyp match { case TModType(_, refs, _, _) => - { if (refs.length != mods.length) { throw ArgLengthMismatch(c.pos, mods.length, refs.length) } refs.zip(mods).foreach { case (reftyp, mname) => - { if (!(isSubtype(tenv(mname), reftyp))) { throw UnexpectedSubtype(mname.pos, mname.toString, reftyp, tenv(mname)) } - } } (mtyp, tenv) - } case x => throw UnexpectedType(c.pos, c.toString, "Module Type", x) } - } case CirCall(mod, inits) => { val mtyp = tenv(mod) @@ -171,7 +202,10 @@ object TypeInferenceWrapper val modEnv = env.add(m.name, TModType(inputTypes, modTypes, m.ret, Some(m.name))) val inEnv = m.inputs.foldLeft[Environment[Id, Type]](modEnv)((env, p) => env.add(p.name, p.typ)) val pipeEnv = m.modules.zip(modTypes).foldLeft[Environment[Id, Type]](inEnv)((env, m) => env.add(m._1.name, m._2)) - checkCommand(m.body, pipeEnv.asInstanceOf[TypeEnv], List()) + val (_, subst) = checkCommand(m.body, pipeEnv.asInstanceOf[TypeEnv], List()) + val hash = mutable.HashMap.from(subst) + typeMapModule(m, opt_func(type_subst_map(_, hash))) + println("SUBSTITUTIONS:\n" + subst) modEnv } @@ -196,6 +230,8 @@ object TypeInferenceWrapper Transforms the argument env by subbing in the returned substitution and adding any relevatn variables */ def checkCommand(c: Command, env: TypeEnv, sub: Subst): (TypeEnv, Subst) = { + /*println(c) + println(c.pos)*/ c match { case CLockOp(mem, op, lockType) => //test basic first @@ -312,7 +348,8 @@ object TypeInferenceWrapper case _ => rhsEnv } (newEnv.asInstanceOf[TypeEnv].apply_subst_typeenv(sret), sret) - case CAssign(lhs, rhs, typ) => val (slhs, tlhs, lhsEnv) = (List(), typ.getOrElse(generateTypeVar()), env) + case CAssign(lhs, rhs, typ) => + val (slhs, tlhs, lhsEnv) = (List(), typ.getOrElse(generateTypeVar()), env) val (srhs, trhs, rhsEnv) = infer(lhsEnv, rhs) val tempSub = compose_many_subst(sub, slhs, srhs) val lhstyp = apply_subst_typ(tempSub, tlhs) @@ -327,10 +364,13 @@ object TypeInferenceWrapper case EVar(id) => rhsEnv.add(id, tlhs) case _ => rhsEnv } + lhs.typ = Some(lhstyp) + rhs.typ = Some(rhstyp) (newEnv.asInstanceOf[TypeEnv].apply_subst_typeenv(sret), sret) case CSeq(c1, c2) => val (e1, s) = checkCommand(c1, env, sub) val (e2, s2) = checkCommand(c2, e1, s) (e2, s2) + case _:InternalCommand => (env, sub) } } @@ -363,10 +403,12 @@ object TypeInferenceWrapper case TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => false case TModType(inputs, refs, retType, name) => false case TNamedType(name1) => name1 == name + case TSignVar(name1) => name1 == name case TBitWidthVar(name1) => name1 == name case TBitWidthAdd(b1, b2) => occursIn(name, b1) || occursIn(name, b2) case TBitWidthMax(b1, b2) => occursIn(name, b1) || occursIn(name, b2) case TBitWidthLen(len) => false + case _:TSignedNess => false } @@ -381,12 +423,16 @@ object TypeInferenceWrapper { case (t1: TNamedType, t2) => if (!occursIn(t1.name, t2)) List((t1.name, t2)) else List() case (t1, t2: TNamedType) => if (!occursIn(t2.name, t1)) List((t2.name, t1)) else List() + case (t1 :TSignVar, t2 :TSignedNess) => if (!occursIn(t1.id, t2)) List((t1.id, t2)) else List() + case (t1 :TSignedNess, t2 :TSignVar) => if (!occursIn(t2.id, t1)) List((t2.id, t1)) else List() case (_: TString, _: TString) => List() case (_: TBool, _: TBool) => List() case (_: TVoid, _: TVoid) => List() + case (_ :TSigned, _ :TSigned) => List() + case (_ :TUnsigned, _ :TUnsigned) => List() case (TBool(), TSizedInt(len, u)) if len.asInstanceOf[TBitWidthLen].len == 1 && u.unsigned() => List() case (TSizedInt(len, u), TBool()) if len.asInstanceOf[TBitWidthLen].len == 1 && u.unsigned() => List() - case (TSizedInt(len1, unsigned1), TSizedInt(len2, unsigned)) => unify(len1, len2) //TODO: TSIZEDINT + case (TSizedInt(len1, signed1), TSizedInt(len2, signed2)) => compose_subst(unify(len1, len2), unify(signed1, signed2)) //TODO: TSIZEDINT case (TFun(args1, ret1), TFun(args2, ret2)) if args1.length == args2.length => val s1 = args1.zip(args2).foldLeft[Subst](List())((s, t) => compose_subst(s, unify(apply_subst_typ(s, t._1), apply_subst_typ(s, t._2)))) compose_subst(s1, unify(apply_subst_typ(s1, ret1), apply_subst_typ(s1, ret2))) @@ -403,6 +449,11 @@ object TypeInferenceWrapper unify(elem1, elem2) case (t1: TBitWidthVar, t2: TBitWidth) => if (!occursIn(t1.name, t2)) List((t1.name, t2)) else List() case (t1: TBitWidth, t2: TBitWidthVar) => if (!occursIn(t2.name, t1)) List((t2.name, t1)) else List() + + case (TBitWidthAdd(a1 :TBitWidthLen, a2), TBitWidthLen(len)) => println("UNIFY"); unify(a2, TBitWidthLen(len - a1.len)) + case (TBitWidthAdd(a2, a1 :TBitWidthLen), TBitWidthLen(len)) => println("UNIFY"); unify(a2, TBitWidthLen(len - a1.len)) + case (TBitWidthLen(len), TBitWidthAdd(a1 :TBitWidthLen, a2)) => println("UNIFY"); unify(a2, TBitWidthLen(len - a1.len)) + case (TBitWidthLen(len), TBitWidthAdd(a2, a1 :TBitWidthLen)) => println("UNIFY"); unify(a2, TBitWidthLen(len - a1.len)) case (t1: TBitWidthLen, t2: TBitWidthLen) => if (t2.len < t1.len) throw UnificationError(t1, t2) else List() //TODO: need to figure this out: we want this subtyping rule to throw error when its being used, but not when its a binop!!! case _ => throw UnificationError(a, b) } @@ -412,8 +463,8 @@ object TypeInferenceWrapper def infer(env: TypeEnv, e: Expr): (Subst, Type, TypeEnv) = e match { case EInt(v, base, bits) => - val is_unsigned = if (e.typ.isDefined) e.typ.get.asInstanceOf[TSizedInt].sign else generateSignTypeVar() - (List(), if(e.typ.isDefined) TSizedInt(TBitWidthLen(bits), is_unsigned) else generateTypeVar(), env) + val sign = if (e.typ.isDefined) e.typ.get.asInstanceOf[TSizedInt].sign else generateSignTypeVar() + (List(), if(e.typ.isDefined) TSizedInt(TBitWidthLen(bits), sign) else generateTypeVar(), env) case EString(v) => (List(), TString(), env) case EBool(v) => (List(), TBool(), env) case EUop(op, ex) => val (s, t, env1) = infer(env, ex) @@ -424,15 +475,14 @@ object TypeInferenceWrapper val retTyp = apply_subst_typ(retSubst, retType) (retSubst, retTyp, env1.apply_subst_typeenv(retSubst)) case EBinop(op, e1, e2) => - println(s"inferring for op $op") + println(s"op: $op") val (s1, t1, env1) = infer(env, e1) val (s2, t2, env2) = infer(env1, e2) - println(s"inferred left = $t1; right = $t2") + println(s"inferred left: $t1;\tright: $t2") val retType = generateTypeVar() val subTemp = compose_subst(s1, s2) val t1New = apply_subst_typ(subTemp, t1) val t2New = apply_subst_typ(subTemp, t2) - println(s"t1New: $t1New; t2New: $t2New") val subst = unify(TFun(List(t1New, t2New), retType), binOpExpectedType(op)) val retSubst = compose_many_subst(subTemp, subst) val retTyp = apply_subst_typ(retSubst, retType) @@ -440,7 +490,7 @@ object TypeInferenceWrapper val t2VNew = apply_subst_typ(retSubst, t2New) e1.typ = Some(t1VNew) e2.typ = Some(t2VNew) - println(s"retTyp: $retTyp;\nt1VNew: $t1VNew;\nt2VNew: $t2VNew") + println(s"decided left: $t1VNew;\tright: $t2VNew") (retSubst, retTyp, env2.apply_subst_typeenv(retSubst)) case EMemAccess(mem, index, wmask) => if (!(env(mem).isInstanceOf[TMemType] || env(mem).isInstanceOf[TLockedMemType])) throw UnexpectedType(e.pos, "Memory Access", "TMemtype", env(mem)) @@ -457,11 +507,12 @@ object TypeInferenceWrapper val retSubst = compose_subst(s, subst) val retTyp = apply_subst_typ(retSubst, retType) (retSubst, retTyp, env1.apply_subst_typeenv(retSubst)) - case EBitExtract(num, start, end) => val (s, t, e) = infer(env, num) + case EBitExtract(num, start, end) => + val (s, t, e) = infer(env, num) t match { - case TSizedInt(TBitWidthLen(len), unsigned) if len >= (math.abs(end - start) + 1) => - (s, TSizedInt(TBitWidthLen(math.abs(end - start) + 1), TUnsigned()/*true*/), e) + case TSizedInt(TBitWidthLen(len), signedness) if len >= (math.abs(end - start) + 1) => + (s, TSizedInt(TBitWidthLen(math.abs(end - start) + 1), signedness), e) case b => throw UnificationError(b, TSizedInt(TBitWidthLen(32), TUnsigned()/*true*/)) //TODO Add better error message } //TODO case ETernary(cond, tval, fval) => val (sc, tc, env1) = infer(env, cond) @@ -549,7 +600,7 @@ object TypeInferenceWrapper op match { case "++" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(TBitWidthAdd(b1, b2), s)) - case _ => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(b1, s)) + case _ => TFun(List(TSizedInt(b1, s), TSizedInt(b2, generateSignTypeVar())), TSizedInt(b1, s)) } } diff --git a/src/test/tests/typecheckTests/solutions/type-inference-bit-width-tests.typechecksol b/src/test/tests/typecheckTests/solutions/type-inference-bit-width-tests.typechecksol new file mode 100644 index 00000000..9fb4ec93 --- /dev/null +++ b/src/test/tests/typecheckTests/solutions/type-inference-bit-width-tests.typechecksol @@ -0,0 +1 @@ +Passed \ No newline at end of file diff --git a/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl b/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl new file mode 100644 index 00000000..49580d16 --- /dev/null +++ b/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl @@ -0,0 +1,50 @@ + +def helper1(a: int<32>, b:bool, c: String): int<32> { + d = a + 1; + f = d + 4; + if (f == d) { + return f; + } else { return d; } +} + +def helper2(a:bool, b: bool): bool { + c = a && b; + return c; +} + +pipe test1(input: int<32>)[rf: int<32>[32]] { + a = 6 * 6; + int<32> b = a; + int<32> c = a + b; +} + +pipe test2(input: int<32>)[rf: int<32>[32]] { + a = 1; + b = rf[a]; +} + +pipe test3(input: int<32>)[rf: int<32>[32]] { + a = 1; + int<32> b = a << 4; + int<5> c = a << 4; +} + +pipe test4(input: int<32>)[rf: int<32>[32]] { + a = 15<32>; + int<32> b = a{0:8}; +} + +pipe test5()[] { + a = 6 * 6; + int<6> b = a; + int<64> c = helper1(a, true, "hi"); +} + +pipe test6()[] { + a = 6 - 6; + int<3> b = a; +} + +circuit { + r = memory(int<32>, 32); +} \ No newline at end of file