Skip to content

Commit 19b33a0

Browse files
Use cycle counter in Keccak circuit for provable determinism
1 parent f1f6894 commit 19b33a0

File tree

5 files changed

+50
-9
lines changed

5 files changed

+50
-9
lines changed
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: true
2+
3+
import is_zero;
4+
5+
extern GetCycle(): Val;
6+
7+
// RISC Zero STARKs are equivalent up to rotation of the trace. This component
8+
// counts cycles from zero up to the power of two, with constraints that
9+
// guarantee a unique "cycle zero."
10+
11+
// It uses a global that stores the total number of cycles, which the verifier
12+
// should check matches the intended trace length. If the global doesn't match,
13+
// the verifier should reject, so assume that it matches. Then we either have a
14+
// cycle zero or we don't; if we do, then the next cycle must be 1, then 2, and
15+
// so on because of the constraint cycle = cycle@1 + 1. After total_cycles, we
16+
// have gone over the whole trace, which means the cycle before cycle 0 must be
17+
// total_cycles - 1. If we don't, then we always have that cycle = cycle@1. But
18+
// because the trace is cyclic, the cycle number must go down on some cycle, so
19+
// a constraint must have been violated.
20+
component CycleCounter() {
21+
global total_cycles : NondetReg;
22+
23+
cycle := NondetReg(GetCycle());
24+
public is_first_cycle := IsZero(cycle);
25+
26+
[is_first_cycle, 1-is_first_cycle] -> ({
27+
// First cycle; previous cycle should be the last cycle.
28+
cycle@1 = total_cycles - 1;
29+
}, {
30+
// Not first cycle; cycle number should advance by one for every row.
31+
cycle = cycle@1 + 1;
32+
});
33+
cycle
34+
}

zirgen/circuit/keccak/predicates.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@ template <typename Func>
2727
void addZirgenLift(Module& module, const std::string name, const std::string path, Func func) {
2828
auto circuit = getInterfaceZirgen(module.getModule().getContext(), path);
2929
for (size_t po2 = 14; po2 < 19; ++po2) {
30+
size_t totalCycles = 1 << po2;
3031
module.addFunc<3>(name + "_" + std::to_string(po2),
3132
{gbuf(recursion::kOutSize), ioparg(), ioparg()},
3233
[&](Buffer out, ReadIopVal rootIop, ReadIopVal zirgenSeal) {
3334
DigestVal root = rootIop.readDigests(1)[0];
3435
VerifyInfo info = zirgen::verify::verify(zirgenSeal, po2, *circuit);
3536
llvm::ArrayRef inStream(info.out);
36-
DigestVal outData = func(inStream);
37+
DigestVal outData = func(inStream, totalCycles);
3738
std::vector<Val> outStream;
3839
writeSha(outData, outStream);
3940
doExtern("write", "", 0, outStream);
@@ -57,9 +58,15 @@ int main(int argc, char* argv[]) {
5758
llvm::cl::ParseCommandLineOptions(argc, argv, "keccak predicates");
5859

5960
Module module;
60-
addZirgenLift(module, "keccak_lift", keccakIR.getValue(), [](llvm::ArrayRef<Val>& inStream) {
61-
return readSha(inStream);
62-
});
61+
addZirgenLift(module,
62+
"keccak_lift",
63+
keccakIR.getValue(),
64+
[](llvm::ArrayRef<Val>& inStream, size_t expectedTotalCycles) {
65+
auto sha = readSha(inStream);
66+
Val totalCycles = readVal(inStream);
67+
eq(totalCycles, expectedTotalCycles);
68+
return sha;
69+
});
6370

6471
module.optimize();
6572
module.getModule().walk([&](mlir::func::FuncOp func) { zirgen::emitRecursion(outputDir, func); });

zirgen/circuit/keccak/top.zir

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: true
22

3-
import is_zero;
3+
import cycle_counter;
44
import keccak;
55
import sha2;
66

@@ -276,7 +276,6 @@ component ShaNextBlockCycle(back1: TopState) {
276276
topState
277277
}
278278

279-
extern IsFirstCycle(): Val;
280279
extern GetPreimage(idx: Val): Val;
281280
extern NextPreimage(): Val;
282281

@@ -474,10 +473,10 @@ component WrapOneHot(oneHot: OneHot<12>) {
474473
component Top() {
475474
global finalDigest: DigestReg;
476475

477-
isFirst := NondetReg(IsFirstCycle());
476+
cycle := CycleCounter();
478477
cycleMux : WrapOneHot;
479478
controlState : ControlState;
480-
controlState := if (isFirst) {
479+
controlState := if (cycle.is_first_cycle) {
481480
[email protected] = CycleTypeShutdown();
482481
ControlState(CycleTypeInit(), 0, 0, 0)
483482
} else {

zirgen/circuit/predicates/predicates.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using Zll::DigestKind;
2525

2626
constexpr size_t kMaxInsnCycles = 2000; // TODO(flaub): update this with precise value.
2727

28-
static Val readVal(llvm::ArrayRef<Val>& stream) {
28+
Val readVal(llvm::ArrayRef<Val>& stream) {
2929
assert(stream.size() >= 1);
3030
Val out = stream[0];
3131
stream = stream.drop_front();

zirgen/circuit/predicates/predicates.h

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ ReceiptClaim join(ReceiptClaim in1, ReceiptClaim in2);
118118
ReceiptClaim identity(ReceiptClaim in);
119119
ReceiptClaim resolve(ReceiptClaim cond, Assumption assum, DigestVal tail, DigestVal journal);
120120

121+
Val readVal(llvm::ArrayRef<Val>& stream);
121122
DigestVal readSha(llvm::ArrayRef<Val>& stream, bool longDigest = false);
122123
void writeSha(DigestVal val, std::vector<Val>& stream);
123124

0 commit comments

Comments
 (0)