From 295eccfd0fea2861974ef1645beca28c3c0a0597 Mon Sep 17 00:00:00 2001 From: Charles Sherk Date: Wed, 11 Aug 2021 17:56:41 -0400 Subject: [PATCH] added type inference, autocasting, better parsing error messages. --- .gitignore | 1 + src/main/scala/pipedsl/Main.scala | 10 +- src/main/scala/pipedsl/Parser.scala | 169 ++- src/main/scala/pipedsl/common/Errors.scala | 10 +- .../scala/pipedsl/common/PrettyPrinter.scala | 7 +- src/main/scala/pipedsl/common/Syntax.scala | 106 +- src/main/scala/pipedsl/common/Utilities.scala | 191 ++- .../pipedsl/typechecker/BaseTypeChecker.scala | 4 +- .../typechecker/LockConstraintChecker.scala | 2 +- .../typechecker/TypeInferenceWrapper.scala | 1071 ++++++++--------- .../autocastTests/autocast-basic-pass.pdl | 78 +- .../tests/autocastTests/risc-pipe-spec.pdl | 336 ++++++ .../solutions/risc-pipe-spec.typechecksol | 1 + .../type-inference-bit-width-tests.pdl | 6 +- src/test/tests/risc-pipe/risc-pipe-spec.pdl | 70 +- .../type-inference-bit-width-tests.pdl | 6 +- 16 files changed, 1343 insertions(+), 725 deletions(-) create mode 100644 src/test/tests/autocastTests/risc-pipe-spec.pdl create mode 100644 src/test/tests/autocastTests/solutions/risc-pipe-spec.typechecksol diff --git a/.gitignore b/.gitignore index e2e9fdb9..77ad2164 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ out testOutputs/* miscnotes \#* +tmp/ ### SBT ### diff --git a/src/main/scala/pipedsl/Main.scala b/src/main/scala/pipedsl/Main.scala index 1817bdb6..05641c99 100644 --- a/src/main/scala/pipedsl/Main.scala +++ b/src/main/scala/pipedsl/Main.scala @@ -40,11 +40,11 @@ object Main { throw new RuntimeException(s"File $inputFile does not exist") } val p: Parser = new Parser() - val r = p.parseAll(p.prog, new String(Files.readAllBytes(inputFile.toPath))) - val outputName = FilenameUtils.getBaseName(inputFile.getName) + ".parse" + val prog = p.parseCode(new String(Files.readAllBytes(inputFile.toPath))) + val outputName = FilenameUtils.getBaseName(inputFile.getName) + ".parse" val outputFile = new File(Paths.get(outDir.getPath, outputName).toString) - if (printOutput) new PrettyPrinter(Some(outputFile)).printProgram(r.get) - r.get + if (printOutput) new PrettyPrinter(Some(outputFile)).printProgram(prog) + prog } def interpret(maxIterations:Int, memoryInputs: Seq[String], inputFile: File, outDir: File): Unit = { @@ -67,7 +67,7 @@ object Main { try { val verifProg = AddVerifyValuesPass.run(prog) val canonProg1 = new CanonicalizePass().run(verifProg) - val canonProg = (new TypeInference(autocast)).checkProgram(canonProg1) + val canonProg = new TypeInference(autocast).checkProgram(canonProg1) val basetypes = BaseTypeChecker.check(canonProg, None) val nprog = new BindModuleTypes(basetypes).run(canonProg) TimingTypeChecker.check(nprog, Some(basetypes)) diff --git a/src/main/scala/pipedsl/Parser.scala b/src/main/scala/pipedsl/Parser.scala index a400e81e..60667189 100644 --- a/src/main/scala/pipedsl/Parser.scala +++ b/src/main/scala/pipedsl/Parser.scala @@ -1,7 +1,6 @@ package pipedsl import scala.util.parsing.combinator._ import common.Syntax._ -import common.Utilities._ import common.Locks._ import pipedsl.common.LockImplementation @@ -10,6 +9,45 @@ import scala.util.matching.Regex class Parser extends RegexParsers with PackratParsers { type P[T] = PackratParser[T] + var bitVarCount = 0 + var namedTypeCount = 0 + + /*TODO: is this really the best way of doing this?*/ + var finalFail :Option[ParseResult[Any]] = None + + val debug = false + + /** + * adds debugging info to a parser if [[debug]] is true + */ + def dlog[T](p: => P[T])(msg :String) :P[T] = if (debug) log(p)(msg) else p + + /** + * wrap this around a parser for its failures to be recorded in a way to be + * reported to the user + */ + def failRecord[T](p: => P[T]) :P[T] = Parser {in => + val r = p(in) + r match { + case Failure(msg, _) if msg != "Base Failure" => + finalFail = Some(r) + case _ => () + } + r + } + + def genBitVar() :TBitWidth = + { + bitVarCount += 1 + TBitWidthVar(Id("__PARSER__BITWIDTH__" + bitVarCount)) + } + + def genTypeVar() :TNamedType = + { + namedTypeCount += 1 + TNamedType(Id("__PARSER__NAMED__" + namedTypeCount)) + } + // General parser combinators def braces[T](parser: P[T]): P[T] = "{" ~> parser <~ "}" @@ -32,31 +70,31 @@ class Parser extends RegexParsers with PackratParsers { "\"" ~> "[^\"]*".r <~ "\"" ^^ {n => EString(n)} private def toInt(n: Int, base: Int, bits: Option[Int], isUnsigned: Boolean): EInt = { - val e = EInt(n, base, if (bits.isDefined) bits.get else log2(n)) + val e = EInt(n, base, if (bits.isDefined) bits.get else -1) e.typ = bits match { case Some(b) => Some(TSizedInt(TBitWidthLen(b), SignFactory.ofBool(!isUnsigned))) - case None if isUnsigned => Some(TSizedInt(TBitWidthLen(e.bits), TUnsigned())) - case _ => None + case None if isUnsigned => Some(TSizedInt(genBitVar(), TUnsigned())) + case _ => Some(genTypeVar()) } - // e.typ = Some(TSizedInt(TBitWidthLen(e.bits), unsigned = isUnsigned)) e } // Atoms - lazy val dec: P[EInt] = "u".? ~ "-?[0-9]+".r ~ angular(posint).? ^^ { + lazy val dec: P[EInt] = positioned { "u".? ~ "-?[0-9]+".r ~ angular(posint).? ^^ { case u ~ n ~ bits => toInt(n.toInt, 10, bits, u.isDefined) - } - lazy val hex: P[EInt] = "u".? ~ "0x-?[0-9a-fA-F]+".r ~ angular(posint).? ^^ { + }} + lazy val hex: P[EInt] = positioned { "u".? ~ "0x-?[0-9a-fA-F]+".r ~ angular(posint).? ^^ { case u ~ n ~ bits => toInt(Integer.parseInt(n.substring(2), 16), 16, bits, u.isDefined) - } - lazy val octal: P[EInt] = "u".? ~ "0-?[0-7]+".r ~ angular(posint).? ^^ { + }} + lazy val octal: P[EInt] = positioned { "u".? ~ "0-?[0-7]+".r ~ angular(posint).? ^^ { case u ~ n ~ bits => toInt(Integer.parseInt(n.substring(1), 8), 8, bits, u.isDefined) - } - lazy val binary: P[EInt] = "u".? ~ "0b-?[0-1]+".r ~ angular(posint).? ^^ { + }} + lazy val binary: P[EInt] = positioned { "u".? ~ "0b-?[0-1]+".r ~ angular(posint).? ^^ { case u ~ n ~ bits => toInt(Integer.parseInt(n.substring(2), 2), 2, bits, u.isDefined) - } + }} - lazy val num: P[EInt] = dec | hex | octal | binary + lazy val num: P[EInt] = binary | hex | octal | dec ^^ + { x: EInt => x.typ.get.setPos(x.pos); x } lazy val boolean: P[Boolean] = "true" ^^ { _ => true } | "false" ^^ { _ => false } @@ -66,7 +104,7 @@ class Parser extends RegexParsers with PackratParsers { } lazy val variable: P[EVar] = positioned { - iden ^^ (id => EVar(id)) + iden ^^ (id => { EVar(id)}) } lazy val recAccess: P[Expr] = positioned { @@ -85,17 +123,16 @@ class Parser extends RegexParsers with PackratParsers { expr ~ braces(posint ~ ":" ~ posint) ^^ { case n ~ (e ~ _ ~ s) => EBitExtract(n, s, e) } } - lazy val ternary: P[Expr] = positioned { - parens(expr) ~ "?" ~ expr ~ ":" ~ expr ^^ { case c ~ _ ~ t ~ _ ~ v => ETernary(c, t, v) } - } lazy val cast: P[Expr] = positioned { - "cast" ~> parens(expr ~ "," ~ typ) ^^ { case e ~ _ ~ t => ECast(t, e) } + "cast" ~> parens(expr ~ "," ~ typ) ^^ { case e ~ _ ~ t => ECast(t, e) } } //UOps lazy val not: P[UOp] = positioned("!" ^^ { _ => NotOp() }) + lazy val binv :P[UOp] = positioned("~" ^^ {_ => InvOp() } ) + lazy val mag: P[Expr] = positioned { "mag" ~> parens(expr) ^^ (e => EUop(MagOp(), e)) } @@ -108,7 +145,7 @@ class Parser extends RegexParsers with PackratParsers { } lazy val simpleAtom: P[Expr] = positioned { "call" ~> iden ~ parens(repsep(expr, ",")) ^^ { case i ~ args => ECall(i, args) } | - not ~ expr ^^ { case n ~ e => EUop(n, e) } | + not ~ simpleAtom ^^ { case n ~ e => EUop(n, e) } | neg | cast | mag | @@ -116,16 +153,12 @@ class Parser extends RegexParsers with PackratParsers { memAccess | bitAccess | recAccess | - ternary | recLiteral | - hex | - octal | - binary | - dec | + num | stringVal | boolean ^^ (b => EBool(b)) | iden ~ parens(repsep(expr, ",")) ^^ { case f ~ args => EApp(f, args) } | - variable | + variable | parens(expr) } @@ -162,7 +195,11 @@ class Parser extends RegexParsers with PackratParsers { def parseOp(base: P[Expr], op: P[BOp]): P[Expr] = positioned { - chainl1[Expr](base, op ^^ (op => EBinop(op, _, _))) + chainl1[Expr](base, op ^^ (op => { EBinop(op, _, _)})) + } + + lazy val ternary: P[Expr] = positioned { + parens(expr) ~ "?" ~ nontern ~ ":" ~ nontern ^^ { case c ~ _ ~ t ~ _ ~ v => ETernary(c, t, v) } } lazy val binMul: P[Expr] = parseOp(simpleAtom, mulOps) @@ -175,15 +212,16 @@ class Parser extends RegexParsers with PackratParsers { lazy val binAnd: P[Expr] = parseOp(binBOr, and) lazy val binOr: P[Expr] = parseOp(binAnd, or) lazy val binConcat: P[Expr] = parseOp(binOr, concat) - lazy val expr: Parser[Expr] = positioned(binConcat) + lazy val nontern: Parser[Expr] = positioned(binConcat) + lazy val expr :Parser[Expr] = ternary | nontern + + lazy val lhs: Parser[Expr] = memAccess | variable lazy val simpleCmd: P[Command] = positioned { speccall | - typ.? ~ variable ~ "=" ~ expr ^^ { case t ~ n ~ _ ~ r => n.typ = t; CAssign(n, r, t) } | - typ.? ~ lhs ~ "<-" ~ expr ^^ { case t ~ l ~ _ ~ r => l.typ = t - CRecv(l, r, t) - } | + typ.? ~ variable ~ "=" ~ expr ^^ { case t ~ n ~ _ ~ r => n.typ = t; CAssign(n, r, t) } | + typ.? ~ lhs ~ "<-" ~ expr ^^ { case t ~ l ~ _ ~ r => l.typ = t; CRecv(l, r, t) } | check | resolveSpec | "start" ~> parens(iden) ^^ { i => CLockStart(i) } | @@ -192,10 +230,10 @@ class Parser extends RegexParsers with PackratParsers { "reserve" ~> parens(lockArg ~ ("," ~> lockType).?) ^^ { case i ~ t => CLockOp(i, Reserved, t)} | "block" ~> parens(lockArg) ^^ { i => CLockOp(i, Acquired, None) } | "release" ~> parens(lockArg) ^^ { i => CLockOp(i, Released, None)} | - "print" ~> parens(repsep(expr, ",")) ^^ (e => CPrint(e)) | + "print" ~> parens(repsep(expr, ",")) ^^ (e => { CPrint(e)}) | "return" ~> expr ^^ (e => CReturn(e)) | - "output" ~> expr ^^ (e => COutput(e)) | - expr ^^ (e => CExpr(e)) + "output" ~> expr ^^ (e => { COutput(e)}) | + expr ^^ (e => { CExpr(e)}) } lazy val lockArg: P[LockArg] = positioned { @@ -249,24 +287,36 @@ class Parser extends RegexParsers with PackratParsers { } lazy val conditional: P[Command] = positioned { - "if" ~> parens(expr) ~ block ~ ("else" ~> blockCmd).? ^^ { - case cond ~ cons ~ alt => CIf(cond, cons, if (alt.isDefined) alt.get else CEmpty()) + failRecord("if" ~> parens(expr) ~ block ~ ("else" ~> blockCmd).? ^^ { + case cond ~ cons ~ alt => CIf(cond, cons, alt.getOrElse(CEmpty())) + }) + } + + lazy val parseEmpty: P[Command] = + { + "" ^^ {_ => CEmpty()} } + + def printPos[T](p: => P[T])(tag :String) :P[T] = Parser { + in => + println(tag + in.pos); p(in) } lazy val seqCmd: P[Command] = { - simpleCmd ~ ";" ~ seqCmd ^^ { case c1 ~ _ ~ c2 => CSeq(c1, c2) } | blockCmd ~ seqCmd ^^ { case c1 ~ c2 => CSeq(c1, c2) } | - simpleCmd <~ ";" | blockCmd | "" ^^ { _ => CEmpty() } + simpleCmd ~ ";" ~ seqCmd ^^ { case c1 ~ _ ~ c2 => CSeq(c1, c2) } | + simpleCmd <~ ";" | blockCmd } - lazy val cmd: P[Command] = positioned { + lazy val cmd: P[Command] = failRecord( positioned { seqCmd ~ "---" ~ cmd ^^ { case c1 ~ _ ~ c2 => CTBar(c1, c2) } | - seqCmd - } + "---" ~> cmd ^^ {c => CTBar(CEmpty(), c)} | + cmd <~ "---" ^^ {c => CTBar(c, CEmpty())} | + seqCmd } ) + - lazy val sizedInt: P[Type] = "int" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TSigned() /*unsigned = false*/) } | - "uint" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TUnsigned() /*unsigned = true*/) } + lazy val sizedInt: P[Type] = "int" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TSigned() ) } | + "uint" ~> angular(posint) ^^ { bits => TSizedInt(TBitWidthLen(bits), TUnsigned() ) } lazy val latency: P[Latency.Latency] = "c" ^^ { _ => Latency.Combinational } | @@ -274,7 +324,7 @@ class Parser extends RegexParsers with PackratParsers { "a" ^^ { _ => Latency.Asynchronous } lazy val lat_and_ports: P[(Latency.Latency, Int)] = - latency ~ ((posint).?) ^^ + latency ~ posint.? ^^ { case lat ~ int => int match { @@ -301,16 +351,15 @@ class Parser extends RegexParsers with PackratParsers { ((l1, i1), (l2, i2)) } | "a" ~> intopt ^^ - { - case n => - val v :(Latency.Latency, Int) = (Latency.Asynchronous, n) + { n => + val v: (Latency.Latency, Int) = (Latency.Asynchronous, n) (v, v) - } + } lazy val lockedMemory: P[Type] = sizedInt ~ brackets(posint) ~ (angular(latsnports) - /*((latency ~ intopt) ~ ("," ~> (latency ~ intopt)))*/ ~ parens(iden).?).? ^^ + ~ parens(iden).?).? ^^ { case elem ~ size ~ lats => if (lats.isDefined) { @@ -334,14 +383,7 @@ class Parser extends RegexParsers with PackratParsers { { case elem ~ size ~ ports ~ lock => val mtyp = TMemType(elem, size, Latency.Asynchronous, Latency.Asynchronous, ports, ports) - lock match - { - case lk => - //case Some(lk) => - TLockedMemType(mtyp, None, LockImplementation.getLockImpl(lk)) -// case None => -// TLockedMemType(mtyp, None, LockImplementation.getDefaultLockImpl) - } + TLockedMemType(mtyp, None, LockImplementation.getLockImpl(lock)) } @@ -434,9 +476,20 @@ class Parser extends RegexParsers with PackratParsers { "circuit" ~> braces(cseq) } + lazy val prog: P[Prog] = positioned { fdef.* ~ moddef.* ~ circuit ^^ { case f ~ p ~ c => Prog(f, p, c) } } + + + def parseCode(code :String) :Prog = { + val r = parseAll(prog, code) + r match { + case Success(program, _) => program + case x :Failure => throw new RuntimeException(finalFail.getOrElse(x).toString) + case x :Error => throw new RuntimeException(x.toString()) + } + } } diff --git a/src/main/scala/pipedsl/common/Errors.scala b/src/main/scala/pipedsl/common/Errors.scala index 61a3f54e..c8bbb455 100644 --- a/src/main/scala/pipedsl/common/Errors.scala +++ b/src/main/scala/pipedsl/common/Errors.scala @@ -1,6 +1,6 @@ package pipedsl.common -import scala.util.parsing.input.Position +import scala.util.parsing.input.{NoPosition, Position} import Syntax._ import pipedsl.common.Locks.LockState @@ -182,4 +182,12 @@ object Errors { case class UnificationError(t1: Type, t2: Type) extends RuntimeException( withPos(s"Unable to unify type $t1 and type $t2", t1.pos) ) + + case class TypeMeetError(t1 :Type, t2 :Type) extends RuntimeException( + withPos(s"Cannot generate meet of type $t1 and type $t2", if (t1.pos eq NoPosition) t2.pos else t1.pos) + ) + + case class LackOfConstraints(e :Expr) extends RuntimeException( + withPos(s"Not enough constraints provided to infer types. Found error at $e", e.pos) + ) } diff --git a/src/main/scala/pipedsl/common/PrettyPrinter.scala b/src/main/scala/pipedsl/common/PrettyPrinter.scala index 39fd16ac..628af4d9 100644 --- a/src/main/scala/pipedsl/common/PrettyPrinter.scala +++ b/src/main/scala/pipedsl/common/PrettyPrinter.scala @@ -118,8 +118,8 @@ class PrettyPrinter(output: Option[File]) { case 16 => "0x" + v.toHexString }) + "<" + bits.toString + ">" case Syntax.EBool(v) => v.toString - case Syntax.EUop(op, ex) => op.op + printExprToString(ex) - case Syntax.EBinop(op, e1, e2) => printExprToString(e1) + " " + op.op + " " + printExprToString(e2) + case Syntax.EUop(op, ex) => op.op + "(" + printExprToString(ex) + ")" + case Syntax.EBinop(op, e1, e2) => "(" + printExprToString(e1) + " " + op.op + " " + printExprToString(e2) + ")" case Syntax.ERecAccess(rec, fieldName) => printExprToString(rec) + "." + fieldName case Syntax.ERecLiteral(fields) => "{" + fields.keySet.map(i => i.v + printExprToString(fields(i))).mkString(",") + "}" case Syntax.EMemAccess(mem, index, m) => mem.v + "[" + printExprToString(index) + @@ -158,7 +158,8 @@ class PrettyPrinter(output: Option[File]) { "<" + rlat + rPorts + ", " + wlat + wPorts + ">" case TModType(_, _, _, _) => "TODO MOD TYPE" case TNamedType(name) => name.v - case _ => throw UnexpectedType(t.pos, "pretty printing", "unimplemented", t) + case x => x.toString + //case _ => throw UnexpectedType(t.pos, "pretty printing", "unimplemented", t) } def printStageGraph(name: String, stages: List[PStage]): Unit = { diff --git a/src/main/scala/pipedsl/common/Syntax.scala b/src/main/scala/pipedsl/common/Syntax.scala index 2c0c0133..8c288225 100644 --- a/src/main/scala/pipedsl/common/Syntax.scala +++ b/src/main/scala/pipedsl/common/Syntax.scala @@ -5,6 +5,9 @@ import Security._ import pipedsl.common.LockImplementation.LockInterface import pipedsl.common.Locks.{General, LockGranularity, LockState} import com.microsoft.z3.BoolExpr +import pipedsl.typechecker.Subtypes + + @@ -132,10 +135,85 @@ object Syntax { maybeSpec = from.maybeSpec this } + + //hex code 22C1 + def ⋁(that :Type) :Type = this match + { + case ness: TSignedNess => ness match + { + case TSignVar(id1) => that match + { + case TSignVar(id2) => if (id1.v != id2.v) throw TypeMeetError(this, that) else this + case _ => that + } + case TSigned() => if (that.isInstanceOf[TUnsigned]) throw TypeMeetError(this, that) else this + case TUnsigned() => if (that.isInstanceOf[TSigned]) throw TypeMeetError(this, that) else this + case _ => throw TypeMeetError(this, that) + } + case TSizedInt(len1, sign1) => that match + { + case TSizedInt(len2, sign2) => + TSizedInt((len1 ⋁ len2).asInstanceOf[TBitWidth], (sign1 ⋁ sign2).asInstanceOf[TSignedNess]) + case TNamedType(name) => this + case _ => throw TypeMeetError(this, that) + } + case TFun(args1, ret1) => that match { + case TFun(args2, ret2) => + TFun(args1.zip(args2).map(t1t2 => t1t2._1 ⋁ t1t2._2), ret1 ⋁ ret2) + case _ => throw TypeMeetError(this, that) + } + case TRecType(name, fields) => if (this eq that) this else throw TypeMeetError(this, that) + case TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => + if(this eq that) this else throw TypeMeetError(this, that) + case TModType(inputs, refs, retType, name) => + if(this eq that) this else throw TypeMeetError(this, that) + case TLockedMemType(mem, idSz, limpl) => + if(this eq that) this else throw TypeMeetError(this, that) + case TRequestHandle(mod, rtyp) => + if(this eq that) this else throw TypeMeetError(this, that) + case TNamedType(name) => + that match { + case TNamedType(name2) => if (name.v eq name2.v) this else throw TypeMeetError(this, that) + case _ => that + } + case TMaybe(btyp) => that match { + case TMaybe(btyp2) => TMaybe(btyp ⋁ btyp2) + case _ => throw TypeMeetError(this, that) + } + case width: TBitWidth => that match { + case w2 :TBitWidth => (width, w2) match { + case (TBitWidthLen(l1), TBitWidthLen(l2)) => TBitWidthLen(Math.max(l1, l2)) + case (w1, w2) if w1 === w2 => w1 + case _ => TBitWidthMax(width, w2) + } + case _ => throw TypeMeetError(this, that) + } + case _ => if (this.getClass eq that.getClass) this else throw TypeMeetError(this, that) + } + def ===(that :Any) :Boolean = + { + that match { + case that :Type => Subtypes.areEqual(this, that) + case _ => throw new IllegalArgumentException + } + } + def =!=(that :Any) :Boolean = !(this === that) + + def <<=(that :Any) :Boolean = that match { + case that :Type => Subtypes.isSubtype(this, that) + case _ => throw new IllegalArgumentException + } + + def >>=(that :Any) :Boolean = that match { + case that :Type => Subtypes.isSubtype(that, this) + case _ => throw new IllegalArgumentException + } + + def meet(that :Type) :Type = ⋁(that) } // Types that can be upcast to Ints sealed trait IntType - trait TSignedNess extends Type + sealed trait TSignedNess extends Type { def signed() :Boolean = this match { @@ -160,12 +238,34 @@ object Syntax { case class TUnsigned() extends TSignedNess case class TSignVar(id :Id) extends TSignedNess case class TSizedInt(len: TBitWidth, sign: TSignedNess) extends Type with IntType + { + override def setPos(newpos: Position): TSizedInt.this.type = + { + sign.setPos(newpos) + super.setPos(newpos) + } + } // Use case class instead of case object to get unique positions case class TString() extends Type case class TVoid() extends Type case class TBool() extends Type case class TFun(args: List[Type], ret: Type) extends Type + { + override def setPos(newpos: Position): TFun.this.type = + { + args.foreach(a => a.setPos(newpos)) + ret.setPos(newpos) + super.setPos(newpos) + } + } case class TRecType(name: Id, fields: Map[Id, Type]) extends Type + { + override def setPos(newpos :Position) :TRecType.this.type = + { + fields.foreach(idtp => idtp._2.setPos(newpos)) + super.setPos(newpos) + } + } case class TMemType(elem: Type, addrSize: Int, readLatency: Latency = Latency.Asynchronous, @@ -217,6 +317,7 @@ object Syntax { def NegOp(): NumUOp = NumUOp("-") def NotOp(): BoolUOp = BoolUOp("!") + def InvOp(): BitUOp = BitUOp("~") def MagOp(): NumUOp = NumUOp("abs") def SignOp(): NumUOp = NumUOp("signum") def AndOp(e1: Expr,e2: Expr): EBinop = EBinop(BoolOp("&&", OpConstructor.and), e1,e2) @@ -290,6 +391,9 @@ object Syntax { case class ECall(mod: Id, args: List[Expr]) extends Expr case class EVar(id: Id) extends Expr case class ECast(ctyp: Type, exp: Expr) extends Expr + { + typ = Some(ctyp) + } sealed trait Command extends Positional with SMTPredicate with PortAnnotation with HasCopyMeta diff --git a/src/main/scala/pipedsl/common/Utilities.scala b/src/main/scala/pipedsl/common/Utilities.scala index bf1fe8d5..05eded34 100644 --- a/src/main/scala/pipedsl/common/Utilities.scala +++ b/src/main/scala/pipedsl/common/Utilities.scala @@ -3,7 +3,7 @@ package pipedsl.common import com.microsoft.z3.{AST => Z3AST, BoolExpr => Z3BoolExpr, Context => Z3Context} import com.sun.org.apache.xpath.internal.Expression import pipedsl.common.DAGSyntax.PStage -import pipedsl.common.Errors.UnexpectedCommand +import pipedsl.common.Errors.{LackOfConstraints, UnexpectedCommand} import pipedsl.common.Syntax._ import scala.annotation.tailrec @@ -314,85 +314,150 @@ object Utilities { case None => None } - private def typeMapExpr(e :Expr, f_opt : Option[Type] => Option[Type]) : Unit = + /** + * Maps the function f_opt over the types of e1. + * MODIFIES THE TYPE OF e1 + * @param e1 the expression to map over + * @param f_opt the function to apply to the types + * @return the expression with new types + */ + private def typeMapExpr(e1 :Expr, f_opt : Option[Type] => Option[Type]) : Expr = { - e.typ = f_opt(e.typ) - e match + try + { + e1.typ = f_opt(e1.typ) + } catch + { + case _ :scala.MatchError => + e1 match { + case EInt(v, _, _) => +// println(s"DEFAULTING ON INT. CURRENTLY: ${e1.typ}") + val sign :TSignedNess = if(e1.typ.isDefined && e1.typ.get.isInstanceOf[TSizedInt]) + if(e1.typ.get.asInstanceOf[TSizedInt].sign == TSigned() || e1.typ.get.asInstanceOf[TSizedInt].sign == TUnsigned()) + e1.typ.get.asInstanceOf[TSizedInt].sign + else TSigned() + else TSigned() + e1.typ = Some(TSizedInt(TBitWidthLen(log2(v)), sign)) + case _ => throw LackOfConstraints(e1) + } + } + e1 match { - case EIsValid(ex) => typeMapExpr(ex, f_opt) - case EFromMaybe(ex) => typeMapExpr(ex, f_opt) - case EToMaybe(ex) => typeMapExpr(ex, f_opt) - case EUop(_, ex) => typeMapExpr(ex, f_opt) - case EBinop(_, e1, e2) => typeMapExpr(e1, f_opt); typeMapExpr(e2, f_opt) - case ERecAccess(rec, fieldName) => typeMapId(fieldName, f_opt); typeMapExpr(rec, f_opt); - case ERecLiteral(fields) => - fields.foreach(idex => {typeMapId(idex._1, f_opt); typeMapExpr(idex._2, f_opt)}) - case EMemAccess(mem, index, wmask) => - typeMapId(mem, f_opt); typeMapExpr(index, f_opt) - wmask.fold(())(typeMapExpr(_, f_opt)) - case EBitExtract(num, _, _) => typeMapExpr(num, f_opt) - case ETernary(cond, tval, fval) => - typeMapExpr(cond, f_opt); typeMapExpr(tval, f_opt); typeMapExpr(fval, f_opt) - case EApp(func, args) => - typeMapId(func, f_opt); args.foreach(typeMapExpr(_, f_opt)) - case ECall(mod, args) => - typeMapId(mod, f_opt); args.foreach(typeMapExpr(_, f_opt)) - case EVar(id) => typeMapId(id, f_opt) - case ECast(_, exp) => typeMapExpr(exp, f_opt) + case e@EInt(v, base, bits) => +// println(s"bits: $bits;\ttype: ${e.typ}\t$e") + if(e.typ.isEmpty) + e.typ = Some(TSizedInt(TBitWidthLen(log2(v)), TSigned())) + e.copy(bits = e.typ.get.asInstanceOf[TSizedInt].len.asInstanceOf[TBitWidthLen].len).copyMeta(e) + case e@EIsValid(ex) => e.copy(ex = typeMapExpr(ex, f_opt)).copyMeta(e) + case e@EFromMaybe(ex) => e.copy(ex = typeMapExpr(ex, f_opt)).copyMeta(e) + case e@EToMaybe(ex) => e.copy(ex = typeMapExpr(ex, f_opt)).copyMeta(e) + case e@EUop(_, ex) => e.copy(ex = typeMapExpr(ex, f_opt)).copyMeta(e) + case e@EBinop(_, e1, e2) => + e.copy(e1 = typeMapExpr(e1, f_opt), + e2 = typeMapExpr(e2, f_opt)).copyMeta(e) + case e@ERecAccess(rec, fieldName) => + e.copy(fieldName = typeMapId(fieldName, f_opt), rec = typeMapExpr(rec, f_opt)).copyMeta(e) + case e@ERecLiteral(fields) => + e.copy(fields = fields.map(idex => typeMapId(idex._1, f_opt) -> typeMapExpr(idex._2, f_opt))).copyMeta(e) + case e@EMemAccess(mem, index, wmask) => + e.copy(mem = typeMapId(mem, f_opt), index = typeMapExpr(index, f_opt), + wmask = opt_func(typeMapExpr(_, f_opt))(wmask)).copyMeta(e) + case e@EBitExtract(num, _, _) => e.copy(num = typeMapExpr(num, f_opt)).copyMeta(e) + case e@ETernary(cond, tval, fval) => + e.copy(cond = typeMapExpr(cond, f_opt), + tval = typeMapExpr(tval, f_opt), + fval = typeMapExpr(fval, f_opt)).copyMeta(e) + case e@EApp(func, args) => + e.copy(func = typeMapId(func, f_opt), args = args.map(typeMapExpr(_, f_opt))).copyMeta(e) + case e@ECall(mod, args) => + e.copy(mod = typeMapId(mod, f_opt), args = args.map(typeMapExpr(_, f_opt))).copyMeta(e) + case e@EVar(id) => e.copy(id = typeMapId(id, f_opt)).copyMeta(e) + case e@ECast(tp, exp) => + val ntp = f_opt(Some(tp)).get + val tmp = e.copy(ctyp = ntp, exp = typeMapExpr(exp, f_opt)).copyMeta(e) +// println(s"setting $e type to $ntp") + tmp.typ = Some(ntp) + tmp case expr: CirExpr => expr match { - case CirLock(mem, _, _) => typeMapId(mem, f_opt) - case CirNew(mod, mods) => typeMapId(mod, f_opt) - mods.foreach((i: Id) => typeMapId(i, f_opt)) - case CirCall(mod, args) => typeMapId(mod, f_opt) - args.foreach(typeMapExpr(_, f_opt)) - case _ => () + case e@CirLock(mem, _, _) => e.copy(mem = typeMapId(mem, f_opt)).copyMeta(e) + case e@CirNew(mod, mods) => e.copy(mod = typeMapId(mod, f_opt), + mods = mods.map((i: Id) => typeMapId(i, f_opt))).copyMeta(e) + case e@CirCall(mod, args) => e.copy(mod = typeMapId(mod, f_opt), + args = args.map(typeMapExpr(_, f_opt))).copyMeta(e) + case _ => e1 } - case _ => () + case _ => e1 } } - private def typeMapCmd(c :Command, f_opt :Option[Type] => Option[Type]) :Unit = + + /** + * maps a function over the types of a command + * CHANGED THE TYPES OF THE ORIGINAL COMMAND + * @param c1 the command to map over + * @param f_opt the function from types to types + * @return a new command with the types mapped + */ + private def typeMapCmd(c1 :Command, f_opt :Option[Type] => Option[Type]) :Command = { - c match + c1 match { - case CSeq(c1, c2) => typeMapCmd(c1, f_opt); typeMapCmd(c2, f_opt) - case CTBar(c1, c2) => typeMapCmd(c1, f_opt); typeMapCmd(c2, f_opt) - case CIf(cond, cons, alt) => typeMapExpr(cond, f_opt); typeMapCmd(cons, f_opt); typeMapCmd(alt, f_opt) - case CAssign(lhs, rhs, _) => typeMapExpr(lhs, f_opt); typeMapExpr(rhs, f_opt) - case CRecv(lhs, rhs, _) => typeMapExpr(lhs, f_opt); typeMapExpr(rhs, f_opt) - case CSpecCall(handle, pipe, args) => typeMapExpr(handle, f_opt); typeMapId(pipe, f_opt); args.foreach(typeMapExpr(_, f_opt)) - case CVerify(handle, args, preds) => typeMapExpr(handle, f_opt); args.foreach(typeMapExpr(_, f_opt)); preds.foreach(typeMapExpr(_, f_opt)) - case CInvalidate(handle) => typeMapExpr(handle, f_opt) - case CPrint(args) => args.foreach(typeMapExpr(_, f_opt)) - case COutput(exp) => typeMapExpr(exp, f_opt) - case CReturn(exp) => typeMapExpr(exp, f_opt) - case CExpr(exp) => typeMapExpr(exp, f_opt) - case CLockStart(mod) => typeMapId(mod, f_opt) - case CLockEnd(mod) => typeMapId(mod, f_opt) - case CLockOp(mem, _, _) => typeMapId(mem.id, f_opt); - mem.evar match - { - case Some(value) => typeMapExpr(value, f_opt) - case None => () - } - case CSplit(cases, default) => cases.foreach(cs => - { - typeMapExpr(cs.cond, f_opt); typeMapCmd(cs.body, f_opt) - }) - typeMapCmd(default, f_opt) - case _ => () + case c@CSeq(c1, c2) => c.copy(c1 = typeMapCmd(c1, f_opt), c2 = typeMapCmd(c2, f_opt)).copyMeta(c) + case c@CTBar(c1, c2) => c.copy(c1 = typeMapCmd(c1, f_opt), c2 = typeMapCmd(c2, f_opt)).copyMeta(c) + case c@CIf(cond, cons, alt) => + c.copy(cond = typeMapExpr(cond, f_opt), + cons = typeMapCmd(cons, f_opt), + alt = typeMapCmd(alt, f_opt)).copyMeta(c) + case c@CAssign(lhs, rhs, _) => + c.copy(lhs = typeMapExpr(lhs, f_opt).asInstanceOf[EVar], rhs = typeMapExpr(rhs, f_opt)).copyMeta(c) + case c@CRecv(lhs, rhs, _) => + c.copy(lhs = typeMapExpr(lhs, f_opt), rhs = typeMapExpr(rhs, f_opt)).copyMeta(c) + case c@CSpecCall(handle, pipe, args) => + c.copy(handle = typeMapExpr(handle, f_opt).asInstanceOf[EVar], + pipe = typeMapId(pipe, f_opt), + args = args.map(typeMapExpr(_, f_opt))).copyMeta(c) + case c@CVerify(handle, args, preds) => + c.copy(handle = typeMapExpr(handle, f_opt).asInstanceOf[EVar], + args = args.map(typeMapExpr(_, f_opt)), + preds = preds.map(typeMapExpr(_, f_opt))).copyMeta(c) + case c@CInvalidate(handle) => c.copy(typeMapExpr(handle, f_opt).asInstanceOf[EVar]).copyMeta(c) + case c@CPrint(args) => c.copy(args = args.map(typeMapExpr(_, f_opt))).copyMeta(c) + case c@COutput(exp) => c.copy(exp = typeMapExpr(exp, f_opt)).copyMeta(c) + case c@CReturn(exp) => c.copy(exp = typeMapExpr(exp, f_opt)).copyMeta(c) + case c@CExpr(exp) => c.copy(exp = typeMapExpr(exp, f_opt)).copyMeta(c) + case c@CLockStart(mod) => c.copy(mod = typeMapId(mod, f_opt)).copyMeta(c) + case c@CLockEnd(mod) => c.copy(mod = typeMapId(mod, f_opt)).copyMeta(c) + case c@CLockOp(mem@LockArg(id, evar), _, _) => + c.copy(mem = mem.copy(id = typeMapId(id, f_opt), evar = evar match { + case Some(e) => Some(typeMapExpr(e, f_opt).asInstanceOf[EVar]) + case None => None + })).copyMeta(c) + case c@CSplit(cases, default) => + c.copy(cases = cases.map(cs => cs.copy(cond = typeMapExpr(cs.cond, f_opt), body = typeMapCmd(cs.body, f_opt))), + default = typeMapCmd(default, f_opt)).copyMeta(c) + case _ => c1 } } - private def typeMapId(i: Id, f_opt: Option[Type] => Option[Type]):Unit = + /** + * Maps a function over the type of an Id + * @param i the id to map over + * @param f_opt the function to apply to i.typ + * @return a COPY of i with a (potentially) new type + */ + private def typeMapId(i: Id, f_opt: Option[Type] => Option[Type]) :Id = { - i.typ = f_opt(i.typ) + val ni = i.copy() + ni.typ = f_opt(i.typ) + ni } - def typeMapFunc(fun :FuncDef, f_opt :Option[Type] => Option[Type]) :Unit = typeMapCmd(fun.body, f_opt) - def typeMapModule(mod :ModuleDef, f_opt :Option[Type] => Option[Type]) :Unit = typeMapCmd(mod.body, f_opt) + def typeMapFunc(fun :FuncDef, f_opt :Option[Type] => Option[Type]) :FuncDef = + fun.copy(body = typeMapCmd(fun.body, f_opt)) + def typeMapModule(mod :ModuleDef, f_opt :Option[Type] => Option[Type]) :ModuleDef = + mod.copy(body = typeMapCmd(mod.body, f_opt)).copyMeta(mod) def typeMap(p: Prog, f: Type => Type) :Unit= { diff --git a/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala b/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala index 0c581965..a09225f0 100644 --- a/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala +++ b/src/main/scala/pipedsl/typechecker/BaseTypeChecker.scala @@ -302,7 +302,7 @@ object BaseTypeChecker extends TypeChecks[Id, Type] { val (idxt, _) = checkExpression(mem.evar.get, tenv, None) idxt match { case TSizedInt(l, TUnsigned()/*true*/) if l.asInstanceOf[TBitWidthLen].len == memt.addrSize => tenv - case _ => throw UnexpectedType(mem.pos, "lock operation", "ubit<" + memt.addrSize + ">", idxt) + case _ => throw UnexpectedType(mem.pos, s"lock operation $c", "ubit<" + memt.addrSize + ">", idxt) } } } @@ -521,7 +521,7 @@ object BaseTypeChecker extends TypeChecks[Id, Type] { id.typ = Some(t) (t, tenv) } else { - throw UnexpectedType(id.pos, "variable", "variable type set to new conflicting type", t) + throw UnexpectedType(id.pos, "variable", s"variable type set to new conflicting type : ${tenv(id)}", t) } case None if (tenv.get(id).isDefined || defaultType.isEmpty) => id.typ = Some(tenv(id)); (tenv(id), tenv) case None => id.typ = defaultType; (defaultType.get, tenv) diff --git a/src/main/scala/pipedsl/typechecker/LockConstraintChecker.scala b/src/main/scala/pipedsl/typechecker/LockConstraintChecker.scala index 4db6bd10..c3046990 100644 --- a/src/main/scala/pipedsl/typechecker/LockConstraintChecker.scala +++ b/src/main/scala/pipedsl/typechecker/LockConstraintChecker.scala @@ -246,7 +246,7 @@ class LockConstraintChecker(lockMap: Map[Id, Set[LockArg]], lockGranularityMap: //The locks are still in Read mode, so if it is possible to be in //Write mode, it is error. val check = solver.check() - solver.reset + solver.reset() check } diff --git a/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala b/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala index db2045de..a9694382 100644 --- a/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala +++ b/src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala @@ -1,131 +1,99 @@ package pipedsl.typechecker -import pipedsl.common.Errors.{ArgLengthMismatch, MalformedLockTypes, TooManyPorts, UnexpectedReturn, UnexpectedSubtype, UnexpectedType, UnificationError} -import pipedsl.common.Syntax._ -import pipedsl.typechecker.Subtypes.{areEqual, canCast, isSubtype} import pipedsl.common.Errors +import pipedsl.common.Errors._ import pipedsl.common.Syntax.Latency.{Asynchronous, Combinational, Sequential} +import pipedsl.common.Syntax._ import pipedsl.common.Utilities.{defaultReadPorts, defaultWritePorts, opt_func, typeMapFunc, typeMapModule} import pipedsl.typechecker.Environments.{Environment, TypeEnv} - +import pipedsl.typechecker.Subtypes.{canCast, isSubtype} import scala.collection.mutable object TypeInferenceWrapper { type Subst = List[(Id, Type)] type bool = Boolean + def apply_subst_typ(subst: Subst, t: Type): Type = subst.foldLeft[Type](t)((t1, s) => subst_into_type(s._1, s._2, t1)) private def subst_into_type(typevar: Id, toType: Type, inType: Type): Type = inType match { - case t@TMemType(elem, _, _, _, _, _) => - t.copy(elem = subst_into_type(typevar, toType, elem)).setPos(t.pos) - case t1@TLockedMemType(t2@TMemType(elem, _, _, _, _, _), _, _) => - t1.copy(t2.copy(elem = subst_into_type(typevar, toType, elem))).setPos(t1.pos) - case TSizedInt(len, signedness) => - TSizedInt(subst_into_type(typevar, toType, len).asInstanceOf[TBitWidth], - subst_into_type(typevar, toType, signedness).asInstanceOf[TSignedNess]).setPos(inType.pos) + case t: TMemType => t.copy(elem = subst_into_type(typevar, toType, t.elem)).setPos(t.pos) + case t1@TLockedMemType(t2: TMemType, _, _) => t1.copy(t2.copy(elem = subst_into_type(typevar, toType, t2.elem))).setPos(t1.pos) + case TSizedInt(len, signedness) => TSizedInt(subst_into_type(typevar, toType, len).asInstanceOf[TBitWidth], subst_into_type(typevar, toType, signedness).asInstanceOf[TSignedNess]).setPos(inType.pos) case TString() => inType case TBool() => inType case TVoid() => inType case TSigned() => inType case TUnsigned() => inType - case TFun(args, ret) => - TFun(args.map(a => subst_into_type(typevar, toType, a)), - subst_into_type(typevar, toType, ret)).setPos(inType.pos) - case TNamedType(name) => - if (name == typevar) toType else inType - case TSignVar(name) => - if (name == typevar) toType else inType - case TModType(inputs, refs, retType, name) => - TModType(inputs.map(i => subst_into_type(typevar, toType, i)), - refs.map(r => subst_into_type(typevar, toType, r)), retType match + case TFun(args, ret) => TFun(args.map(a => subst_into_type(typevar, toType, a)), subst_into_type(typevar, toType, ret)).setPos(inType.pos) + case TNamedType(name) => if (name == typevar) toType else inType + case TSignVar(name) => if (name == typevar) toType else inType + case TModType(inputs, refs, retType, name) => TModType(inputs.map(i => subst_into_type(typevar, toType, i)), refs.map(r => subst_into_type(typevar, toType, r)), retType match { case Some(value) => Some(subst_into_type(typevar, toType, value)) case None => None }, name).setPos(inType.pos) - case t: TBitWidth => - t match + case t: TBitWidth => t match { case TBitWidthVar(name) => if (name == typevar) toType else inType - case TBitWidthLen(len) => inType + case TBitWidthLen(_) => inType case TBitWidthAdd(b1, b2) => val t1 = TBitWidthAdd(subst_into_type(typevar, toType, b1).asInstanceOf[TBitWidth], subst_into_type(typevar, toType, b2).asInstanceOf[TBitWidth]) (t1.b1, t1.b2) match { case (TBitWidthLen(len1), TBitWidthLen(len2)) => TBitWidthLen(len1 + len2).setPos(inType.pos) case _ => t1.setPos(inType.pos) } - case TBitWidthMax(b1, b2) => - val t1 = TBitWidthMax(subst_into_type(typevar, toType, b1).asInstanceOf[TBitWidth], subst_into_type(typevar, toType, b2).asInstanceOf[TBitWidth]) + case TBitWidthMax(b1, b2) => val t1 = TBitWidthMax(subst_into_type(typevar, toType, b1).asInstanceOf[TBitWidth], subst_into_type(typevar, toType, b2).asInstanceOf[TBitWidth]) (t1.b1, t1.b2) match { case (TBitWidthLen(len1), TBitWidthLen(len2)) => TBitWidthLen(len1.max(len2)).setPos(inType.pos) - case (TBitWidthLen(len), _ :TBitWidthVar) => TBitWidthLen(len).setPos(inType.pos) - case (_ :TBitWidthVar, TBitWidthLen(len)) => TBitWidthLen(len).setPos(inType.pos) + case (TBitWidthLen(len), _: TBitWidthVar) => TBitWidthLen(len).setPos(inType.pos) + case (_: TBitWidthVar, TBitWidthLen(len)) => TBitWidthLen(len).setPos(inType.pos) + case (TBitWidthVar(v1), TBitWidthVar(v2)) if v1.v == v2.v => TBitWidthVar(v1) case _ => t1.setPos(inType.pos) } } } - - private def type_subst_map(t :Type, tp_mp :mutable.HashMap[Id, Type]) :Type = t match + private def type_subst_map(t: Type, tp_mp: mutable.HashMap[Id, Type]): Type = t match { case TSignVar(nm) => tp_mp.get(nm) match { case Some(value) => type_subst_map(value, tp_mp) - case None => t } - case sz@TSizedInt(len, sign) => -// println(len) -// println(sign) - val tmp = sz.copy(len = type_subst_map(len, tp_mp).copyMeta(sz).asInstanceOf[TBitWidth], - - sign = type_subst_map(sign, tp_mp).asInstanceOf[TSignedNess]) -// println(tmp.len) -// println(tmp.sign) + case sz@TSizedInt(len, sign) => val tmp = sz.copy(len = type_subst_map(len, tp_mp).copyMeta(sz).asInstanceOf[TBitWidth], sign = type_subst_map(sign, tp_mp).asInstanceOf[TSignedNess]) tmp - case f@TFun(args, ret) => - f.copy(args = args.map(type_subst_map(_, tp_mp)), ret = type_subst_map(ret, tp_mp)).copyMeta(f) - case r@TRecType(name, fields) => r.copy(fields = fields.map((idtp) => (idtp._1, type_subst_map(idtp._2, tp_mp)))) - case m@TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => - m.copy(elem = type_subst_map(elem, tp_mp)) - case m@TModType(inputs, refs, retType, name) => - m.copy(inputs = inputs.map(type_subst_map(_, tp_mp)), refs = refs.map(type_subst_map(_, tp_mp))) - case l@TLockedMemType(mem, idSz, limpl) => - l.copy(mem = type_subst_map(mem, tp_mp).asInstanceOf[TMemType]) + case f@TFun(args, ret) => f.copy(args = args.map(type_subst_map(_, tp_mp)), ret = type_subst_map(ret, tp_mp)).copyMeta(f) + case r@TRecType(_, fields) => r.copy(fields = fields.map(idtp => (idtp._1, type_subst_map(idtp._2, tp_mp)))).copyMeta(r) + case m: TMemType => m.copy(elem = type_subst_map(m.elem, tp_mp)).copyMeta(m) + case m@TModType(inputs, refs, _, _) => m.copy(inputs = inputs.map(type_subst_map(_, tp_mp)), refs = refs.map(type_subst_map(_, tp_mp))).copyMeta(m) + case l@TLockedMemType(mem, _, _) => l.copy(mem = type_subst_map(mem, tp_mp).asInstanceOf[TMemType]).copyMeta(l) case TNamedType(name) => tp_mp.get(name) match { case Some(value) => type_subst_map(value, tp_mp) - case None => t } case m@TMaybe(btyp) => m.copy(btyp = type_subst_map(btyp, tp_mp)) - case TBitWidthAdd(b1, b2) => - val tmp = TBitWidthLen(type_subst_map(b1, tp_mp).asInstanceOf[TBitWidthLen].len + - type_subst_map(b2, tp_mp).asInstanceOf[TBitWidthLen].len) -// println(s"mapping $b1 + $b2 to ${tmp.len}") + case TBitWidthAdd(b1, b2) => val tmp = TBitWidthLen(type_subst_map(b1, tp_mp).asInstanceOf[TBitWidthLen].len + type_subst_map(b2, tp_mp).asInstanceOf[TBitWidthLen].len) tmp case TBitWidthMax(b1, b2) => - TBitWidthLen(Math.max(type_subst_map(b1, tp_mp).asInstanceOf[TBitWidthLen].len, - type_subst_map(b2, tp_mp).asInstanceOf[TBitWidthLen].len)) + + + TBitWidthLen(Math.max(type_subst_map(b1, tp_mp).asInstanceOf[TBitWidthLen].len, type_subst_map(b2, tp_mp).asInstanceOf[TBitWidthLen].len)) case TBitWidthVar(name) => tp_mp.get(name) match { case Some(value) => type_subst_map(value, tp_mp) - case None => t } case _ => t } - - - class TypeInference(autocast :bool) + class TypeInference(autocast: bool) { - private var currentDef: Id = Id("-invalid-") private var counter = 0 def checkProgram(p: Prog): Prog = { - val (funcEnvs, newFuncs) = - p.fdefs.foldLeft[(Environment[Id, Type], List[FuncDef])]((TypeEnv(), List.empty[FuncDef]))((envNlst :(Environment[Id, Type], List[FuncDef]), f) => + val (funcEnvs, newFuncs) = p.fdefs.foldLeft[(Environment[Id, Type], List[FuncDef])]((TypeEnv(), List.empty[FuncDef]))((envNlst: (Environment[Id, Type], List[FuncDef]), f) => { val env = envNlst._1 val lst = envNlst._2 @@ -147,77 +115,13 @@ object TypeInferenceWrapper def checkCircuit(c: Circuit, tenv: Environment[Id, Type]): (Environment[Id, Type], Circuit) = c match { - case cs@CirSeq(c1, c2) => - val (e1, nc1) = checkCircuit(c1, tenv) - val (e2, nc2) = checkCircuit(c2, e1) - (e2, cs.copy(c1 = nc1, c2 = nc2)) - case cc@CirConnect(name, ce) => - val (t, env2, nce) = checkCirExpr(ce, tenv) - (env2.add(name, t), cc.copy(c = nce)) - case ces@CirExprStmt(ce) => - val (_, nv, nce) = checkCirExpr(ce, tenv) - (nv, ces.copy(ce = nce)) - } - - private def checkCirExpr(c: CirExpr, tenv: Environment[Id, Type]): (Type, Environment[Id, Type], CirExpr) = c match - { - case CirMem(elemTyp, addrSize, numPorts) => - if (numPorts > 2) throw TooManyPorts(c.pos, 2) - val mtyp = TMemType(elemTyp, addrSize, Asynchronous, Asynchronous, numPorts, numPorts) - c.typ = Some(mtyp) - (mtyp, tenv, c) - case CirLock(mem, impl, _) => - val mtyp: TMemType = tenv(mem).matchOrError(mem.pos, "lock instantiation", "memory") - { case c: TMemType => c } - mem.typ = Some(mtyp) - val newtyp = TLockedMemType(mtyp, None, impl) - c.typ = Some(newtyp) - (newtyp, tenv, c) - case CirLockMem(elemTyp, addrSize, impl, _, numPorts) => - val mtyp = TMemType(elemTyp, addrSize, Asynchronous, Asynchronous, numPorts, numPorts) - val ltyp = TLockedMemType(mtyp, None, impl) - c.typ = Some(ltyp) - (ltyp, tenv, c) - case CirRegFile(elemTyp, addrSize) => val mtyp = TMemType(elemTyp, addrSize, Combinational, Sequential, defaultReadPorts, defaultWritePorts) - c.typ = Some(mtyp) - (mtyp, tenv, c) - case CirLockRegFile(elemTyp, addrSize, impl, szParams) => - val mtyp = TMemType(elemTyp, addrSize, Combinational, Sequential, defaultReadPorts, defaultWritePorts) - val idsz = szParams.headOption - val ltyp = TLockedMemType(mtyp, idsz, impl) - c.typ = Some(ltyp) - (ltyp, tenv, c) - case CirNew(mod, mods) => - val mtyp = tenv(mod) - mtyp match - { - case TModType(_, refs, _, _) => - if (refs.length != mods.length) - throw ArgLengthMismatch(c.pos, mods.length, refs.length) - refs.zip(mods).foreach - { case (reftyp, mname) => - if (!isSubtype(tenv(mname), reftyp)) - throw UnexpectedSubtype(mname.pos, mname.toString, reftyp, tenv(mname)) } - (mtyp, tenv, c) - case x => throw UnexpectedType(c.pos, c.toString, "Module Type", x) - } - case cc@CirCall(mod, inits) => - val mtyp = tenv(mod) - mtyp match - { - case TModType(ityps, _, _, _) => - if (ityps.length != inits.length) - throw ArgLengthMismatch(c.pos, inits.length, ityps.length) - val fixed_args = ityps.zip(inits).map - { case (expectedT, arg) => - val (subst, atyp, aenv, a_fixed) = infer(tenv.asInstanceOf[TypeEnv], arg) - if (!isSubtype(atyp, expectedT)) - throw UnexpectedSubtype(arg.pos, arg.toString, expectedT, atyp) - a_fixed - } - (mtyp, tenv, cc.copy(args = fixed_args)) - case x => throw UnexpectedType(c.pos, c.toString, "Module Type", x) - } + case cs@CirSeq(c1, c2) => val (e1, nc1) = checkCircuit(c1, tenv) + val (e2, nc2) = checkCircuit(c2, e1) + (e2, cs.copy(c1 = nc1, c2 = nc2).setPos(cs.pos)) + case cc@CirConnect(name, ce) => val (t, env2, nce) = checkCirExpr(ce, tenv) + (env2.add(name, t), cc.copy(c = nce).setPos(cc.pos)) + case ces@CirExprStmt(ce) => val (_, nv, nce) = checkCirExpr(ce, tenv) + (nv, ces.copy(ce = nce).setPos(ces.pos)) } def checkModule(m: ModuleDef, env: TypeEnv): (Environment[Id, Type], ModuleDef) = @@ -229,8 +133,8 @@ object TypeInferenceWrapper val pipeEnv = m.modules.zip(modTypes).foldLeft[Environment[Id, Type]](inEnv)((env, m) => env.add(m._1.name, m._2)) val (fixed_cmd, _, subst) = checkCommand(m.body, pipeEnv.asInstanceOf[TypeEnv], List()) val hash = mutable.HashMap.from(subst) - typeMapModule(m, opt_func(type_subst_map(_, hash))) - (modEnv, m.copy(body = fixed_cmd).copyMeta(m)) + val newMod = typeMapModule(m.copy(body = fixed_cmd).copyMeta(m), opt_func(type_subst_map(_, hash))) + (modEnv, newMod) } def checkFunc(f: FuncDef, env: TypeEnv): (Environment[Id, Type], FuncDef) = @@ -241,485 +145,557 @@ object TypeInferenceWrapper val inEnv = f.args.foldLeft[Environment[Id, Type]](funEnv)((env, a) => env.add(a.name, a.typ)) val (fixed_cmd, _, subst) = checkCommand(f.body, inEnv.asInstanceOf[TypeEnv], List()) val hash = mutable.HashMap.from(subst) - typeMapFunc(f, opt_func(type_subst_map(_, hash))) -// println("SUBSTITUTIONS:\n" + subst) - //TODO this is filler - (funEnv, f.copy(body = fixed_cmd).setPos(f.pos)) + val newFunc = typeMapFunc(f.copy(body = fixed_cmd).setPos(f.pos), opt_func(type_subst_map(_, hash))) + (funEnv, newFunc) } - private def replaceNamedType(t: Type, tenv: TypeEnv): Type = t match + /** INVARIANTS + * Transforms the argument sub by composing any additional substitution + * Transforms the argument env by subbing in the returned substitution and adding any relevatn variables */ + def checkCommand(c: Command, env: TypeEnv, sub: Subst): (Command, TypeEnv, Subst) = c match { - case TNamedType(name) => tenv(name) - case _ => t - } - - /*INVARIANTS - Transforms the argument sub by composing any additional substitution - Transforms the argument env by subbing in the returned substitution and adding any relevatn variables */ - def checkCommand(c: Command, env: TypeEnv, sub: Subst): (Command, TypeEnv, Subst) = + case CLockOp(mem, _, _) => env(mem.id) match { - /*println(c) - println(c.pos)*/ - c match + case tm: TMemType => mem.evar match { - case CLockOp(mem, op, lockType) => //test basic first - env(mem.id) match - { - case t@TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => mem.evar match - { - case Some(value) => val (s, t, e, _) = infer(env, value) - val tempSub = compose_subst(sub, s) - val tNew = apply_subst_typ(tempSub, t) - val newSub = compose_subst(tempSub, unify(tNew, TSizedInt(TBitWidthLen(addrSize), TUnsigned()))._1) - (c, e.apply_subst_typeenv(newSub), newSub) - case None => (c, env, sub) - } - case TLockedMemType(TMemType(_, addrSize, _, _, _, _), idSz, limpl) => mem.evar match - { - case Some(value) => val (s, t, e, _) = infer(env, value) - val tempSub = compose_subst(sub, s) - val tNew = apply_subst_typ(tempSub, t) - val newSub = compose_subst(tempSub, unify(tNew, TSizedInt(TBitWidthLen(addrSize), TUnsigned()/*unsigned = true*/))._1) - (c, e.apply_subst_typeenv(newSub), newSub) - case None => (c, env, sub) - } - case TModType(inputs, refs, retType, name) => if (mem.evar.isDefined) throw MalformedLockTypes("Pipeline modules can not have specific locks") - (c, env, sub) - case b => throw UnexpectedType(mem.id.pos, c.toString, "Memory or Module Type", b) - } - case CEmpty() => (c, env, sub) - case cr@CReturn(exp) => val (s, t, e, fixed) = infer(env, exp) + case Some(value) => val (s, t, e, _) = infer(env, value) val tempSub = compose_subst(sub, s) val tNew = apply_subst_typ(tempSub, t) - val funT = env(currentDef) - funT match - { - case TFun(args, ret) => - val (subst, cast) = unify(tNew, ret) - /*TODO insert cast*/ - val retSub = compose_subst(tempSub, subst) - (cr.copy(exp = fixed).copyMeta(cr), e.apply_subst_typeenv(retSub), retSub) - case b => throw UnexpectedType(c.pos, c.toString, funT.toString, b) - } - case CLockStart(mod) => if (!(env(mod).isInstanceOf[TMemType] || env(mod).isInstanceOf[TModType] || env(mod).isInstanceOf[TLockedMemType])) - { - throw UnexpectedType(mod.pos, c.toString, "Memory or Module Type", env(mod)) - } - (c, env, sub) - case i@CIf(cond, cons, alt) => - val (condS, condT, env1, fixed_cond) = infer(env, cond) - val tempSub = compose_subst(sub, condS) + val newSub = compose_subst(tempSub, unify(tNew, TSizedInt(TBitWidthLen(tm.addrSize), TUnsigned()))._1) + (c, e.apply_subst_typeenv(newSub), newSub) + case None => (c, env, sub) + } + case TLockedMemType(tm: TMemType, _, _) => mem.evar match + { + case Some(value) => val (s, t, e, _) = infer(env, value) + val tempSub = compose_subst(sub, s) + val tNew = apply_subst_typ(tempSub, t) + val newSub = compose_subst(tempSub, unify(tNew, TSizedInt(TBitWidthLen(tm.addrSize), TUnsigned()))._1) + (c, e.apply_subst_typeenv(newSub), newSub) + case None => (c, env, sub) + } + case _: TModType => if (mem.evar.isDefined) throw MalformedLockTypes("Pipeline modules can not have specific locks") + (c, env, sub) + case b => throw UnexpectedType(mem.id.pos, c.toString, "Memory or Module Type", b) + } + case CEmpty() => (c, env, sub) + case cr@CReturn(exp) => val (s, t, e, fixed) = infer(env, exp) + val tempSub = compose_subst(sub, s) + val tNew = apply_subst_typ(tempSub, t) + val funT = env(currentDef) + funT match + { + case TFun(_, ret) => val (subst, cast) = unify(tNew, ret) + val more_fixed = if (cast) + { + val tmp = ECast(ret, fixed) + tmp.typ = Some(tmp.ctyp) + tmp + } else fixed + val retSub = compose_subst(tempSub, subst) + (cr.copy(exp = more_fixed).copyMeta(cr), e.apply_subst_typeenv(retSub), retSub) + case b => throw UnexpectedType(c.pos, c.toString, funT.toString, b) + } + case CLockStart(mod) => if (!(env(mod).isInstanceOf[TMemType] || env(mod).isInstanceOf[TModType] || env(mod).isInstanceOf[TLockedMemType])) + { + throw UnexpectedType(mod.pos, c.toString, "Memory or Module Type", env(mod)) + } + (c, env, sub) + case i@CIf(cond, cons, alt) => val (condS, condT, env1, fixed_cond) = infer(env, cond) + val tempSub = compose_subst(sub, condS) + val condTyp = apply_subst_typ(tempSub, condT) + val newSub = compose_subst(tempSub, unify(condTyp, TBool())._1) + val newEnv = env1.apply_subst_typeenv(newSub) + val (fixed_cons, consEnv, consSub) = checkCommand(cons, newEnv, newSub) + val newEnv2 = newEnv.apply_subst_typeenv(consSub) + val (fixed_alt, altEnv, altSub) = checkCommand(alt, newEnv2, consSub) + (i.copy(cond = fixed_cond, cons = fixed_cons, alt = fixed_alt).copyMeta(i), consEnv.apply_subst_typeenv(altSub).intersect(altEnv).asInstanceOf[TypeEnv], altSub) + case CLockEnd(mod) => if (!(env(mod).isInstanceOf[TMemType] || env(mod).isInstanceOf[TModType] || env(mod).isInstanceOf[TLockedMemType])) + { + throw UnexpectedType(mod.pos, c.toString, "Memory or Module Type", env(mod)) + } + (c, env, sub) + case cs@CSplit(cases, default) => var (fixed_def, runningEnv, runningSub) = checkCommand(default, env, sub) + var fixed_cases: List[CaseObj] = List() + for (c <- cases) + { + val (condS, condT, env1, fixed_cond) = infer(env, c.cond) + val tempSub = compose_subst(runningSub, condS) val condTyp = apply_subst_typ(tempSub, condT) val newSub = compose_subst(tempSub, unify(condTyp, TBool())._1) val newEnv = env1.apply_subst_typeenv(newSub) - val (fixed_cons, consEnv, consSub) = checkCommand(cons, newEnv, newSub) - val newEnv2 = newEnv.apply_subst_typeenv(consSub) - val (fixed_alt, altEnv, altSub) = checkCommand(alt, newEnv2, consSub) //TODO: Intersection of both envs? - (i.copy(cond = fixed_cond, cons = fixed_cons, alt = fixed_alt).copyMeta(i), - consEnv.apply_subst_typeenv(altSub).intersect(altEnv).asInstanceOf[TypeEnv], altSub) - case CLockEnd(mod) => if (!(env(mod).isInstanceOf[TMemType] || env(mod).isInstanceOf[TModType] || env(mod).isInstanceOf[TLockedMemType])) + val (fixed_bod, caseEnv, caseSub) = checkCommand(c.body, newEnv, newSub) + fixed_cases = fixed_cases :+ c.copy(cond = fixed_cond, body = fixed_bod).setPos(c.pos) + runningSub = caseSub + runningEnv = runningEnv.apply_subst_typeenv(runningSub).intersect(caseEnv).asInstanceOf[TypeEnv] + } + (cs.copy(cases = fixed_cases, default = fixed_def).copyMeta(cs), runningEnv, runningSub) + case ce@CExpr(exp) => val (s, _, e, fixed) = infer(env, exp) + val retS = compose_subst(sub, s) + (ce.copy(exp = fixed).copyMeta(ce), e.apply_subst_typeenv(retS), retS) /*TODO is this right? I don't know. Maybe someone else does :)*/ + case CCheckSpec(_) => (c, env, sub) + case _: CVerify => (c, env, sub) + case CInvalidate(_) => (c, env, sub) + case ct@CTBar(c1, c2) => val (fixed1, e, s) = checkCommand(c1, env, sub) + val (fixed2, e2, s2) = checkCommand(c2, e, s) + (ct.copy(c1 = fixed1, c2 = fixed2).copyMeta(ct), e2, s2) + case CPrint(_) => (c, env, sub) + case _: CSpecCall => (c, env, sub) + case co@COutput(exp) => val (s, t, e, fixed) = infer(env, exp) + val tempSub = compose_subst(sub, s) + val tNew = apply_subst_typ(tempSub, t) + val modT = env(currentDef) + modT match + { + case tm: TModType => tm.retType match + { + case Some(value) => val (subst, cast) = unify(tNew, value) + val fixed1 = if (cast) ECast(value, fixed) else fixed + val retSub = compose_subst(tempSub, subst) + (co.copy(exp = fixed1).copyMeta(co), e.apply_subst_typeenv(retSub), retSub) + case None => (co.copy(exp = fixed).copyMeta(co), e.apply_subst_typeenv(tempSub), tempSub) + } + case b => throw UnexpectedType(c.pos, c.toString, modT.toString, b) + } + case cr@CRecv(lhs, rhs, typ) => val (slhs, tlhs, lhsEnv, lhsFixed) = lhs match + { + case EVar(_) => (List(), typ.getOrElse(generateTypeVar()), env, lhs) + case _ => infer(env, lhs) + } + val (srhs, trhs, rhsEnv, rhsFixed) = infer(lhsEnv, rhs) + val tempSub = compose_many_subst(sub, slhs, srhs) + val lhstyp = apply_subst_typ(tempSub, tlhs) + val rhstyp = apply_subst_typ(tempSub, trhs) + lhs.typ = Some(lhstyp) + rhs.typ = Some(rhstyp) + val (s1, cast) = unify(rhstyp, lhstyp) + val rhsFixed1 = if (cast) ECast(lhstyp, rhsFixed) else rhsFixed + val sret = compose_many_subst(tempSub, s1, typ match + { case Some(value) => val (s2, _) = unify(lhstyp, value) + val (s3, _) = unify(rhstyp, value) + compose_subst(s2, s3) + case None => List() + }) + val newEnv = lhs match + { + case EVar(id) => rhsEnv.add(id, tlhs) + case _ => rhsEnv + } + (cr.copy(lhs = lhsFixed, rhs = rhsFixed1).copyMeta(cr), newEnv.asInstanceOf[TypeEnv].apply_subst_typeenv(sret), sret) + case ca@CAssign(lhs, rhs, typ) => val (slhs, tlhs, lhsEnv) = (List(), typ.getOrElse(generateTypeVar()), env) + val (srhs, trhs, rhsEnv, rhsFixed) = infer(lhsEnv, rhs) + val tempSub = compose_many_subst(sub, slhs, srhs) + val lhstyp = apply_subst_typ(tempSub, tlhs) + val rhstyp = apply_subst_typ(tempSub, trhs) + val (s1, cast) = unify(rhstyp, lhstyp) + val rhsFixed1 = if (cast) ECast(tlhs, rhsFixed) else rhsFixed + val sret = compose_many_subst(tempSub, s1, typ match + { case Some(value) => val (s2, _) = unify(lhstyp, value) + val (s3, _) = unify(rhstyp, value) + compose_subst(s2, s3) + case None => List() + }) + val newEnv = lhs match + { + case EVar(id) => rhsEnv.add(id, tlhs) + case _ => rhsEnv + } + lhs.typ = Some(lhstyp) + rhs.typ = Some(rhstyp) + (ca.copy(rhs = rhsFixed1).copyMeta(ca), newEnv.asInstanceOf[TypeEnv].apply_subst_typeenv(sret), sret) + case cs@CSeq(c1, c2) => val (fixed1, e1, s) = checkCommand(c1, env, sub) + val (fixed2, e2, s2) = checkCommand(c2, e1, s) + (cs.copy(c1 = fixed1, c2 = fixed2).copyMeta(cs), e2, s2) + case _: InternalCommand => (c, env, sub) + } + + /** for subtyping bit widths, t1 is the subtype, t2 is the supertype. so t2 is the expected private */ + def unify(a: Type, b: Type, binop: bool = false): (Subst, bool) = + { + (a, b) match + { + case (t1: TNamedType, t2) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) + case (t1, t2: TNamedType) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) + case (t1: TSignVar, t2: TSignedNess) => if (!occursIn(t1.id, t2)) (List((t1.id, t2)), false) else (List(), false) + case (t1: TSignedNess, t2: TSignVar) => if (!occursIn(t2.id, t1)) (List((t2.id, t1)), false) else (List(), false) + case (_: TString, _: TString) => (List(), false) + case (_: TBool, _: TBool) => (List(), false) + case (_: TVoid, _: TVoid) => (List(), false) + case (_: TSigned, _: TSigned) => (List(), false) + case (_: TUnsigned, _: TUnsigned) => (List(), false) + case (TBool(), TSizedInt(len, u)) if len.asInstanceOf[TBitWidthLen].len == 1 && u.unsigned() => (List(), false) + case (TSizedInt(len, u), TBool()) if len.asInstanceOf[TBitWidthLen].len == 1 && u.unsigned() => (List(), false) + case (TSizedInt(len1, signed1), TSizedInt(len2, signed2)) => val (s1, c1) = unify(len1, len2, binop) + val (s2, c2) = unify(signed1, signed2, binop) + (compose_subst(s1, s2), c1 || c2) + case (TFun(args1, ret1), TFun(args2, ret2)) if args1.length == args2.length => val (s1, c1) = args1.zip(args2).foldLeft[(Subst, bool)]((List(), false))((sc, t) => { - throw UnexpectedType(mod.pos, c.toString, "Memory or Module Type", env(mod)) - } - (c, env, sub) - case cs@CSplit(cases, default) => //TODO - var (fixed_def, runningEnv, runningSub) = checkCommand(default, env, sub) - var fixed_cases : List[CaseObj] = List() - for (c <- cases) + val (unif_s, unif_c) = unify(apply_subst_typ(sc._1, t._1), apply_subst_typ(sc._1, t._2), binop) + (compose_subst(sc._1, unif_s), unif_c || sc._2) + }) + val (s2, c2) = unify(apply_subst_typ(s1, ret1), apply_subst_typ(s1, ret2), binop) + (compose_subst(s1, s2), c1 || c2) + case (TModType(input1, refs1, retType1, name1), TModType(input2, refs2, retType2, name2)) => //TODO: Name?\ if (name1 != name2) throw UnificationError(a, b) + if (name1 != name2) throw UnificationError(a, b) + val (s1, c1) = input1.zip(input2).foldLeft[(Subst, bool)]((List(), false))((sc, t) => { - val (condS, condT, env1, fixed_cond) = infer(env, c.cond) - val tempSub = compose_subst(runningSub, condS) - val condTyp = apply_subst_typ(tempSub, condT) - val newSub = compose_subst(tempSub, unify(condTyp, TBool())._1) - /*apply substitution to original environmnet, which you will use to check the body*/ - val newEnv = env1.apply_subst_typeenv(newSub) - val (fixed_bod, caseEnv, caseSub) = checkCommand(c.body, newEnv, newSub) - fixed_cases = fixed_cases :+ c.copy(cond = fixed_cond, body = fixed_bod).setPos(c.pos) - runningSub = caseSub - runningEnv = runningEnv.apply_subst_typeenv(runningSub).intersect(caseEnv).asInstanceOf[TypeEnv] - } - (cs.copy(cases = fixed_cases, default = fixed_def).copyMeta(cs), runningEnv, runningSub) - case ce@CExpr(exp) => val (s, t, e, fixed) = infer(env, exp) - val retS = compose_subst(sub, s) - (ce.copy(exp = fixed).copyMeta(ce), e.apply_subst_typeenv(retS), retS) //TODO - /*TODO is this right? I don't know. Maybe someone else does :)*/ - case CCheckSpec(isBlocking) => (c, env, sub) - case CVerify(handle, args, preds) => (c, env, sub) - case CInvalidate(handle) => (c, env, sub) /*ODOT*/ - case ct@CTBar(c1, c2) => val (fixed1, e, s) = checkCommand(c1, env, sub) - val (fixed2, e2, s2) = checkCommand(c2, e, s) - (ct.copy(c1 = fixed1, c2 = fixed2).copyMeta(ct), e2, s2) - case CPrint(evar) => (c, env, sub) - case CSpecCall(handle, pipe, args) => (c, env, sub) - case co@COutput(exp) => - val (s, t, e, fixed) = infer(env, exp) - val tempSub = compose_subst(sub, s) - val tNew = apply_subst_typ(tempSub, t) - val modT = env(currentDef) - modT match - { - case TModType(inputs, refs, retType, name) => retType match + val (unif_s, unif_c) = unify(apply_subst_typ(sc._1, t._1), apply_subst_typ(sc._1, t._2)) + (compose_subst(sc._1, unif_s), unif_c || sc._2) + }) + val (s2, c2) = refs1.zip(refs2).foldLeft[(Subst, bool)](s1, c1)((sc, t) => { - case Some(value) => - val (subst, cast) = unify(tNew, value) - /*TODO insert cast*/ - val retSub = compose_subst(tempSub, subst) - (co.copy(exp = fixed).copyMeta(co), e.apply_subst_typeenv(retSub), retSub) - case None => (co.copy(exp = fixed).copyMeta(co), e.apply_subst_typeenv(tempSub), tempSub) - } - case b => throw UnexpectedType(c.pos, c.toString, modT.toString, b) - } //How to check wellformedness with the module body - case cr@CRecv(lhs, rhs, typ) => - val (slhs, tlhs, lhsEnv, lhsFixed) = lhs match + val (unif_s, unif_c) = unify(apply_subst_typ(sc._1, t._1), apply_subst_typ(sc._1, t._2)) + (compose_subst(sc._1, unif_s), sc._2 || unif_c) + }) + val (s3, c3) = (retType1, retType2) match { - case EVar(id) => (List(), typ.getOrElse(generateTypeVar()), env, lhs) - case _ => infer(env, lhs) + case (Some(t1: Type), Some(t2: Type)) => unify(apply_subst_typ(s2, t1), apply_subst_typ(s2, t2)) + case (None, None) => (List(), false) + case _ => throw UnificationError(a, b) } - val (srhs, trhs, rhsEnv, rhsFixed) = infer(lhsEnv, rhs) - val tempSub = compose_many_subst(sub, slhs, srhs) - val lhstyp = apply_subst_typ(tempSub, tlhs) - val rhstyp = apply_subst_typ(tempSub, trhs) - lhs.typ = Some(lhstyp) - rhs.typ = Some(rhstyp) - val (s1, cast) = unify(rhstyp, lhstyp) - /*TODO insert cast*/ - val sret = compose_many_subst(tempSub, s1, typ match - { case Some(value) => - val (s2, c2) = unify(lhstyp, value) - val (s3, c3) = unify(rhstyp, value) - /*TODO insert cast*/ - compose_subst(s2, s3) - case None => List() - }) - val newEnv = lhs match + (compose_subst(s2, s3), c2 || c3) + case (TMemType(elem1, addr1, rl1, wl1, rp1, wp1), TMemType(elem2, addr2, rl2, wl2, rp2, wp2)) => if (addr1 != addr2 || rl1 != rl2 || wl1 != wl2 || rp1 < rp2 || wp1 < wp2) throw UnificationError(a, b) + unify(elem1, elem2) + case (t1: TBitWidthVar, t2: TBitWidth) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) + case (t1: TBitWidth, t2: TBitWidthVar) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) + case (TBitWidthAdd(a1: TBitWidthLen, a2), TBitWidthLen(len)) => unify(a2, TBitWidthLen(len - a1.len), binop) + case (TBitWidthAdd(a2, a1: TBitWidthLen), TBitWidthLen(len)) => unify(a2, TBitWidthLen(len - a1.len), binop) + case (TBitWidthLen(len), TBitWidthAdd(a1: TBitWidthLen, a2)) => unify(a2, TBitWidthLen(len - a1.len), binop) + case (TBitWidthLen(len), TBitWidthAdd(a2, a1: TBitWidthLen)) => unify(a2, TBitWidthLen(len - a1.len), binop) + case (TBitWidthAdd(a1: TBitWidthVar, a2: TBitWidthVar), TBitWidthLen(len)) => if (len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") + val (s1, c1) = unify(a1, TBitWidthLen(len / 2), binop) + val (s2, c2) = unify(a2, TBitWidthLen(len / 2), binop) + (compose_subst(s1, s2), c1 || c2) + case (TBitWidthLen(len), TBitWidthAdd(a1: TBitWidthVar, a2: TBitWidthVar)) => if (len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") + val (s1, c1) = unify(a1, TBitWidthLen(len / 2), binop) + val (s2, c2) = unify(a2, TBitWidthLen(len / 2), binop) + (compose_subst(s1, s2), c1 || c2) + case (t1: TBitWidthLen, t2: TBitWidthLen) => if (autocast) { - case EVar(id) => rhsEnv.add(id, tlhs) - case _ => rhsEnv - } - (cr.copy(lhs = lhsFixed, rhs = rhsFixed).copyMeta(cr), newEnv.asInstanceOf[TypeEnv].apply_subst_typeenv(sret), sret) - case ca@CAssign(lhs, rhs, typ) => - val (slhs, tlhs, lhsEnv) = (List(), typ.getOrElse(generateTypeVar()), env) - val (srhs, trhs, rhsEnv, rhsFixed) = infer(lhsEnv, rhs) - val tempSub = compose_many_subst(sub, slhs, srhs) - val lhstyp = apply_subst_typ(tempSub, tlhs) - val rhstyp = apply_subst_typ(tempSub, trhs) - val (s1, cast) = unify(rhstyp, lhstyp) - /*TODO insert cast*/ - val sret = compose_many_subst(tempSub, s1, typ match - { case Some(value) => - val (s2, c2) = unify(lhstyp, value) - val (s3, c3) = unify(rhstyp, value) - /*TODO insert cast*/ - compose_subst(s2, s3) - case None => List() - }) - val newEnv = lhs match + if (binop) (List(), t1.len != t2.len) else if (t2.len < t1.len) throw UnificationError(t1, t2) else (List(), t1.len != t2.len) + } else { - case EVar(id) => rhsEnv.add(id, tlhs) - case _ => rhsEnv + if (t2.len != t1.len) throw UnificationError(t1, t2) else (List(), false) } - lhs.typ = Some(lhstyp) - rhs.typ = Some(rhstyp) - (ca.copy(rhs = rhsFixed).copyMeta(ca), newEnv.asInstanceOf[TypeEnv].apply_subst_typeenv(sret), sret) - case cs@CSeq(c1, c2) => - val (fixed1, e1, s) = checkCommand(c1, env, sub) - val (fixed2, e2, s2) = checkCommand(c2, e1, s) - (cs.copy(c1 = fixed1, c2 = fixed2).copyMeta(cs), e2, s2) - case _:InternalCommand => (c, env, sub) + case _ => throw UnificationError(a, b) } } - private def generateTypeVar(): TNamedType = - { - counter += 1 - TNamedType(Id("__TYPE__" + counter)) - } + private def checkCirExpr(c: CirExpr, tenv: Environment[Id, Type]): (Type, Environment[Id, Type], CirExpr) = c match + { + case CirMem(elemTyp, addrSize, numPorts) => if (numPorts > 2) throw TooManyPorts(c.pos, 2) + val mtyp = TMemType(elemTyp, addrSize, Asynchronous, Asynchronous, numPorts, numPorts) + c.typ = Some(mtyp) + (mtyp, tenv, c) + case CirLock(mem, impl, _) => val mtyp: TMemType = tenv(mem).matchOrError(mem.pos, "lock instantiation", "memory") + { case c: TMemType => c } + mem.typ = Some(mtyp) + val newtyp = TLockedMemType(mtyp, None, impl) + c.typ = Some(newtyp) + (newtyp, tenv, c) + case CirLockMem(elemTyp, addrSize, impl, _, numPorts) => val mtyp = TMemType(elemTyp, addrSize, Asynchronous, Asynchronous, numPorts, numPorts) + val ltyp = TLockedMemType(mtyp, None, impl) + c.typ = Some(ltyp) + (ltyp, tenv, c) + case CirRegFile(elemTyp, addrSize) => val mtyp = TMemType(elemTyp, addrSize, Combinational, Sequential, defaultReadPorts, defaultWritePorts) + c.typ = Some(mtyp) + (mtyp, tenv, c) + case CirLockRegFile(elemTyp, addrSize, impl, szParams) => val mtyp = TMemType(elemTyp, addrSize, Combinational, Sequential, defaultReadPorts, defaultWritePorts) + val idsz = szParams.headOption + val ltyp = TLockedMemType(mtyp, idsz, impl) + c.typ = Some(ltyp) + (ltyp, tenv, c) + case CirNew(mod, mods) => val mtyp = tenv(mod) + mtyp match + { + case TModType(_, refs, _, _) => if (refs.length != mods.length) throw ArgLengthMismatch(c.pos, mods.length, refs.length) + refs.zip(mods).foreach + { case (reftyp, mname) => if (!isSubtype(tenv(mname), reftyp)) throw UnexpectedSubtype(mname.pos, mname.toString, reftyp, tenv(mname)) } + (mtyp, tenv, c) + case x => throw UnexpectedType(c.pos, c.toString, "Module Type", x) + } + case cc@CirCall(mod, inits) => val mtyp = tenv(mod) + mtyp match + { + case TModType(ityps, _, _, _) => if (ityps.length != inits.length) throw ArgLengthMismatch(c.pos, inits.length, ityps.length) + val fixed_args = ityps.zip(inits).map + { case (expectedT, arg) => val (_, atyp, _, a_fixed) = infer(tenv.asInstanceOf[TypeEnv], arg) + if (!isSubtype(atyp, expectedT)) throw UnexpectedSubtype(arg.pos, arg.toString, expectedT, atyp) + a_fixed + } + (mtyp, tenv, cc.copy(args = fixed_args)) + case x => throw UnexpectedType(c.pos, c.toString, "Module Type", x) + } + } - private def generateBitWidthTypeVar(): TBitWidthVar = - { - counter += 1 - TBitWidthVar(Id("__BITWIDTH__" + counter)) - } + private def replaceNamedType(t: Type, tenv: TypeEnv): Type = t match + { + case TNamedType(name) => tenv(name) + case _ => t + } - private def generateSignTypeVar(): TSignedNess = - { - counter += 1 - TSignVar(Id("__SIGN__" + counter)) - } - private def occursIn(name: Id, b: Type): bool = b match { - case TSizedInt(len, unsigned) => occursIn(name, len) + case TSizedInt(len, _) => occursIn(name, len) case TString() => false case TVoid() => false case TBool() => false case TFun(args, ret) => args.foldLeft[bool](false)((b, t) => b || occursIn(name, t)) || occursIn(name, ret) - case TRecType(name, fields) => false - case TMemType(elem, addrSize, readLatency, writeLatency, readPorts, writePorts) => false - case TModType(inputs, refs, retType, name) => false + case _: TRecType => false + case _: TMemType => false + case _: TModType => false case TNamedType(name1) => name1 == name case TSignVar(name1) => name1 == name case TBitWidthVar(name1) => name1 == name case TBitWidthAdd(b1, b2) => occursIn(name, b1) || occursIn(name, b2) case TBitWidthMax(b1, b2) => occursIn(name, b1) || occursIn(name, b2) - case TBitWidthLen(len) => false - case _:TSignedNess => false + case TBitWidthLen(_) => false + case _: TSignedNess => false } - private def apply_subst_substs(subst: Subst, inSubst: Subst): Subst = inSubst.foldLeft[Subst](List())((s, c) => s :+ ((c._1, apply_subst_typ(subst, c._2)))) private def compose_subst(sub1: Subst, sub2: Subst): Subst = sub1 ++ apply_subst_substs(sub1, sub2) private def compose_many_subst(subs: Subst*): Subst = subs.foldRight[Subst](List())((s1, s2) => compose_subst(s1, s2)) - /*for subtyping bit widths, t1 is the subtype, t2 is the supertype. so t2 is the expected private*/ - def unify(a: Type, b: Type): (Subst, bool) = { - println(s"unifying $a and $b") - (a, b) match + private def binOpExpectedType(b: BOp): Type = { - case (t1: TNamedType, t2) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) - case (t1, t2: TNamedType) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) - case (t1 :TSignVar, t2 :TSignedNess) => if (!occursIn(t1.id, t2)) (List((t1.id, t2)), false) else (List(), false) - case (t1 :TSignedNess, t2 :TSignVar) => if (!occursIn(t2.id, t1)) (List((t2.id, t1)), false) else (List(), false) - case (_: TString, _: TString) => (List(), false) - case (_: TBool, _: TBool) => (List(), false) - case (_: TVoid, _: TVoid) => (List(), false) - case (_ :TSigned, _ :TSigned) => (List(), false) - case (_ :TUnsigned, _ :TUnsigned) => (List(), false) - case (TBool(), TSizedInt(len, u)) if len.asInstanceOf[TBitWidthLen].len == 1 && u.unsigned() => (List(), false) - case (TSizedInt(len, u), TBool()) if len.asInstanceOf[TBitWidthLen].len == 1 && u.unsigned() => (List(), false) - case (TSizedInt(len1, signed1), TSizedInt(len2, signed2)) => - val (s1, c1) = unify(len1, len2); val (s2, c2) = unify(signed1, signed2) - (compose_subst(s1, s2), c1 || c2) - case (TFun(args1, ret1), TFun(args2, ret2)) if args1.length == args2.length => - val (s1, c1) = args1.zip(args2).foldLeft[(Subst, bool)]((List(), false))((sc, t) => - { - val (unif_s, unif_c) = unify(apply_subst_typ(sc._1, t._1), apply_subst_typ(sc._1, t._2)) - (compose_subst(sc._1, unif_s), unif_c || sc._2) - }) - val (s2, c2) = unify(apply_subst_typ(s1, ret1), apply_subst_typ(s1, ret2)) - (compose_subst(s1, s2), c1 || c2) - case (TModType(input1, refs1, retType1, name1), TModType(input2, refs2, retType2, name2)) => //TODO: Name?\ if (name1 != name2) throw UnificationError(a, b) - val (s1, c1) = input1.zip(input2).foldLeft[(Subst, bool)]((List(), false))((sc, t) => + val tmp = b match + { + case EqOp(_) => val t = generateTypeVar() // TODO: This can be anything? + TFun(List(t, t), TBool()) + case CmpOp(_) => val t = generateTypeVar() // TODO: This can be anything? + TFun(List(t, t), TBool()) + case _: BoolOp => TFun(List(TBool(), TBool()), TBool()) + case NumOp(op, _) => val b1 = generateBitWidthTypeVar() + val b2 = generateBitWidthTypeVar() + val s = generateSignTypeVar() + op match { - val (unif_s, unif_c) = unify(apply_subst_typ(sc._1, t._1), apply_subst_typ(sc._1, t._2)) - (compose_subst(sc._1, unif_s), unif_c || sc._2) - }) - val (s2, c2) = refs1.zip(refs2).foldLeft[(Subst, bool)](s1, c1)((sc, t) => + case "/" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(b1, s)) + case "*" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(TBitWidthAdd(b1, b2), s)) + case "$*" => TFun(List(TSizedInt(b1, s), TSizedInt(b1, s)), TSizedInt(b1, s)) + case "+" => TFun(List(TSizedInt(b1, s), TSizedInt(b1, s)), TSizedInt(b1, s)) + case "-" => TFun(List(TSizedInt(b1, s), TSizedInt(b1, s)), TSizedInt(b1, s)) + case "%" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(b1, s)) + } + case BitOp(op, _) => val b1 = generateBitWidthTypeVar() + val b2 = generateBitWidthTypeVar() + val s = generateSignTypeVar() + op match { - val (unif_s, unif_c) = unify(apply_subst_typ(sc._1, t._1), apply_subst_typ(sc._1, t._2)) - (compose_subst(sc._1, unif_s), sc._2 || unif_c) - }) - val (s3, c3) = (retType1, retType2) match - { - case (Some(t1: Type), Some(t2: Type)) => unify(apply_subst_typ(s2, t1), apply_subst_typ(s2, t2)) - case (None, None) => (List(), false) - case _ => throw UnificationError(a, b) - } - (compose_subst(s2, s3), c2 || c3) - case (TMemType(elem1, addr1, rl1, wl1, rp1, wp1), TMemType(elem2, addr2, rl2, wl2, rp2, wp2)) => if (addr1 != addr2 || rl1 != rl2 || wl1 != wl2 || rp1 < rp2 || wp1 < wp2) throw UnificationError(a, b) - unify(elem1, elem2) - case (t1: TBitWidthVar, t2: TBitWidth) => if (!occursIn(t1.name, t2)) (List((t1.name, t2)), false) else (List(), false) - case (t1: TBitWidth, t2: TBitWidthVar) => if (!occursIn(t2.name, t1)) (List((t2.name, t1)), false) else (List(), false) - - case (TBitWidthAdd(a1 :TBitWidthLen, a2), TBitWidthLen(len)) => unify(a2, TBitWidthLen(len - a1.len)) - case (TBitWidthAdd(a2, a1 :TBitWidthLen), TBitWidthLen(len)) => unify(a2, TBitWidthLen(len - a1.len)) - case (TBitWidthLen(len), TBitWidthAdd(a1 :TBitWidthLen, a2)) => unify(a2, TBitWidthLen(len - a1.len)) - case (TBitWidthLen(len), TBitWidthAdd(a2, a1 :TBitWidthLen)) => unify(a2, TBitWidthLen(len - a1.len)) - case (TBitWidthAdd(a1 :TBitWidthVar, a2 :TBitWidthVar), TBitWidthLen(len)) => - if(len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") - val (s1, c1) = unify(a1, TBitWidthLen(len / 2)) - val (s2, c2) = unify(a2, TBitWidthLen(len / 2)) - (compose_subst(s1, s2), c1 || c2) - case (TBitWidthLen(len), TBitWidthAdd(a1 :TBitWidthVar, a2 :TBitWidthVar)) => - if(len % 2 != 0) throw new RuntimeException(s"result of bitwidthadd should be even. Consider multiply. Found $len") - val (s1, c1) = unify(a1, TBitWidthLen(len / 2)) - val (s2, c2) = unify(a2, TBitWidthLen(len / 2)) - (compose_subst(s1, s2), c1 || c2) - case (t1: TBitWidthLen, t2: TBitWidthLen) => - if(autocast) - {if (t2.len < t1.len) throw UnificationError(t1, t2) else (List(), t1.len != t2.len)} //TODO: need to figure this out: we want this subtyping rule to throw error when its being used, but not when its a binop!!! - else - {if (t2.len != t1.len) throw UnificationError(t1, t2) else (List(), false)} - case _ => throw UnificationError(a, b) + case "++" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(TBitWidthAdd(b1, b2), s)) + case _ => TFun(List(TSizedInt(b1, s), TSizedInt(b2, generateSignTypeVar())), TSizedInt(b1, s)) + } + } + tmp.setPos(b.pos) } - } - //Updating the type environment with the new substitution whenever you generate one allows errors to be found :D - //The environment returned is guaratneed to already have been substituted into with the returned substitution private - def infer(env: TypeEnv, e: Expr): (Subst, Type, TypeEnv, Expr) = { - //println(e) - e match + private def generateTypeVar(): TNamedType = { - case EInt(v, base, bits) => - val sign = if (e.typ.isDefined) e.typ.get.asInstanceOf[TSizedInt].sign else generateSignTypeVar() - (List(), if(e.typ.isDefined) TSizedInt(TBitWidthLen(bits), sign) else generateTypeVar(), env, e) - case EString(v) => (List(), TString(), env, e) - case EBool(v) => (List(), TBool(), env, e) - case u@EUop(op, ex) => - val (s, t, env1, fixed) = infer(env, ex) - val retType = generateTypeVar() - val tNew = apply_subst_typ(s, t) - val (subst, cast) = unify(TFun(List(tNew), retType), uOpExpectedType(op)) - /*TODO insert cast*/ - val retSubst = compose_subst(s, subst) - val retTyp = apply_subst_typ(retSubst, retType) - (retSubst, retTyp, env1.apply_subst_typeenv(retSubst), u.copy(ex = fixed).copyMeta(u)) - case b@EBinop(op, e1, e2) => - println(s"op: $op") - val (s1, t1, env1, fixed1) = infer(env, e1) - val (s2, t2, env2, fixed2) = infer(env1, e2) - println(s"inferred left: $t1;\tright: $t2") - val retType = generateTypeVar() - val subTemp = compose_subst(s1, s2) - val t1New = apply_subst_typ(subTemp, t1) - val t2New = apply_subst_typ(subTemp, t2) - val (subst, cast) = unify(TFun(List(t1New, t2New), retType), binOpExpectedType(op)) + counter += 1 + TNamedType(Id("__TYPE__" + counter)) + } - /*TODO insert cast*/ - val retSubst = compose_many_subst(subTemp, subst) - val retTyp = apply_subst_typ(retSubst, retType) - val t1VNew = apply_subst_typ(retSubst, t1New) - val t2VNew = apply_subst_typ(retSubst, t2New) - println(s"decided left: $t1VNew;\tright: $t2VNew") - println(s"needs cast: $cast") - fixed1.typ = Some(t1VNew) - fixed2.typ = Some(t2VNew) - // println(s"decided left: $t1VNew;\tright: $t2VNew") - (retSubst, retTyp, env2.apply_subst_typeenv(retSubst), b.copy(e1 = fixed1, e2 = fixed2).copyMeta(b)) - case m@EMemAccess(mem, index, wmask) => if (!(env(mem).isInstanceOf[TMemType] || env(mem).isInstanceOf[TLockedMemType])) throw UnexpectedType(e.pos, "Memory Access", "TMemtype", env(mem)) + private def generateBitWidthTypeVar(): TBitWidthVar = + { + counter += 1 + TBitWidthVar(Id("__BITWIDTH__" + counter)) + } - val retType = generateTypeVar() - val (s, t, env1, fixed_idx) = infer(env, index) - val tTemp = apply_subst_typ(s, t) - val memt = env1(mem) match - { - case t@TMemType(_, _, _, _, _, _) => t - case TLockedMemType(t, _, _) => t - case _ => throw UnexpectedType(e.pos, "Memory Access", "TMemtype", env1(mem)) - } - val (subst, cast) = unify(TFun(List(tTemp), retType), getMemAccessType(memt)) - /*TODO insert cast*/ - val retSubst = compose_subst(s, subst) - val retTyp = apply_subst_typ(retSubst, retType) - (retSubst, retTyp, env1.apply_subst_typeenv(retSubst), m.copy(index = fixed_idx).copyMeta(m)) - case b@EBitExtract(num, start, end) => - val (s, t, e, fixed_num) = infer(env, num) - t match - { - case TSizedInt(TBitWidthLen(len), signedness) if len >= (math.abs(end - start) + 1) => - (s, TSizedInt(TBitWidthLen(math.abs(end - start) + 1), signedness), e, b.copy(num = fixed_num).copyMeta(b)) - case b => throw UnificationError(b, TSizedInt(TBitWidthLen(32), TUnsigned())) //TODO Add better error message - } //TODO - case trn@ETernary(cond, tval, fval) => - val (sc, tc, env1, fixed_cond) = infer(env, cond) - val (st, tt, env2, fixed_tval) = infer(env1, tval) - val (sf, tf, env3, fixed_fval) = infer(env2, fval) - val substSoFar = compose_many_subst(sc, st, sf) - val tcNew = apply_subst_typ(substSoFar, tc) - val ttNew = apply_subst_typ(substSoFar, tt) - val tfNew = apply_subst_typ(substSoFar, tf) - val (substc, _) = unify(tcNew, TBool()) - val (subst, cast) = unify(ttNew, tfNew) //TODO this will fail with bad subtyping stuff going on currently bc for sized ints, we don't care which one is bigger right - /*TODO insert cast*/ - val retSubst = compose_many_subst(sc, st, sf, substc, subst) - val retType = apply_subst_typ(retSubst, ttNew) - (retSubst, retType, env3.apply_subst_typeenv(retSubst), trn.copy(cond = fixed_cond, tval = fixed_tval, fval = fixed_fval).copyMeta(trn)) - case ap@EApp(func, args) => val expectedType = env(func) - val retType = generateTypeVar() - var runningEnv: TypeEnv = env - var runningSubst: Subst = List() - var typeList: List[Type] = List() - var argList: List[Expr] = List() - for (a <- args) + /** Updating the type environment with the new substitution whenever you generate one allows errors to be found :D + * The environment returned is guaratneed to already have been substituted into with the returned substitution */ + private def infer(env: TypeEnv, e: Expr): (Subst, Type, TypeEnv, Expr) = + { + e match + { + case _: EInt => val newvar = generateTypeVar() + if (e.typ.isEmpty) e.typ = Some(newvar) + (List(), e.typ.getOrElse(generateTypeVar()), env, e) + case EString(_) => (List(), TString(), env, e) + case EBool(_) => (List(), TBool(), env, e) + case u@EUop(op, ex) => val (s, t, env1, fixed) = infer(env, ex) + val retType = generateTypeVar() + val tNew = apply_subst_typ(s, t) + val (subst, cast) = unify(TFun(List(tNew), retType), uOpExpectedType(op)) + val retSubst = compose_subst(s, subst) + val retTyp = apply_subst_typ(retSubst, retType) + val fixed1 = if (cast) ECast(retTyp, fixed) else fixed + (retSubst, retTyp, env1.apply_subst_typeenv(retSubst), u.copy(ex = fixed1).copyMeta(u)) + case b@EBinop(op, e1, e2) => + + val (s1, t1, env1, fixed1) = infer(env, e1) + val (s2, t2, env2, fixed2) = infer(env1, e2) + val retType = generateTypeVar() + val subTemp = compose_subst(s1, s2) + val t1New = apply_subst_typ(subTemp, t1) + val t2New = apply_subst_typ(subTemp, t2) + val (subst, _) = unify(TFun(List(t1New, t2New), retType), binOpExpectedType(op), binop = true) + val retSubst = compose_many_subst(subTemp, subst) + val retTyp = apply_subst_typ(retSubst, retType) + val t1VNew = apply_subst_typ(retSubst, t1New) + val t2VNew = apply_subst_typ(retSubst, t2New) + val (castL, castR) = binOpTypesFromRet(op, retTyp, t1VNew, t2VNew) + val moreFixed1 = castL match { - val (sub, typ, env1, fixed_a) = infer(runningEnv, a) - runningSubst = compose_subst(runningSubst, sub) - typeList = typeList :+ typ - argList = argList :+ fixed_a - runningEnv = env1 + case Some(tp) => assert(autocast) + ECast(tp, fixed1).copyMeta(fixed1) + case _ => fixed1 } - typeList = typeList.map(t => apply_subst_typ(runningSubst, t)) - val (subst, cast) = unify(TFun(typeList, retType), expectedType) - /*TODO insert cast*/ - val retSubst = compose_subst(runningSubst, subst) - val retEnv = runningEnv.apply_subst_typeenv(retSubst) - val retTyp = apply_subst_typ(retSubst, retType) - (retSubst, retTyp, retEnv, ap.copy(args = argList).copyMeta(ap)) - case ca@ECall(mod, args) => if (!env(mod).isInstanceOf[TModType]) throw UnexpectedType(e.pos, "Module Call", "TModType", env(mod)) - val expectedType = getArrowModType(env(mod).asInstanceOf[TModType]) - val retType = generateTypeVar() - var runningEnv: TypeEnv = env - var runningSubst: Subst = List() - var typeList: List[Type] = List() - var argList: List[Expr] = List() - for (a <- args) + val moreFixed2 = castR match { - val (sub, typ, env1, fixed_a) = infer(runningEnv, a) - runningSubst = compose_subst(runningSubst, sub) - typeList = typeList :+ typ - argList = argList :+ fixed_a - runningEnv = env1 + case Some(tp) => assert(autocast) + ECast(tp, fixed2).copyMeta(fixed2) + case _ => fixed2 } - typeList = typeList.map(t => apply_subst_typ(runningSubst, t)) - val (subst, cast) = unify(TFun(typeList, retType), expectedType) - /*TODO insert cast*/ - val retSubst = compose_subst(runningSubst, subst) - val retEnv = runningEnv.apply_subst_typeenv(retSubst) - val retTyp = apply_subst_typ(retSubst, retType) - (retSubst, retTyp, retEnv, ca.copy(args = argList).copyMeta(ca)) - case EVar(id) => (List(), env(id), env, e) - case ECast(ctyp, exp) => /*TODO this is wrong probably*/ - val (s, t, env1, _) = infer(env, exp) - val newT = apply_subst_typ(s, t) - if (!canCast(ctyp, newT)) throw Errors.IllegalCast(e.pos, ctyp, newT) - (s, ctyp, env1, e) + moreFixed1.typ = Some(castL.getOrElse(t1VNew)) + moreFixed2.typ = Some(castR.getOrElse(t2VNew)) + val finalRetType = generateTypeVar() + val (finSubst, _) = unify(TFun(List(moreFixed1.typ.get, moreFixed2.typ.get), finalRetType), binOpExpectedType(op), binop = true) + val finalRetSubst = compose_many_subst(subTemp, finSubst) + val finalRetTyp = apply_subst_typ(finalRetSubst, finalRetType) + val bFixed = b.copy(e1 = moreFixed1, e2 = moreFixed2).copyMeta(b) + bFixed.typ = Some(finalRetTyp) + (finalRetSubst, finalRetTyp, env2.apply_subst_typeenv(finalRetSubst), bFixed) + case m@EMemAccess(mem, index, _) => if (!(env(mem).isInstanceOf[TMemType] || env(mem).isInstanceOf[TLockedMemType])) throw UnexpectedType(e.pos, "Memory Access", "TMemtype", env(mem)) + val retType = generateTypeVar() + val (s, t, env1, fixed_idx) = infer(env, index) + val tTemp = apply_subst_typ(s, t) + val memt = env1(mem) match + { + case t@TMemType(_, _, _, _, _, _) => t + case TLockedMemType(t, _, _) => t + case _ => throw UnexpectedType(e.pos, "Memory Access", "TMemtype", env1(mem)) + } + val (subst, _) = unify(TFun(List(tTemp), retType), getMemAccessType(memt)) + val retSubst = compose_subst(s, subst) + val retTyp = apply_subst_typ(retSubst, retType) + (retSubst, retTyp, env1.apply_subst_typeenv(retSubst), m.copy(index = fixed_idx).copyMeta(m)) + case b@EBitExtract(num, start, end) => val (s, t, e, fixed_num) = infer(env, num) + t match + { + case TSizedInt(TBitWidthLen(len), signedness) if len >= (math.abs(end - start) + 1) => (s, TSizedInt(TBitWidthLen(math.abs(end - start) + 1), signedness), e, b.copy(num = fixed_num).copyMeta(b)) + case b => throw UnificationError(b, TSizedInt(TBitWidthLen(32), TUnsigned())) //TODO Add better error message + } //TODO + case trn@ETernary(cond, tval, fval) => val (sc, tc, env1, fixed_cond) = infer(env, cond) + val (st, tt, env2, fixed_tval) = infer(env1, tval) + val (sf, tf, env3, fixed_fval) = infer(env2, fval) + val substSoFar = compose_many_subst(sc, st, sf) + val tcNew = apply_subst_typ(substSoFar, tc) + val ttNew = apply_subst_typ(substSoFar, tt) + val tfNew = apply_subst_typ(substSoFar, tf) + val (substc, _) = unify(tcNew, TBool()) + val (subst, cast) = unify(ttNew, tfNew) + val (fixed_tval1, fixed_fval1) = if (cast) if (ttNew <<= tfNew) (ECast(tfNew, fixed_tval), fixed_fval) else (fixed_tval, ECast(ttNew, fixed_fval)) else (fixed_tval, fixed_fval) + val retSubst = compose_many_subst(sc, st, sf, substc, subst) + val retType = apply_subst_typ(retSubst, ttNew) + (retSubst, retType, env3.apply_subst_typeenv(retSubst), trn.copy(cond = fixed_cond, tval = fixed_tval1, fval = fixed_fval1).copyMeta(trn)) + case ap@EApp(func, args) => val expectedType = env(func).asInstanceOf[TFun] + val retType = generateTypeVar() + var runningEnv: TypeEnv = env + var runningSubst: Subst = List() + var typeList: List[Type] = List() + var argList: List[Expr] = List() + for (a <- args) + { + val (sub, typ, env1, fixed_a) = infer(runningEnv, a) + runningSubst = compose_subst(runningSubst, sub) + typeList = typeList :+ typ + argList = argList :+ fixed_a + runningEnv = env1 + } + typeList = typeList.map(t => apply_subst_typ(runningSubst, t)) + val (subst, cast) = unify(TFun(typeList, retType), expectedType) + val fixed_arg_list = if (cast) + { + argList.zip(expectedType.args).map(argtp => ECast(argtp._2, argtp._1)) + } else argList + val retSubst = compose_subst(runningSubst, subst) + val retEnv = runningEnv.apply_subst_typeenv(retSubst) + val retTyp = apply_subst_typ(retSubst, retType) + (retSubst, retTyp, retEnv, ap.copy(args = fixed_arg_list).copyMeta(ap)) + case ca@ECall(mod, args) => if (!env(mod).isInstanceOf[TModType]) throw UnexpectedType(e.pos, "Module Call", "TModType", env(mod)) + val expectedType = getArrowModType(env(mod).asInstanceOf[TModType]) + val retType = generateTypeVar() + var runningEnv: TypeEnv = env + var runningSubst: Subst = List() + var typeList: List[Type] = List() + var argList: List[Expr] = List() + for (a <- args) + { + val (sub, typ, env1, fixed_a) = infer(runningEnv, a) + runningSubst = compose_subst(runningSubst, sub) + typeList = typeList :+ typ + argList = argList :+ fixed_a + runningEnv = env1 + } + typeList = typeList.map(t => apply_subst_typ(runningSubst, t)) + val (subst, cast) = unify(TFun(typeList, retType), expectedType) + val fixed_arg_list = if (cast) + { + argList.zip(expectedType.args).map(argtp => ECast(argtp._2, argtp._1)) + } else argList + val retSubst = compose_subst(runningSubst, subst) + val retEnv = runningEnv.apply_subst_typeenv(retSubst) + val retTyp = apply_subst_typ(retSubst, retType) + (retSubst, retTyp, retEnv, ca.copy(args = fixed_arg_list).copyMeta(ca)) + case EVar(id) => (List(), env(id), env, e) + case ECast(ctyp, exp) => /*TODO this is wrong probably*/ val (s, t, env1, _) = infer(env, exp) + val newT = apply_subst_typ(s, t) + if (!canCast(ctyp, newT)) throw Errors.IllegalCast(e.pos, ctyp, newT) + (s, ctyp, env1, e) + } } - } + private def generateSignTypeVar(): TSignedNess = + { + counter += 1 + TSignVar(Id("__SIGN__" + counter)) + } - private def binOpExpectedType(b: BOp): Type = b match - { - case EqOp(op) => val t = generateTypeVar() // TODO: This can be anything? - TFun(List(t, t), TBool()) - case CmpOp(op) => - val t = generateSignTypeVar() - TFun(List(TSizedInt(generateBitWidthTypeVar(), t), TSizedInt(generateBitWidthTypeVar(), t)), TBool()) //TODO: TSizedInt? - case BoolOp(op, fun) => TFun(List(TBool(), TBool()), TBool()) - case NumOp(op, fun) => val b1 = generateBitWidthTypeVar() - val b2 = generateBitWidthTypeVar() - val s = generateSignTypeVar() - op match + private def binOpTypesFromRet(b: BOp, retType: Type, t1: Type, t2: Type): (Option[Type], Option[Type]) = + { + val tmp = b match { - case "/" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(b1, s)) - case "*" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(TBitWidthAdd(b1, b2), s)) - // case "+" => TFun(List(TSizedInt(b1, true), TSizedInt(b2, true)), TSizedInt(TBitWidthMax(b1, b2), true)) - case "+" => TFun(List(TSizedInt(b1, s), TSizedInt(b1, s)), TSizedInt(b1, s)) - //case "-" => TFun(List(TSizedInt(b1, true), TSizedInt(b2, true)), TSizedInt(TBitWidthMax(b1, b2), true)) - case "-" => TFun(List(TSizedInt(b1, s), TSizedInt(b1, s)), TSizedInt(b1, s)) - case "%" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(b1, s)) + case EqOp(_) => val meet = t1 ⋁ t2 + (if (meet === t1) None else Some(meet), if (meet === t2) None else Some(meet)) + case CmpOp(_) => val meet = t1 ⋁ t2 + (if (meet === t1) None else Some(meet), if (meet === t2) None else Some(meet)) + case _: BoolOp => (None, None) + case NumOp(op, _) => op match + { + case "/" | "%" => val meet = retType ⋁ t1 + if (meet === t1) (None, None) else (Some(meet), None) + case "*" => (None, None) + case "+" | "-" | "$*" => val meet = t1 ⋁ t2 ⋁ retType + (if (meet === t1) None else Some(meet), if (meet === t2) None else Some(meet)) + } + case BitOp(op, _) => op match + { + case "++" => (None, None) + case _ => val meet = t1 ⋁ retType + if (meet === t1) (None, None) else (Some(meet), None) + } } - case BitOp(op, fun) => val b1 = generateBitWidthTypeVar() - val b2 = generateBitWidthTypeVar() - val s = generateSignTypeVar() - op match + tmp match { - case "++" => TFun(List(TSizedInt(b1, s), TSizedInt(b2, s)), TSizedInt(TBitWidthAdd(b1, b2), s)) - case _ => TFun(List(TSizedInt(b1, s), TSizedInt(b2, generateSignTypeVar())), TSizedInt(b1, s)) + case (Some(x), Some(y)) => (Some(x.setPos(b.pos)), Some(y.setPos(b.pos))) + case (Some(x), None) => (Some(x.setPos(b.pos)), None) + case (None, Some(y)) => (None, Some(y.setPos(b.pos))) + case _ => tmp } - } + } private def uOpExpectedType(u: UOp): Type = u match { - case BitUOp(op) => val b1 = generateBitWidthTypeVar() //TODO: Fix this + case BitUOp(_) => val b1 = generateBitWidthTypeVar() //TODO: Fix this val s = generateSignTypeVar() TFun(List(TSizedInt(b1, s)), TSizedInt(b1, s)) - case BoolUOp(op) => TFun(List(TBool()), TBool()) - case NumUOp(op) => val b1 = generateBitWidthTypeVar() + case BoolUOp(_) => TFun(List(TBool()), TBool()) + case NumUOp(_) => val b1 = generateBitWidthTypeVar() val s = generateSignTypeVar() TFun(List(TSizedInt(b1, s)), TSizedInt(b1, s)) - } private def getArrowModType(t: TModType): TFun = @@ -732,8 +708,7 @@ object TypeInferenceWrapper private def getMemAccessType(t: TMemType): TFun = { - TFun(List(TSizedInt(TBitWidthLen(t.addrSize), sign = TUnsigned()/*true*/)), t.elem) + TFun(List(TSizedInt(TBitWidthLen(t.addrSize), sign = TUnsigned())), t.elem) } - } } \ No newline at end of file diff --git a/src/test/tests/autocastTests/autocast-basic-pass.pdl b/src/test/tests/autocastTests/autocast-basic-pass.pdl index 4af4ea65..3cbad275 100644 --- a/src/test/tests/autocastTests/autocast-basic-pass.pdl +++ b/src/test/tests/autocastTests/autocast-basic-pass.pdl @@ -1,4 +1,3 @@ - def helper1(a: int<32>, b:bool, c: String): int<32> { d = a + 1<16>; e = b && false; @@ -13,6 +12,81 @@ def helper2(a:bool, b: bool): bool { return c; } +def ret_cast_test() :int<32> + { + b = 69<16>; + a = 15<12>; + return a + b; + } + +def ass_cast_test() :int<32> + { + a = 10<10>; + b = 5<10>; + int<32> c = a + b; + return c; + } + +def tern_cast_test(input :bool) :int<32> + { + a = 10<10>; + b = 1<30>; + d = ((input) ? (a) : (b)); + return d; + } + +pipe app_cast()[] + { + s = "owgeu"; + a = helper1(10<5>, true, s); + call app_cast(); + } + +pipe phelper1(a: int<32>, b:bool, c: String)[]: int<32> +{ + d = a + 1<16>; + e = b && false; + if (e) { + f = c; + } + output(d); +} + + +pipe call_cast()[] + { + s = "owgeu"; + a <- call phelper1(10<19>, true, s); + --- + call call_cast(); + } + +pipe output_cast()[] :int<32> + { output(10<5>); } + + + +pipe recv_cast1(input :uint<32>)[rf :int<32>[32]] + { + start(rf); + reserve(rf[input], W); + end(rf); + --- + int<10> a = 4; + block(rf[input]); + rf[input] <- a; + --- + release(rf[input]); + call recv_cast1(input); + } + +pipe recv_cast2()[] + { + int<32> a <- 5<12>; + call recv_cast2(); + } + + pipe test1(input: int<32>)[rf: int<32>[32]] { a = input; b = true; @@ -156,4 +230,4 @@ pipe test11()[rf: int<32>[32]] { circuit { r = memory(int<32>, 32); -} \ No newline at end of file +} diff --git a/src/test/tests/autocastTests/risc-pipe-spec.pdl b/src/test/tests/autocastTests/risc-pipe-spec.pdl new file mode 100644 index 00000000..e3990b76 --- /dev/null +++ b/src/test/tests/autocastTests/risc-pipe-spec.pdl @@ -0,0 +1,336 @@ +def mul(arg1: int<32>, arg2: int<32>, op: uint<3>): int<32> { + uint<32> mag1 = cast(mag(arg1), uint<32>); + uint<32> mag2 = cast(mag(arg2), uint<32>); + //MULHU => positive sign always + int<32> s1 = (op == 3) ? 1 : sign(arg1); + //MULHU/MULHSU => positive sign always + int<32> s2 = (op >= 2) ? 1 : sign(arg2); + int<64> magRes = cast((mag1 * mag2), int<64>); + int<64> m = (s1 == s2) ? (magRes) : -(magRes); + if (op == 0) { //MUL + return m{31:0}; + } else { + return m{63:32}; + } +} + +def alu(arg1: int<32>, arg2: int<32>, op: uint<3>, flip: bool): int<32> { + shamt = cast(arg2{4:0}, uint<5>); + if (op == 0) { //000 == ADD , flip == sub + if (!flip) { + return arg1 + arg2; + } else { + return arg1 - arg2; + } + } else { + if (op == u1) { //001 == SLL + return arg1 << shamt; + } else { + if (op == u2) { //010 == SLT + return (arg1 < arg2) ? 1<32> : 0<32>; + } else { + if (op == u3) { //011 == SLTU + uint<32> un1 = cast(arg1, uint<32>); + uint<32> un2 = cast(arg2, uint<32>); + return (un1 < un2) ? 1<32> : 0<32>; + } else { + if (op == u4) { //100 == XOR + return arg1 ^ arg2; + } else { + if (op == u5) { //101 == SRL / SRA + if (!flip) { + return cast((cast(arg1, uint<32>)) >> shamt, int<32>); //SRL + } else { + return arg1 >> shamt; //SRA + } + } else { + if (op == u6) { //110 == OR + return arg1 | arg2; + } else { //111 == AND + return arg1 & arg2; + }}}}}}} + +} + +def br(pc: int<16>, off:int<16>, op:uint<3>, arg1:int<32>, arg2:int<32>): int<16> { + //divide by 4 b/c we count instructions not bytes + int<16> offpc = pc + (off >> 2); + int<16> npc = pc + 1<16>; + if (op == u0<3>) { //BEQ + if (arg1 == arg2) { return offpc; } else { return npc; } + } else { + if (op == u1<3>) { //BNE + if (arg1 != arg2) { return offpc; } else { return npc; } + } else { + if (op == u4<3>) { //BLT + if (arg1 < arg2) { return offpc; } else { return npc; } + } else { + if (op == u5<3>) { //BGE + if (arg1 >= arg2) { return offpc; } else { return npc; } + } else { + if (op == u6<3>) { //BLTU + uint<32> un1 = cast(arg1, uint<32>); + uint<32> un2 = cast(arg2, uint<32>); + if (un1 < un2) { return offpc; } else { return npc; } + } else { + if (op == u7<3>) { //BGEU + uint<32> un1 = cast(arg1, uint<32>); + uint<32> un2 = cast(arg2, uint<32>); + if (un1 >= un2) { return offpc; } else { return npc; } + } else { + return npc; + }}}}}} +} + + +def storeMask(off: uint<2>, op: uint<3>): uint<4> { + if (op == u0<3>) { //SB + return ((1) << (off)); + } else { + if (op == u1<3>) { //SH + uint<2> shamt = off{1:1} ++ 0; + return (u0b0011<4> << shamt); + } else { //SW + return u0b1111<4>; + }} +} + +def maskLoad(data: int<32>, op: uint<3>, start: uint<2>): int<32> { + //start == offset in bytes, need to multiply by 8 + uint<5> boff = start ++ u0<3>; + int<32> tmp = data >> boff; + uint<8> bdata = cast(tmp, uint<8>); + uint<16> hdata = cast(tmp, uint<16>); + + if (op == u0<3>) { //LB + return cast(bdata, int<32>); + } else { + if (op == u1<3>) { //LH + return cast(hdata, int<32>); + } else { + if (op == u2<3>) { //LW + return data; + } else { + if (op == u4<3>) { //LBU + uint<32> zext = cast(bdata, uint<32>); + return cast(zext, int<32>); + } else { + if (op == u5<3>) { //LHU + uint<32> zext = cast(hdata, uint<32>); + return cast(zext, int<32>); + } else { + return 0<32>; + }}}}} +} + +pipe multi_stg_div(num: uint<32>, denom: uint<32>, quot: uint<32>, acc: uint<32>, cnt: uint<5>, retQuot: bool)[]: uint<32> { + uint<32> tmp = acc{30:0} ++ num{31:31}; + uint<32> na = (tmp >= denom) ? (tmp - denom) : (tmp); + uint<32> nq = (tmp >= denom) ? ((quot << 1){31:1} ++ u1<1>) : (quot << 1); + uint<32> nnum = num << 1; + if (cnt == u31<5>) { + output( (retQuot) ? nq : na ); + } else { + call multi_stg_div(nnum, denom, nq, na, cnt + u1<5>, retQuot); + } +} + + +pipe cpu(pc: int<16>)[rf: int<32>[5](FAQueue), imem: int<32>[16](Queue), dmem: int<32>[16](FAQueue), div: multi_stg_div]: bool { + spec_check(); + start(imem); + uint<16> pcaddr = cast(pc, uint<16>); + acquire(imem[pcaddr], R); + int<32> insn <- imem[pcaddr]; + release(imem[pcaddr]); + end(imem); + s <- speccall cpu(pc + 1<16>); + --- + //This OPCODE is J Self and thus we're using it to signal termination + bool done = insn == 0x0000006f<32>; + int<7> opcode = insn{6:0}; + uint<5> rs1 = cast(insn{19:15}, uint<5>); + uint<5> rs2 = cast(insn{24:20}, uint<5>); + uint<5> rd = cast(insn{11:7}, uint<5>); + uint<7> funct7 = cast(insn{31:25}, uint<7>); + uint<3> funct3 = cast(insn{14:12}, uint<3>); + int<1> flipBit = insn{30:30}; + int<32> immI = cast(insn{31:20}, int<32>); + int<32> immS = cast((insn{31:25} ++ insn{11:7}), int<32>); + int<13> immBTmp = insn{31:31} ++ insn{7:7} ++ insn{30:25} ++ insn{11:8} ++ 0; + int<16> immB = cast(immBTmp, int<16>); + int<21> immJTmp = insn{31:31} ++ insn{19:12} ++ insn{20:20} ++ insn{30:21} ++ 0<1>; + int<32> immJ = cast(immJTmp, int<32>); + immJRTmp = insn{31:20}; + int<16> immJR = cast(immJRTmp, int<16>); + int<32> immU = insn{31:12} ++ 0<12>; + bool isOpImm = opcode == 0b0010011<7>; + bool flip = (!isOpImm) && (flipBit == 1<1>); + bool isLui = opcode == 0b0110111; + bool isAui = opcode == 0b0010111; + bool isOp = opcode == 0b0110011; + bool isJal = opcode == 0b1101111; + bool isJalr = opcode == 0b1100111; + bool isBranch = opcode == 0b1100011; + bool isStore = opcode == 0b0100011; + bool isLoad = opcode == 0b0000011; + bool isMDiv = (funct7 == u1<7>) && isOp; + bool isDiv = isMDiv && (funct3 >= u4<3>); + bool needrs1 = !isJal; + bool needrs2 = isOp || isBranch || isStore || isJalr; + bool writerd = (rd != u0<5>) && (isOp || isOpImm || isLoad || isJal || isJalr || isLui || isAui); + spec_barrier(); + bool notBranch = (!isBranch) && (!isJal) && (!isJalr); + if ((!done) && notBranch) { + verify(s, pc + 1<16>); + } else { + invalidate(s); + } + start(rf); + if (needrs1) { + acquire(rf[rs1], R); + int<32> rf1 = rf[rs1]; + release(rf[rs1]); + } else { + int<32> rf1 = 0<32>; + } + if (needrs2) { + acquire(rf[rs2], R); + int<32> rf2 = rf[rs2]; + release(rf[rs2]); + } else { + int<32> rf2 = 0<32>; + } + if (writerd) { + reserve(rf[rd], W); + } + end(rf); + --- + if (isBranch) { + int<16> npc = br(pc, immB, funct3, rf1, rf2); + } else { + if (isJal) { + //divide by 4 since it counts bytes instead of insns + int<32> npc32 = cast(pc, int<32>) + (immJ >> 2); + int<16> npc = npc32{15:0}; + } else { + if (isJalr) { + int<16> npc = (rf1{15:0} + immJR) >> 2; + } else { + int<16> npc = pc + 1<16>; + }}} + if ((!done) && (!notBranch)) { call cpu(npc); } + + int<32> alu_arg2 = (isOpImm) ? immI : rf2; + + if (isDiv) { + int<32> sdividend = sign(rf1); + //For REM, ignore sign of divisor + int<32> sdivisor = (funct3 == u6<3>) ? 1<32> : sign(rf2); + bool isSignedDiv = ((funct3 == u4<3>) || (funct3 == u6<3>)); + bool invertRes = isSignedDiv && (sdividend != sdivisor); + uint<32> dividend = (isSignedDiv) ? cast(mag(rf1), uint<32>) : cast(rf1, uint<32>); + uint<32> divisor = (isSignedDiv) ? cast(mag(rf2), uint<32>) : cast(rf2, uint<32>); + bool retQuot = funct3 <= u5<3>; + uint<32> udivout <- call div(dividend, divisor, u0<32>, u0<32>, u0<5>, retQuot); + } else { + uint<32> udivout <- u0<32>; + bool invertRes = false; + } + --- + split { + case: (isLui) { + int<32> alu_res = immU; + } + case: (isAui) { + //all pc computation needs to be multiplied by 4 + int<32> pc32 = (0<16> ++ pc) << 2; + int<32> alu_res = pc32 + immU; + } + case: (isDiv) { + int<32> alu_res = (invertRes) ? -(cast(udivout, int<32>)) : cast(udivout, int<32>); + } + case: (isMDiv) { + int<32> alu_res = mul(rf1, rf2, funct3); + } + default: { + int<32> alu_res = alu(rf1, alu_arg2, funct3, flip); + } + } + split { + case: (isStore) { + //addresses also are word-sized + int<32> tmp = immS + rf1; + uint<32> ctmp = cast(tmp, uint<32>); + uint<16> memaddr = (ctmp >> 2){15:0}; + uint<2> boff = ctmp{1:0}; + } + case: (isLoad) { + //addresses also are word-sized + int<32> tmp = immI + rf1; + uint<32> ctmp = cast(tmp, uint<32>); + uint<16> memaddr = (ctmp >> 2){15:0}; + uint<2> boff = ctmp{1:0}; + } + default: { + uint<16> memaddr = u0<16>; + uint<2> boff = u0<2>; + } + } + --- + start(dmem); + split { + case: (isLoad) { + uint<16> raddr = memaddr; + acquire(dmem[raddr], R); + int<32> wdata <- dmem[raddr]; + release(dmem[raddr]); + } + case: (isStore) { + uint<16> waddr = memaddr; + acquire(dmem[waddr], W); + //use bottom bits of data and place in correct offset + //shift by boff*8 + uint<5> nboff = boff ++ u0<3>; + dmem[waddr, storeMask(boff, funct3)] <- (rf2 << nboff); + release(dmem[waddr]); + int<32> wdata <- 0<32>; + } + default: { + int<32> wdata <- 0<32>; + } + } + end(dmem); + --- + print("PC: %h", pc << 2); + print("INSN: %h", insn); + if (writerd) { + block(rf[rd]); + if (isLoad) { + int<32> insnout = maskLoad(wdata, funct3, boff); + } else { + if (isJal || isJalr) { + //need to multiply by 4 b/c it is arch visible. + int<16> nextpc = pc + 1<16>; + int<32> insnout = 0<16> ++ (nextpc << 2); //todo make pc 32 bits + } else { + int<32> insnout = alu_res; + }} + print("Writing %d to r%d", insnout, rd); + rf[rd] <- insnout; + release(rf[rd]); + } + if (done) { output(true); } +} + +circuit { + ti = memory(int<32>, 16); + i = Queue(ti); + td = memory(int<32>, 16); + d = FAQueue(td); + rf = regfile(int<32>, 5); + r = FAQueue(rf); + div = new multi_stg_div[]; + c = new cpu[r, i, d, div]; + call c(0<16>); +} \ No newline at end of file diff --git a/src/test/tests/autocastTests/solutions/risc-pipe-spec.typechecksol b/src/test/tests/autocastTests/solutions/risc-pipe-spec.typechecksol new file mode 100644 index 00000000..9fb4ec93 --- /dev/null +++ b/src/test/tests/autocastTests/solutions/risc-pipe-spec.typechecksol @@ -0,0 +1 @@ +Passed \ No newline at end of file diff --git a/src/test/tests/autocastTests/type-inference-bit-width-tests.pdl b/src/test/tests/autocastTests/type-inference-bit-width-tests.pdl index b2a31e6e..d5330578 100644 --- a/src/test/tests/autocastTests/type-inference-bit-width-tests.pdl +++ b/src/test/tests/autocastTests/type-inference-bit-width-tests.pdl @@ -13,7 +13,7 @@ def helper2(a:bool, b: bool): bool { } pipe test1(input: int<32>)[rf: int<32>[32]] { - a = 6 * 6; + a = 6 * 6<10>; int<32> b = a; int<32> c = a + b; call test1(c); @@ -31,7 +31,7 @@ pipe test2(input: int<32>)[rf: int<32>[32]] { } pipe test3(input: int<32>)[rf: int<32>[32]] { - a = 1; + a = 1<32>; int<32> b = a << 4; c = a << 4; call test3(c); @@ -44,7 +44,7 @@ pipe test4(input: int<32>)[rf: int<32>[32]] { } pipe test5()[] { - a = 6 * 6; + a = 6 * 6<5>; int<6> b = a; int<32> c = helper1(cast(a, int<32>), true, "hi"); call test5(); diff --git a/src/test/tests/risc-pipe/risc-pipe-spec.pdl b/src/test/tests/risc-pipe/risc-pipe-spec.pdl index cd855a1b..3a822899 100644 --- a/src/test/tests/risc-pipe/risc-pipe-spec.pdl +++ b/src/test/tests/risc-pipe/risc-pipe-spec.pdl @@ -2,12 +2,12 @@ def mul(arg1: int<32>, arg2: int<32>, op: uint<3>): int<32> { uint<32> mag1 = cast(mag(arg1), uint<32>); uint<32> mag2 = cast(mag(arg2), uint<32>); //MULHU => positive sign always - int<32> s1 = (op == u3<3>) ? 1<32> : sign(arg1); + int<32> s1 = (op == 3) ? 1 : sign(arg1); //MULHU/MULHSU => positive sign always - int<32> s2 = (op >= u2<3>) ? 1<32> : sign(arg2); + int<32> s2 = (op >= 2) ? 1 : sign(arg2); int<64> magRes = cast((mag1 * mag2), int<64>); int<64> m = (s1 == s2) ? (magRes) : -(magRes); - if (op == u0<3>) { //MUL + if (op == 0) { //MUL return m{31:0}; } else { return m{63:32}; @@ -15,36 +15,36 @@ def mul(arg1: int<32>, arg2: int<32>, op: uint<3>): int<32> { } def alu(arg1: int<32>, arg2: int<32>, op: uint<3>, flip: bool): int<32> { - uint<5> shamt = cast(arg2{4:0}, uint<5>); - if (op == u0<3>) { //000 == ADD , flip == sub + shamt = cast(arg2{4:0}, uint<5>); + if (op == 0) { //000 == ADD , flip == sub if (!flip) { return arg1 + arg2; } else { return arg1 - arg2; } } else { - if (op == u1<3>) { //001 == SLL + if (op == u1) { //001 == SLL return arg1 << shamt; } else { - if (op == u2<3>) { //010 == SLT - return (arg1 < arg2) ? 1<32> : 0<32>; + if (op == u2) { //010 == SLT + return (arg1 < arg2) ? 1 : 0; } else { - if (op == u3<3>) { //011 == SLTU - uint<32> un1 = cast(arg1, uint<32>); - uint<32> un2 = cast(arg2, uint<32>); - return (un1 < un2) ? 1<32> : 0<32>; + if (op == u3) { //011 == SLTU + un1 = cast(arg1, uint<32>); + un2 = cast(arg2, uint<32>); + return (un1 < un2) ? 1 : 0; } else { - if (op == u4<3>) { //100 == XOR + if (op == u4) { //100 == XOR return arg1 ^ arg2; } else { - if (op == u5<3>) { //101 == SRL / SRA + if (op == u5) { //101 == SRL / SRA if (!flip) { return cast((cast(arg1, uint<32>)) >> shamt, int<32>); //SRL } else { return arg1 >> shamt; //SRA } } else { - if (op == u6<3>) { //110 == OR + if (op == u6) { //110 == OR return arg1 | arg2; } else { //111 == AND return arg1 & arg2; @@ -54,8 +54,8 @@ def alu(arg1: int<32>, arg2: int<32>, op: uint<3>, flip: bool): int<32> { def br(pc: int<16>, off:int<16>, op:uint<3>, arg1:int<32>, arg2:int<32>): int<16> { //divide by 4 b/c we count instructions not bytes - int<16> offpc = pc + (off >> 2); - int<16> npc = pc + 1<16>; + offpc = pc + (off >> 2); + npc = pc + 1<16>; if (op == u0<3>) { //BEQ if (arg1 == arg2) { return offpc; } else { return npc; } } else { @@ -69,13 +69,13 @@ def br(pc: int<16>, off:int<16>, op:uint<3>, arg1:int<32>, arg2:int<32>): int<16 if (arg1 >= arg2) { return offpc; } else { return npc; } } else { if (op == u6<3>) { //BLTU - uint<32> un1 = cast(arg1, uint<32>); - uint<32> un2 = cast(arg2, uint<32>); + un1 = cast(arg1, uint<32>); + un2 = cast(arg2, uint<32>); if (un1 < un2) { return offpc; } else { return npc; } } else { if (op == u7<3>) { //BGEU - uint<32> un1 = cast(arg1, uint<32>); - uint<32> un2 = cast(arg2, uint<32>); + un1 = cast(arg1, uint<32>); + un2 = cast(arg2, uint<32>); if (un1 >= un2) { return offpc; } else { return npc; } } else { return npc; @@ -85,7 +85,7 @@ def br(pc: int<16>, off:int<16>, op:uint<3>, arg1:int<32>, arg2:int<32>): int<16 def storeMask(off: uint<2>, op: uint<3>): uint<4> { if (op == u0<3>) { //SB - return (u0b0001<4> << off); + return ((1) << (off)); } else { if (op == u1<3>) { //SH uint<2> shamt = off{1:1} ++ u0<1>; @@ -97,10 +97,10 @@ def storeMask(off: uint<2>, op: uint<3>): uint<4> { def maskLoad(data: int<32>, op: uint<3>, start: uint<2>): int<32> { //start == offset in bytes, need to multiply by 8 - uint<5> boff = start ++ u0<3>; - int<32> tmp = data >> boff; - uint<8> bdata = cast(tmp, uint<8>); - uint<16> hdata = cast(tmp, uint<16>); + boff = start ++ u0<3>; + tmp = data >> boff; + bdata = cast(tmp, uint<8>); + hdata = cast(tmp, uint<16>); if (op == u0<3>) { //LB return cast(bdata, int<32>); @@ -119,19 +119,19 @@ def maskLoad(data: int<32>, op: uint<3>, start: uint<2>): int<32> { uint<32> zext = cast(hdata, uint<32>); return cast(zext, int<32>); } else { - return 0<32>; + return 0; }}}}} } pipe multi_stg_div(num: uint<32>, denom: uint<32>, quot: uint<32>, acc: uint<32>, cnt: uint<5>, retQuot: bool)[]: uint<32> { - uint<32> tmp = acc{30:0} ++ num{31:31}; - uint<32> na = (tmp >= denom) ? (tmp - denom) : (tmp); - uint<32> nq = (tmp >= denom) ? ((quot << 1){31:1} ++ u1<1>) : (quot << 1); - uint<32> nnum = num << 1; - if (cnt == u31<5>) { + tmp = acc{30:0} ++ num{31:31}; + na = (tmp >= denom) ? (tmp - denom) : (tmp); + nq = (tmp >= denom) ? ((quot << 1){31:1} ++ 1) : (quot << 1); + nnum = num << 1; + if (cnt == 31) { output( (retQuot) ? nq : na ); } else { - call multi_stg_div(nnum, denom, nq, na, cnt + u1<5>, retQuot); + call multi_stg_div(nnum, denom, nq, na, cnt + 1, retQuot); } } @@ -139,9 +139,9 @@ pipe multi_stg_div(num: uint<32>, denom: uint<32>, quot: uint<32>, acc: uint<32> pipe cpu(pc: int<16>)[rf: int<32>[5](FAQueue), imem: int<32>[16](Queue), dmem: int<32>[16](FAQueue), div: multi_stg_div]: bool { spec_check(); start(imem); - uint<16> pcaddr = cast(pc, uint<16>); + pcaddr = cast(pc, uint<16>); acquire(imem[pcaddr], R); - int<32> insn <- imem[pcaddr]; + insn <- imem[pcaddr]; release(imem[pcaddr]); end(imem); s <- speccall cpu(pc + 1<16>); diff --git a/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl b/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl index b2a31e6e..e6cd8277 100644 --- a/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl +++ b/src/test/tests/typecheckTests/type-inference-bit-width-tests.pdl @@ -13,7 +13,7 @@ def helper2(a:bool, b: bool): bool { } pipe test1(input: int<32>)[rf: int<32>[32]] { - a = 6 * 6; + a = 6 * 6<5>; int<32> b = a; int<32> c = a + b; call test1(c); @@ -31,7 +31,7 @@ pipe test2(input: int<32>)[rf: int<32>[32]] { } pipe test3(input: int<32>)[rf: int<32>[32]] { - a = 1; + a = 1<32>; int<32> b = a << 4; c = a << 4; call test3(c); @@ -44,7 +44,7 @@ pipe test4(input: int<32>)[rf: int<32>[32]] { } pipe test5()[] { - a = 6 * 6; + a = 6 * 6<3>; int<6> b = a; int<32> c = helper1(cast(a, int<32>), true, "hi"); call test5();