Skip to content

Commit

Permalink
fix: counts and small shape fixing (#2007)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Jan 31, 2025
1 parent 1721066 commit bd1e88e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 39 deletions.
5 changes: 1 addition & 4 deletions crates/core/machine/src/riscv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,7 @@ impl<F: PrimeField32> RiscvAir<F> {
(RiscvAirId::Auipc, record.auipc_events.len()),
(RiscvAirId::Branch, record.branch_events.len()),
(RiscvAirId::Jump, record.jump_events.len()),
(
RiscvAirId::Global,
2 * record.get_local_mem_events().count() + record.syscall_events.len(),
),
(RiscvAirId::Global, record.global_interaction_events.len()),
(RiscvAirId::SyscallCore, record.syscall_events.len()),
(RiscvAirId::SyscallInstrs, record.syscall_events.len()),
]
Expand Down
131 changes: 96 additions & 35 deletions crates/core/machine/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,57 +250,118 @@ where

// We combine the memory init/finalize events if they are "small"
// and would affect performance.
let last_record = if done
let mut shape_fixed_records = if done
&& num_cycles < 1 << 21
&& deferred.global_memory_initialize_events.len()
< opts.split_opts.combine_memory_threshold
&& deferred.global_memory_finalize_events.len()
< opts.split_opts.combine_memory_threshold
{
records.last_mut()
let mut records_clone = records.clone();
let last_record = records_clone.last_mut();
// See if any deferred shards are ready to be committed to.
let mut deferred =
deferred.split(done, last_record, opts.split_opts);
log::info!("deferred {} records", deferred.len());

// Update the public values & prover state for the shards which do not
// contain "cpu events" before committing to them.
if !done {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits =
record.public_values.previous_init_addr_bits;
state.last_init_addr_bits =
record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits =
record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = *state;
}
records_clone.append(&mut deferred);

// Generate the dependencies.
tracing::debug_span!("generate dependencies", index).in_scope(
|| {
prover.machine().generate_dependencies(
&mut records_clone,
&opts,
None,
);
},
);

// Let another worker update the state.
record_gen_sync.advance_turn();

// Fix the shape of the records.
let mut fixed_shape = true;
if let Some(shape_config) = shape_config {
for record in records_clone.iter_mut() {
if shape_config.fix_shape(record).is_err() {
fixed_shape = false;
}
}
}
fixed_shape.then_some(records_clone)
} else {
None
};

// See if any deferred shards are ready to be committed to.
let mut deferred = deferred.split(done, last_record, opts.split_opts);
log::info!("deferred {} records", deferred.len());
if shape_fixed_records.is_none() {
// See if any deferred shards are ready to be committed to.
let mut deferred = deferred.split(done, None, opts.split_opts);
log::info!("deferred {} records", deferred.len());

// Update the public values & prover state for the shards which do not
// contain "cpu events" before committing to them.
if !done {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits =
record.public_values.previous_init_addr_bits;
state.last_init_addr_bits =
record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits =
record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = *state;
}
records.append(&mut deferred);

// Generate the dependencies.
tracing::debug_span!("generate dependencies", index).in_scope(|| {
prover.machine().generate_dependencies(&mut records, &opts, None);
});
// Update the public values & prover state for the shards which do not
// contain "cpu events" before committing to them.
if !done {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits =
record.public_values.previous_init_addr_bits;
state.last_init_addr_bits =
record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits =
record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = *state;
}
records.append(&mut deferred);

// Generate the dependencies.
tracing::debug_span!("generate dependencies", index).in_scope(
|| {
prover.machine().generate_dependencies(
&mut records,
&opts,
None,
);
},
);

// Let another worker update the state.
record_gen_sync.advance_turn();
// Let another worker update the state.
record_gen_sync.advance_turn();

// Fix the shape of the records.
if let Some(shape_config) = shape_config {
for record in records.iter_mut() {
shape_config.fix_shape(record).unwrap();
// Fix the shape of the records.
if let Some(shape_config) = shape_config {
for record in records.iter_mut() {
shape_config.fix_shape(record).unwrap();
}
}
shape_fixed_records = Some(records);
}

let mut records = shape_fixed_records.unwrap();

// Send the shapes to the channel, if necessary.
for record in records.iter() {
let mut heights = vec![];
Expand Down

0 comments on commit bd1e88e

Please sign in to comment.