From bd1e88ef6c832447c806043835f1760baca170ec Mon Sep 17 00:00:00 2001 From: Tamir Hemo Date: Thu, 30 Jan 2025 17:40:04 -0800 Subject: [PATCH] fix: counts and small shape fixing (#2007) --- crates/core/machine/src/riscv/mod.rs | 5 +- crates/core/machine/src/utils/prove.rs | 131 ++++++++++++++++++------- 2 files changed, 97 insertions(+), 39 deletions(-) diff --git a/crates/core/machine/src/riscv/mod.rs b/crates/core/machine/src/riscv/mod.rs index e8164bb48..525e435a3 100644 --- a/crates/core/machine/src/riscv/mod.rs +++ b/crates/core/machine/src/riscv/mod.rs @@ -448,10 +448,7 @@ impl RiscvAir { (RiscvAirId::Auipc, record.auipc_events.len()), (RiscvAirId::Branch, record.branch_events.len()), (RiscvAirId::Jump, record.jump_events.len()), - ( - RiscvAirId::Global, - 2 * record.get_local_mem_events().count() + record.syscall_events.len(), - ), + (RiscvAirId::Global, record.global_interaction_events.len()), (RiscvAirId::SyscallCore, record.syscall_events.len()), (RiscvAirId::SyscallInstrs, record.syscall_events.len()), ] diff --git a/crates/core/machine/src/utils/prove.rs b/crates/core/machine/src/utils/prove.rs index 6979975b1..87f31f447 100644 --- a/crates/core/machine/src/utils/prove.rs +++ b/crates/core/machine/src/utils/prove.rs @@ -250,57 +250,118 @@ where // We combine the memory init/finalize events if they are "small" // and would affect performance. - let last_record = if done + let mut shape_fixed_records = if done && num_cycles < 1 << 21 && deferred.global_memory_initialize_events.len() < opts.split_opts.combine_memory_threshold && deferred.global_memory_finalize_events.len() < opts.split_opts.combine_memory_threshold { - records.last_mut() + let mut records_clone = records.clone(); + let last_record = records_clone.last_mut(); + // See if any deferred shards are ready to be committed to. + let mut deferred = + deferred.split(done, last_record, opts.split_opts); + log::info!("deferred {} records", deferred.len()); + + // Update the public values & prover state for the shards which do not + // contain "cpu events" before committing to them. + if !done { + state.execution_shard += 1; + } + for record in deferred.iter_mut() { + state.shard += 1; + state.previous_init_addr_bits = + record.public_values.previous_init_addr_bits; + state.last_init_addr_bits = + record.public_values.last_init_addr_bits; + state.previous_finalize_addr_bits = + record.public_values.previous_finalize_addr_bits; + state.last_finalize_addr_bits = + record.public_values.last_finalize_addr_bits; + state.start_pc = state.next_pc; + record.public_values = *state; + } + records_clone.append(&mut deferred); + + // Generate the dependencies. + tracing::debug_span!("generate dependencies", index).in_scope( + || { + prover.machine().generate_dependencies( + &mut records_clone, + &opts, + None, + ); + }, + ); + + // Let another worker update the state. + record_gen_sync.advance_turn(); + + // Fix the shape of the records. + let mut fixed_shape = true; + if let Some(shape_config) = shape_config { + for record in records_clone.iter_mut() { + if shape_config.fix_shape(record).is_err() { + fixed_shape = false; + } + } + } + fixed_shape.then_some(records_clone) } else { None }; - // See if any deferred shards are ready to be committed to. - let mut deferred = deferred.split(done, last_record, opts.split_opts); - log::info!("deferred {} records", deferred.len()); + if shape_fixed_records.is_none() { + // See if any deferred shards are ready to be committed to. + let mut deferred = deferred.split(done, None, opts.split_opts); + log::info!("deferred {} records", deferred.len()); - // Update the public values & prover state for the shards which do not - // contain "cpu events" before committing to them. - if !done { - state.execution_shard += 1; - } - for record in deferred.iter_mut() { - state.shard += 1; - state.previous_init_addr_bits = - record.public_values.previous_init_addr_bits; - state.last_init_addr_bits = - record.public_values.last_init_addr_bits; - state.previous_finalize_addr_bits = - record.public_values.previous_finalize_addr_bits; - state.last_finalize_addr_bits = - record.public_values.last_finalize_addr_bits; - state.start_pc = state.next_pc; - record.public_values = *state; - } - records.append(&mut deferred); - - // Generate the dependencies. - tracing::debug_span!("generate dependencies", index).in_scope(|| { - prover.machine().generate_dependencies(&mut records, &opts, None); - }); + // Update the public values & prover state for the shards which do not + // contain "cpu events" before committing to them. + if !done { + state.execution_shard += 1; + } + for record in deferred.iter_mut() { + state.shard += 1; + state.previous_init_addr_bits = + record.public_values.previous_init_addr_bits; + state.last_init_addr_bits = + record.public_values.last_init_addr_bits; + state.previous_finalize_addr_bits = + record.public_values.previous_finalize_addr_bits; + state.last_finalize_addr_bits = + record.public_values.last_finalize_addr_bits; + state.start_pc = state.next_pc; + record.public_values = *state; + } + records.append(&mut deferred); + + // Generate the dependencies. + tracing::debug_span!("generate dependencies", index).in_scope( + || { + prover.machine().generate_dependencies( + &mut records, + &opts, + None, + ); + }, + ); - // Let another worker update the state. - record_gen_sync.advance_turn(); + // Let another worker update the state. + record_gen_sync.advance_turn(); - // Fix the shape of the records. - if let Some(shape_config) = shape_config { - for record in records.iter_mut() { - shape_config.fix_shape(record).unwrap(); + // Fix the shape of the records. + if let Some(shape_config) = shape_config { + for record in records.iter_mut() { + shape_config.fix_shape(record).unwrap(); + } } + shape_fixed_records = Some(records); } + let mut records = shape_fixed_records.unwrap(); + // Send the shapes to the channel, if necessary. for record in records.iter() { let mut heights = vec![];