Skip to content

Commit 2cf6a48

Browse files
committed
Add circuit location statistics for recursion predicates
* Overhaul EDSL location handling; we now generate mlir Locations directly instead of keeping our own SourceLoc structure, and it's relatively easy to add additional context to component construction using ScopedSourceLoc. * Expand --op-stats flag to work on `gen_predicates`; this outputs `...-opstats.txt` files in the output directory for each predicate processed. * --op-stats on `gen_predicates` also outputs encoded cycle counts in `...-encoded-cycles.txt` * Added a bunch more source annotations to get better granularity for recursion predicates.
1 parent 8051576 commit 2cf6a48

File tree

31 files changed

+511
-475
lines changed

31 files changed

+511
-475
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
fetch-depth: 0
3636
- uses: risc0/risc0/.github/actions/rustup@a9d723e29a44563497220a998b5de4e03d9da049
3737
- name: Install cargo-sort
38-
uses: risc0/cargo-install@v1
38+
uses: risc0/cargo-install@9f6037ed331dcf7da101461a20656273fa72abf0
3939
with:
4040
crate: cargo-sort
4141
version: "1.0"

risc0/core/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ cc_library(
88
hdrs = [
99
"elf.h",
1010
"log.h",
11-
"source_loc.h",
1211
"util.h",
1312
],
1413
visibility = ["//visibility:public"],

risc0/core/source_loc.h

Lines changed: 0 additions & 78 deletions
This file was deleted.

zirgen/circuit/keccak/predicates.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ static cl::opt<std::string> keccakIR("keccak-ir",
5555
int main(int argc, char* argv[]) {
5656
llvm::InitLLVM y(argc, argv);
5757
registerEdslCLOptions();
58+
registerOpStatsCLOptions();
5859
llvm::cl::ParseCommandLineOptions(argc, argv, "keccak predicates");
5960

6061
Module module;

zirgen/circuit/predicates/gen_predicates.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "zirgen/circuit/verify/wrap_rv32im.h"
2020
#include "zirgen/circuit/verify/wrap_zirgen.h"
2121
#include "zirgen/compiler/codegen/codegen.h"
22+
#include "zirgen/compiler/stats/OpStats.h"
2223

2324
using namespace zirgen;
2425
using namespace zirgen::verify;
@@ -202,6 +203,7 @@ static cl::opt<std::string>
202203
int main(int argc, char* argv[]) {
203204
llvm::InitLLVM y(argc, argv);
204205
registerEdslCLOptions();
206+
registerOpStatsCLOptions();
205207
llvm::cl::ParseCommandLineOptions(argc, argv, "gen_predicates edsl");
206208

207209
Module module;

zirgen/circuit/recursion/encode.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,10 @@ struct Instructions {
136136
uint64_t padShaCountConst;
137137
uint64_t shaRngConsts;
138138
llvm::StringMap<uint64_t> tagConsts;
139+
EncodeStats* stats = nullptr;
139140

140-
Instructions(HashType hashType) : hashType(hashType), nextOut(1), microUsed(0) {
141+
Instructions(HashType hashType, EncodeStats* stats)
142+
: hashType(hashType), nextOut(1), microUsed(0), stats(stats) {
141143
addMacro(/*outs=*/0, MacroOpcode::WOM_INIT);
142144
// Make: [0, 1, 0, 0], [0, 0, 1, 0], and [0, 0, 0, 1]
143145
fp4Rot1 = addConst(0, 1);
@@ -224,6 +226,15 @@ struct Instructions {
224226
if (microUsed == 3) {
225227
microUsed = 0;
226228
}
229+
230+
if (stats) {
231+
if (out) {
232+
ScopedLocation loc(out.getLoc());
233+
stats->locs[currentLoc()]++;
234+
} else {
235+
stats->locs[currentLoc()]++;
236+
}
237+
}
227238
return outId;
228239
}
229240

@@ -245,6 +256,8 @@ struct Instructions {
245256
uint64_t writeAddr = nextOut;
246257
data.back().writeAddr = writeAddr;
247258
nextOut += outs;
259+
if (stats)
260+
stats->locs[currentLoc()] += 3;
248261
return writeAddr;
249262
}
250263

@@ -263,6 +276,8 @@ struct Instructions {
263276
uint64_t writeAddr = nextOut;
264277
data.back().writeAddr = writeAddr;
265278
nextOut += 1; // Write exactly one thing (evaluated point)
279+
if (stats)
280+
stats->locs[currentLoc()] += 3;
266281
return writeAddr;
267282
}
268283

@@ -286,6 +301,8 @@ struct Instructions {
286301
data.back().data.poseidon2Mem.inputs[i] = inputs[i];
287302
}
288303
data.back().writeAddr = nextOut;
304+
if (stats)
305+
stats->locs[currentLoc()] += 3;
289306
}
290307

291308
void addPoseidon2Full(uint64_t cycle) {
@@ -295,6 +312,8 @@ struct Instructions {
295312
data.back().opType = OpType::POSEIDON2_FULL;
296313
data.back().data.poseidon2Full.cycle = cycle;
297314
data.back().writeAddr = nextOut;
315+
if (stats)
316+
stats->locs[currentLoc()] += 3;
298317
}
299318

300319
void addPoseidon2Partial() {
@@ -303,6 +322,8 @@ struct Instructions {
303322
data.emplace_back();
304323
data.back().opType = OpType::POSEIDON2_PARTIAL;
305324
data.back().writeAddr = nextOut;
325+
if (stats)
326+
stats->locs[currentLoc()] += 3;
306327
}
307328

308329
uint64_t addPoseidon2Store(uint64_t doMont, uint64_t group) {
@@ -315,30 +336,36 @@ struct Instructions {
315336
uint64_t writeAddr = nextOut;
316337
data.back().writeAddr = writeAddr;
317338
nextOut += 8;
339+
if (stats)
340+
stats->locs[currentLoc()] += 3;
318341
return writeAddr;
319342
}
320343

321344
void doShaInit() {
345+
ScopedLocation loc;
322346
for (size_t i = 0; i < 4; i++) {
323347
shaUsed++;
324348
addMacro(/*outs=*/0, MacroOpcode::SHA_INIT, shaInit[3 - i], shaInit[3 - i + 4]);
325349
}
326350
}
327351

328352
void doShaLoad(llvm::ArrayRef<uint64_t> values, uint64_t subtype) {
353+
ScopedLocation loc;
329354
for (size_t i = 0; i < 16; i++) {
330355
shaUsed++;
331356
addMacro(/*outs=*/0, MacroOpcode::SHA_LOAD, values[i], shaK[i], subtype);
332357
}
333358
}
334359

335360
void doShaMix() {
361+
ScopedLocation loc;
336362
for (size_t i = 0; i < 48; i++) {
337363
shaUsed++;
338364
addMacro(/*outs=*/0, MacroOpcode::SHA_MIX, 0, shaK[16 + i]);
339365
}
340366
}
341367
uint64_t doShaFini() {
368+
ScopedLocation loc;
342369
uint64_t out = nextOut;
343370
for (size_t i = 0; i < 4; i++) {
344371
shaUsed++;
@@ -349,6 +376,7 @@ struct Instructions {
349376
}
350377

351378
uint64_t doSha(llvm::ArrayRef<uint64_t> values, uint64_t subtype) {
379+
ScopedLocation loc;
352380
doShaInit();
353381
uint64_t ret = 0;
354382
for (size_t i = 0; i < values.size() / 16; i++) {
@@ -360,6 +388,7 @@ struct Instructions {
360388
}
361389

362390
uint64_t doShaFold(uint64_t lhs, uint64_t rhs) {
391+
ScopedLocation loc;
363392
std::vector<uint64_t> ids(16);
364393
for (size_t i = 0; i < 8; i++) {
365394
ids[i] = lhs + i;
@@ -369,6 +398,7 @@ struct Instructions {
369398
}
370399

371400
uint64_t doIntoDigestShaBytes(llvm::ArrayRef<uint64_t> bytes) {
401+
ScopedLocation loc;
372402
// We keep things in low / high form right until the end so that the final adds are
373403
// all contiguous since all the 'digest' stuff assumes digests are always contiguous.
374404
std::vector<uint64_t> low;
@@ -388,6 +418,7 @@ struct Instructions {
388418
}
389419

390420
uint64_t doIntoDigestShaWords(llvm::ArrayRef<uint64_t> words) {
421+
ScopedLocation loc;
391422
// We keep things in low / high form right until the end so that the final adds are
392423
// all contiguous since all the 'digest' stuff assumes digests are always contiguous.
393424
std::vector<uint64_t> low;
@@ -404,6 +435,7 @@ struct Instructions {
404435
}
405436

406437
uint64_t doShaTag(llvm::StringRef tag) {
438+
ScopedLocation loc;
407439
if (tagConsts.count(tag)) {
408440
return tagConsts.find(tag)->second;
409441
} else {
@@ -422,6 +454,7 @@ struct Instructions {
422454
llvm::ArrayRef<uint64_t> digests,
423455
llvm::ArrayRef<DigestKind> digestTypes,
424456
llvm::ArrayRef<uint64_t> vals) {
457+
ScopedLocation loc;
425458
std::vector<uint64_t> words;
426459

427460
for (size_t i = 0; i < 8; i++) {
@@ -472,6 +505,7 @@ struct Instructions {
472505
// representation, with 16 bits in each of the two low components of an
473506
// extension field element.
474507
void taggedStructPushVals(std::vector<uint64_t>& words, llvm::ArrayRef<uint64_t> vals) {
508+
ScopedLocation loc;
475509
// Get low 16 bits of each value (done in a loop for better packing)
476510
std::vector<uint64_t> lowVals;
477511
for (size_t i = 0; i < vals.size(); i++) {
@@ -489,6 +523,7 @@ struct Instructions {
489523
}
490524

491525
std::pair<uint64_t, std::vector<uint64_t>> doHashCheckedBytes(uint64_t evalPt, uint64_t count) {
526+
ScopedLocation loc;
492527
if (!count) {
493528
// Special case for 0 outputs
494529
return {doPoseidon2({}), {}};
@@ -519,6 +554,7 @@ struct Instructions {
519554

520555
std::tuple<uint64_t, uint64_t, std::vector<uint64_t>> doHashCheckedBytesPublic(uint64_t evalPt,
521556
uint64_t count) {
557+
ScopedLocation loc;
522558
if (!count)
523559
throw std::runtime_error("Cannont publically hash empty checked bytes");
524560

@@ -570,6 +606,7 @@ struct Instructions {
570606
}
571607

572608
uint64_t doPoseidon2(llvm::ArrayRef<uint64_t> values) {
609+
ScopedLocation loc;
573610
if (values.empty()) {
574611
auto psuite = poseidon2HashSuite();
575612
auto hashVal = psuite->hash(nullptr, 0);
@@ -634,6 +671,7 @@ struct Instructions {
634671
}
635672

636673
uint64_t doIntoDigestPoseidon2(llvm::ArrayRef<uint64_t> words) {
674+
ScopedLocation loc;
637675
// Do pointless adds to make all the words land in sequential spots
638676
uint64_t ret = nextOut;
639677
size_t pad = words.size() / 8 - 1;
@@ -646,6 +684,7 @@ struct Instructions {
646684
}
647685

648686
void addInst(Operation& op) {
687+
ScopedLocation loc(op.getLoc());
649688
TypeSwitch<Operation*>(&op)
650689
.Case<Zll::ExternOp>([&](Zll::ExternOp op) {
651690
if (op.getName() == "write") {
@@ -984,6 +1023,7 @@ uint64_t ShaRng::generateFp(Instructions& insts) {
9841023
}
9851024

9861025
void ShaRng::mix(Instructions& insts, uint64_t digest) {
1026+
ScopedLocation loc;
9871027
uint64_t xorOut = insts.nextOut;
9881028
for (size_t i = 0; i < 8; i++) {
9891029
// Xors and returns 2 shorts: [a, b, 0, 0] ^ [c, d, 0, 0] -> [a ^ c, b ^ d, 0, 0]
@@ -1022,6 +1062,7 @@ uint64_t Poseidon2Rng::generateFp(Instructions& insts) {
10221062
}
10231063

10241064
void Poseidon2Rng::mix(Instructions& insts, uint64_t digest) {
1065+
ScopedLocation loc;
10251066
if (stateUsed != 0) {
10261067
stateUsed = 0;
10271068
mix(insts, 0);
@@ -1069,6 +1110,7 @@ void Poseidon2Rng::mix(Instructions& insts, uint64_t digest) {
10691110
}
10701111

10711112
void MixedPoseidon2ShaRng::mix(Instructions& insts, uint64_t digest) {
1113+
ScopedLocation loc;
10721114
// For each element of the Poseidon2 hash, we convert it to a form usable by SHA.
10731115
// This is done in stages so that macro ops are grouped together for efficiency
10741116
// First, we 'and' things by 0xffff
@@ -1104,7 +1146,7 @@ std::vector<uint32_t> encode(HashType hashType,
11041146
mlir::Block* block,
11051147
llvm::DenseMap<Value, uint64_t>* toIdReturn,
11061148
EncodeStats* stats) {
1107-
Instructions insts(hashType);
1149+
Instructions insts(hashType, stats);
11081150
for (Operation& op : block->without_terminator()) {
11091151
insts.addInst(op);
11101152
}

zirgen/circuit/recursion/encode.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ struct EncodeStats {
3131
size_t totCycles = 0;
3232
size_t shaCycles = 0;
3333
size_t poseidon2Cycles = 0;
34+
35+
// Locations and the number of micro cycles used (= number of macro cycles * 3).
36+
llvm::DenseMap<mlir::Location, /*cycles=*/size_t> locs;
3437
};
3538

3639
std::vector<uint32_t> encode(HashType hashType,

zirgen/circuit/recursion/recursion.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
using namespace zirgen;
2222
using namespace zirgen::recursion;
23-
using namespace risc0;
2423
using namespace mlir;
2524

2625
int main(int argc, char* argv[]) {

zirgen/circuit/rv32im/v1/edsl/rv32im.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
using namespace zirgen;
2525
using namespace zirgen::rv32im_v1;
26-
using namespace risc0;
2726
using namespace mlir;
2827

2928
int main(int argc, char* argv[]) {

0 commit comments

Comments
 (0)