Skip to content

Commit

Permalink
some fixes and testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
yy665 committed Oct 16, 2023
1 parent ef2f70f commit 5176c08
Show file tree
Hide file tree
Showing 6 changed files with 1,524 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/pipedsl/typechecker/TypeInferenceWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ object TypeInferenceWrapper
val (fixed_except, final_subst) = m.except_blk match
{
case ExceptEmpty() => (ExceptEmpty(), subst1)
//TODO IMPROVE PRECISION OF OUT_ENV, SUBST
//TODO - EXN: IMPROVE PRECISION OF OUT_ENV, SUBST
case ExceptFull(args, c) =>
checkCommand(c, args.foldLeft(out_env)((env, id) => env.add(id, id.typ.get).asInstanceOf[TypeEnv]), subst) |>
checkCommand(c, args.foldLeft(pipeEnv)((env, id) => env.add(id, id.typ.get).asInstanceOf[TypeEnv]), subst) |>
{case (cmd, _, s) => (ExceptFull(args, cmd), s)}
}
val hash = mutable.HashMap.from(final_subst)
Expand Down
370 changes: 370 additions & 0 deletions src/test/tests/risc-pipe/risc-pipe-spec-exn-1.pdl
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
// risc-pipe-spec-write.pdl
extern BHT<T> {
method req(pc: int<T>, skip: int<T>, take: int<T>): int<T>;
method upd(pc: int<T>, taken: bool): ();
}

def mul(arg1: int<32>, arg2: int<32>, op: uint<3>): int<32> {
uint<32> mag1 = cast(mag(arg1), uint<32>);
uint<32> mag2 = cast(mag(arg2), uint<32>);
//MULHU => positive sign always
int<32> s1 = (op == u3<3>) ? 1<32> : sign(arg1);
//MULHU/MULHSU => positive sign always
int<32> s2 = (op >= u2<3>) ? 1<32> : sign(arg2);
int<64> magRes = cast((mag1 * mag2), int<64>);
int<64> m = (s1 == s2) ? (magRes) : -(magRes);
if (op == u0<3>) { //MUL
return m{31:0};
} else {
return m{63:32};
}
}

def alu(arg1: int<32>, arg2: int<32>, op: uint<3>, flip: bool): int<32> {
uint<5> shamt = cast(arg2{4:0}, uint<5>);
if (op == u0<3>) { //000 == ADD , flip == sub
if (!flip) {
return arg1 + arg2;
} else {
return arg1 - arg2;
}
} else {
if (op == u1<3>) { //001 == SLL
return arg1 << shamt;
} else {
if (op == u2<3>) { //010 == SLT
return (arg1 < arg2) ? 1<32> : 0<32>;
} else {
if (op == u3<3>) { //011 == SLTU
uint<32> un1 = cast(arg1, uint<32>);
uint<32> un2 = cast(arg2, uint<32>);
return (un1 < un2) ? 1<32> : 0<32>;
} else {
if (op == u4<3>) { //100 == XOR
return arg1 ^ arg2;
} else {
if (op == u5<3>) { //101 == SRL / SRA
if (!flip) {
return cast((cast(arg1, uint<32>)) >> shamt, int<32>); //SRL
} else {
return arg1 >> shamt; //SRA
}
} else {
if (op == u6<3>) { //110 == OR
return arg1 | arg2;
} else { //111 == AND
return arg1 & arg2;
}}}}}}}

}

//true if taking branch
def br(op:uint<3>, arg1:int<32>, arg2:int<32>): bool {
if (op == u0<3>) { //BEQ
return (arg1 == arg2);
} else {
if (op == u1<3>) { //BNE
return (arg1 != arg2);
} else {
if (op == u4<3>) { //BLT
return (arg1 < arg2);
} else {
if (op == u5<3>) { //BGE
return (arg1 >= arg2);
} else {
if (op == u6<3>) { //BLTU
uint<32> un1 = cast(arg1, uint<32>);
uint<32> un2 = cast(arg2, uint<32>);
return (un1 < un2);
} else {
if (op == u7<3>) { //BGEU
uint<32> un1 = cast(arg1, uint<32>);
uint<32> un2 = cast(arg2, uint<32>);
return(un1 >= un2);
} else {
return false;
}}}}}}
}


def storeMask(off: uint<2>, op: uint<3>): uint<4> {
if (op == u0<3>) { //SB
return (u0b0001<4> << off);
} else {
if (op == u1<3>) { //SH
uint<2> shamt = off{1:1} ++ u0<1>;
return (u0b0011<4> << shamt);
} else { //SW
return u0b1111<4>;
}}
}

def maskLoad(data: int<32>, op: uint<3>, start: uint<2>): int<32> {
//start == offset in bytes, need to multiply by 8
uint<5> boff = start ++ u0<3>;
int<32> tmp = data >> boff;
uint<8> bdata = cast(tmp, uint<8>);
uint<16> hdata = cast(tmp, uint<16>);

if (op == u0<3>) { //LB
return cast(bdata, int<32>);
} else {
if (op == u1<3>) { //LH
return cast(hdata, int<32>);
} else {
if (op == u2<3>) { //LW
return data;
} else {
if (op == u4<3>) { //LBU
uint<32> zext = cast(bdata, uint<32>);
return cast(zext, int<32>);
} else {
if (op == u5<3>) { //LHU
uint<32> zext = cast(hdata, uint<32>);
return cast(zext, int<32>);
} else {
return 0<32>;
}}}}}
}

pipe multi_stg_div(num: uint<32>, denom: uint<32>, quot: uint<32>, acc: uint<32>, cnt: uint<5>, retQuot: bool)[]: uint<32> {
uint<32> tmp = acc{30:0} ++ num{31:31};
uint<32> na = (tmp >= denom) ? (tmp - denom) : (tmp);
uint<32> nq = (tmp >= denom) ? ((quot << 1){31:1} ++ u1<1>) : (quot << 1);
uint<32> nnum = num << 1;
bool done = (cnt == u31<5>);
if (done) {
output( (retQuot) ? nq : na );
} else {
call multi_stg_div(nnum, denom, nq, na, cnt + u1<5>, retQuot);
}
}


exn-pipe cpu(pc: int<16>)[rf: int<32>[5]<c2,s>(CheckpointRF), imem: int<32>[16]<a,a>, dmem: int<32>[16]<a,a>, div: multi_stg_div, bht: BHT, csrf: int<32>[2]<c,s>]: bool {
spec_check();
start(imem);
uint<16> pcaddr = cast(pc, uint<16>);
int<32> insn <- imem[pcaddr];
end(imem);
s <- speccall cpu(pc + 1<16>);
---
//This OPCODE is J Self and thus we're using it to signal termination
bool done = insn == 0x0000006f<32>;
int<7> opcode = insn{6:0};
uint<5> rs1 = cast(insn{19:15}, uint<5>);
uint<5> rs2 = cast(insn{24:20}, uint<5>);
uint<5> rd = cast(insn{11:7}, uint<5>);
uint<7> funct7 = cast(insn{31:25}, uint<7>);
uint<3> funct3 = cast(insn{14:12}, uint<3>);
int<1> flipBit = insn{30:30};
int<32> immI = cast(insn{31:20}, int<32>);
int<32> immS = cast((insn{31:25} ++ insn{11:7}), int<32>);
int<13> immBTmp = insn{31:31} ++ insn{7:7} ++ insn{30:25} ++ insn{11:8} ++ 0<1>;
int<16> immB = cast(immBTmp, int<16>);
int<21> immJTmp = insn{31:31} ++ insn{19:12} ++ insn{20:20} ++ insn{30:21} ++ 0<1>;
int<32> immJ = cast(immJTmp, int<32>);
int<12> immJRTmp = insn{31:20};
int<16> immJR = cast(immJRTmp, int<16>);
int<32> immU = insn{31:12} ++ 0<12>;
uint<3> doAdd = u0<3>;
bool isOpImm = opcode == 0b0010011<7>;
bool flip = (!isOpImm) && (flipBit == 1<1>);
bool isLui = opcode == 0b0110111<7>;
bool isAui = opcode == 0b0010111<7>;
bool isOp = opcode == 0b0110011<7>;
bool isJal = opcode == 0b1101111<7>;
bool isJalr = opcode == 0b1100111<7>;
bool isBranch = opcode == 0b1100011<7>;
bool isStore = opcode == 0b0100011<7>;
bool isLoad = opcode == 0b0000011<7>;
bool isMDiv = (funct7 == u1<7>) && isOp;
bool isDiv = isMDiv && (funct3 >= u4<3>);
bool isMul = isMDiv && (funct3 < u4<3>);
bool isSystem = opcode == 0b1110011<7>;
bool isEcall = isSystem && (immI == 0<32>);
bool isMret = isSystem && (funct7 == u24<7>) && (rs2 == u2<5>);
bool needrs1 = !isJal;
bool needrs2 = isOp || isBranch || isStore || isJalr;
bool writerd = (rd != u0<5>) && (isOp || isOpImm || isLoad || isJal || isJalr || isLui || isAui);
spec_check();
bool notBranch = (!isBranch) && (!isJal) && (!isJalr);
if (!done) {
if (notBranch) {
s2 <- s;
} else {
if (isBranch) {
s2 <- update(s, bht.req(pc, immB, 1<16>));
} else {
s2 <- s;
invalidate(s);
}}} else { s2 <- s; invalidate(s); }
if (isEcall){
throw(u8<4>);
}
if (isMret){
throw(u3<4>);
}
start(rf);
if (!done && (notBranch || isBranch)) { checkpoint(rf); }
if (needrs1) {
reserve(rf[rs1], R);
}
if (needrs2) {
reserve(rf[rs2], R);
}
if (writerd) {
reserve(rf[rd], W);
}
end(rf);
---
spec_check();
if (needrs1) {
block(rf[rs1]);
int<32> rf1 = rf[rs1];
release(rf[rs1]);
} else {
int<32> rf1 = 0<32>;
}
if (needrs2) {
block(rf[rs2]);
int<32> rf2 = rf[rs2];
release(rf[rs2]);
} else {
int<32> rf2 = 0<32>;
}
bool take = br(funct3, rf1, rf2);
if (isBranch) {
//divide by 4 b/c we count instructions not bytes
int<16> offpc = pc + (immB >> 2);
int<16> npc = (take) ? (offpc) : (pc + 1<16>);
} else {
if (isJal) {
//divide by 4 since it counts bytes instead of insns
int<32> npc32 = cast(pc, int<32>) + (immJ >> 2);
int<16> npc = npc32{15:0};
} else {
if (isJalr) {
int<16> npc = (rf1{15:0} + immJR) >> 2;
} else {
int<16> npc = pc + 1<16>;
}}}
int<32> alu_arg1 = (isAui) ? ((0<16> ++ pc) << 2) : rf1;
int<32> alu_arg2 = (isAui) ? immU : ((isStore) ? immS : ((isOpImm || isLoad) ? immI : rf2));
bool alu_flip = (isStore || isLoad || isAui) ? false : flip;
uint<3> alu_funct3 = (isStore || isLoad || isAui) ? doAdd : funct3;
int<32> alu_res = alu(alu_arg1, alu_arg2, alu_funct3, alu_flip);
int<16> tmppc = pc + 1<16>;
int<32> linkpc = 0<16> ++ (tmppc << 2);
int<32> mulres = mul(rf1, rf2, funct3);
if (writerd && (!isLoad) && (!isDiv)) {
block(rf[rd]);
int<32> rddata = (isLui) ? immU : ((isMul) ? mulres : ((isJal || isJalr) ? linkpc : alu_res));
rf[rd] <- rddata;
} else {
int<32> rddata = 0<32>;
}
---
spec_barrier(); //delayed just to force writes to happen speculatively
if (!done) {
if (!notBranch) {
if (isBranch) {
verify(s2, npc) { bht.upd(pc, take) };
} else {
call cpu(npc);
}
} else {
verify(s, pc + 1<16>);
}
}

if (isDiv) {
int<32> sdividend = sign(rf1);
//For REM, ignore sign of divisor
int<32> sdivisor = (funct3 == u6<3>) ? 1<32> : sign(rf2);
bool isSignedDiv = ((funct3 == u4<3>) || (funct3 == u6<3>));
uint<32> dividend = (isSignedDiv) ? cast(mag(rf1), uint<32>) : cast(rf1, uint<32>);
uint<32> divisor = (isSignedDiv) ? cast(mag(rf2), uint<32>) : cast(rf2, uint<32>);
bool retQuot = funct3 <= u5<3>;
bool invertRes <- isSignedDiv && (sdividend != sdivisor);
uint<32> udivout <- call div(dividend, divisor, u0<32>, u0<32>, u0<5>, retQuot);
} else {
bool invertRes <- false;
uint<32> udivout <- u0<32>;
}
uint<32> tmpaddr = cast(alu_res, uint<32>);
uint<16> memaddr = (tmpaddr >> 2){15:0};
uint<2> boff = cast(alu_res{1:0}, uint<2>);
start(dmem);
split {
case: (isLoad) {
uint<16> raddr = memaddr;
int<32> wdata <- dmem[raddr];
}
case: (isStore) {
uint<16> waddr = memaddr;
//use bottom bits of data and place in correct offset
//shift by boff*8
uint<5> nboff = boff ++ u0<3>;
dmem[waddr, storeMask(boff, funct3)] <- (rf2 << nboff);
int<32> wdata <- 0<32>;
}
default: {
int<32> wdata <- 0<32>;
}
}
end(dmem);
---
print("PC: %h", pc << 2);
print("INSN: %h", insn);
if (writerd) {
if (isLoad) {
block(rf[rd]);
int<32> insnout = maskLoad(wdata, funct3, boff);
rf[rd] <- insnout;
} else {
if (isDiv) {
block(rf[rd]);
int<32> divout = (invertRes) ? -(cast(udivout, int<32>)) : cast(udivout, int<32>);
int<32> insnout = divout;
rf[rd] <- insnout;
} else {
int<32> insnout = rddata;
}}
print("Writing %d to r%d", insnout, rd);
}
---
commit:
if (writerd) {
release(rf[rd]);
}
if (done) { output(true); }
except(exncode: uint<4>): // mepc = csrf[0], mtvec = csrf[1], mcause = csrf[2]
print("exception at PC == %d! with exncode: %d", pc, exncode);
start(csrf);
if (exncode == u8<4>){
csrf[0] <- cast(pc, int<32>); // write to mepc
int<32> csrret = csrf[1]; // read from mtvec
int<16> npc = csrret{15:0};
} else { if (exncode == u3<4>){
int<32> csrret = csrf[0]; // read from mepc
int<16> npc = csrret{15:0} + 1;
} else { int<16> npc = pc + 1; }}
---
csrf[2] <- cast(exncode, int<32>); // write to mcause
end(csrf);
---
print("GOTO PC: %d", npc);
call cpu(npc);
}

circuit {
ti = memory(int<32>, 16);
td = memory(int<32>, 16);
rf = rflock CheckpointRF(int<32>, 5, 64);
csrf = regfile(int<32>, 2);
div = new multi_stg_div[];
b = new BHT<16>[](4);
c = new cpu[rf, ti, td, div, b, csrf];
call c(0<16>);
}
Loading

0 comments on commit 5176c08

Please sign in to comment.