Skip to content

Commit

Permalink
refactor recursion: simplify as_array impl for public values
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoratger authored and nhtyy committed Jan 23, 2025
1 parent ab3b5cc commit 97d2de5
Showing 1 changed file with 196 additions and 17 deletions.
213 changes: 196 additions & 17 deletions crates/recursion/core/src/air/public_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ use sp1_core_machine::utils::indices_arr;
use sp1_derive::AlignedBorrow;
use sp1_stark::{air::POSEIDON_NUM_WORDS, septic_digest::SepticDigest, Word, PROOF_MAX_NUM_PVS};
use static_assertions::const_assert_eq;
use std::{
borrow::BorrowMut,
mem::{size_of, transmute, MaybeUninit},
};
use std::mem::{size_of, transmute};

pub const PV_DIGEST_NUM_WORDS: usize = 8;

Expand Down Expand Up @@ -64,12 +61,7 @@ impl<T: Clone> ChallengerPublicValues<T> {
where
T: Copy,
{
unsafe {
let mut ret = [MaybeUninit::<T>::zeroed().assume_init(); CHALLENGER_STATE_NUM_ELTS];
let pv: &mut ChallengerPublicValues<T> = ret.as_mut_slice().borrow_mut();
*pv = *self;
ret
}
unsafe { std::mem::transmute_copy(self) }
}
}

Expand Down Expand Up @@ -126,7 +118,7 @@ pub struct RecursionPublicValues<T> {
pub vk_root: [T; DIGEST_SIZE],

/// Current cumulative sum of lookup bus. Note that for recursive proofs for core proofs, this
/// contains the global cumulative sum.
/// contains the global cumulative sum.
pub global_cumulative_sum: SepticDigest<T>,

/// Whether the proof completely proves the program execution.
Expand All @@ -146,12 +138,7 @@ pub struct RecursionPublicValues<T> {
/// Converts the public values to an array of elements.
impl<F: Copy> RecursionPublicValues<F> {
pub fn as_array(&self) -> [F; RECURSIVE_PROOF_NUM_PV_ELTS] {
unsafe {
let mut ret = [MaybeUninit::<F>::zeroed().assume_init(); RECURSIVE_PROOF_NUM_PV_ELTS];
let pv: &mut RecursionPublicValues<F> = ret.as_mut_slice().borrow_mut();
*pv = *self;
ret
}
unsafe { std::mem::transmute_copy(self) }
}
}

Expand All @@ -172,3 +159,195 @@ impl<T: Copy> IntoIterator for ChallengerPublicValues<T> {
self.as_array().into_iter()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_recursion_public_values_as_array() {
// Create a sample RecursionPublicValues with arbitrary values.
let test_values = RecursionPublicValues {
committed_value_digest: [Word([1, 2, 3, 4]); PV_DIGEST_NUM_WORDS],
deferred_proofs_digest: [2; POSEIDON_NUM_WORDS],
start_pc: 3,
next_pc: 4,
start_shard: 5,
next_shard: 6,
start_execution_shard: 7,
next_execution_shard: 8,
previous_init_addr_bits: [9; 32],
last_init_addr_bits: [10; 32],
previous_finalize_addr_bits: [11; 32],
last_finalize_addr_bits: [12; 32],
start_reconstruct_deferred_digest: [13; POSEIDON_NUM_WORDS],
end_reconstruct_deferred_digest: [14; POSEIDON_NUM_WORDS],
sp1_vk_digest: [15; DIGEST_SIZE],
vk_root: [16; DIGEST_SIZE],
global_cumulative_sum: Default::default(),
is_complete: 18,
contains_execution_shard: 19,
exit_code: 20,
digest: [21; DIGEST_SIZE],
};

// Convert to array and verify the array length.
let as_array = test_values.as_array();
assert_eq!(as_array.len(), RECURSIVE_PROOF_NUM_PV_ELTS);

// Verify specific elements in the array (by index, depending on layout).
for i in 0..PV_DIGEST_NUM_WORDS {
assert_eq!(as_array[4 * i + 0], 1);

Check failure on line 200 in crates/recursion/core/src/air/public_values.rs

View workflow job for this annotation

GitHub Actions / Formatting & Clippy

this operation has no effect
assert_eq!(as_array[4 * i + 1], 2);
assert_eq!(as_array[4 * i + 2], 3);
assert_eq!(as_array[4 * i + 3], 4);
}

// Verify deferred_proofs_digest.
let mut index = 4 * PV_DIGEST_NUM_WORDS;
for &value in &test_values.deferred_proofs_digest {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify scalar fields.
assert_eq!(as_array[index], test_values.start_pc);
index += 1;
assert_eq!(as_array[index], test_values.next_pc);
index += 1;
assert_eq!(as_array[index], test_values.start_shard);
index += 1;
assert_eq!(as_array[index], test_values.next_shard);
index += 1;
assert_eq!(as_array[index], test_values.start_execution_shard);
index += 1;
assert_eq!(as_array[index], test_values.next_execution_shard);
index += 1;

// Verify previous_init_addr_bits.
for &value in &test_values.previous_init_addr_bits {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify last_init_addr_bits.
for &value in &test_values.last_init_addr_bits {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify previous_finalize_addr_bits.
for &value in &test_values.previous_finalize_addr_bits {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify last_finalize_addr_bits.
for &value in &test_values.last_finalize_addr_bits {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify start_reconstruct_deferred_digest.
for &value in &test_values.start_reconstruct_deferred_digest {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify end_reconstruct_deferred_digest.
for &value in &test_values.end_reconstruct_deferred_digest {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify sp1_vk_digest.
for &value in &test_values.sp1_vk_digest {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify vk_root.
for &value in &test_values.vk_root {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify global_cumulative_sum (default is [0; DIGEST_SIZE]).
for &value in &test_values.global_cumulative_sum.0.x.0 {
assert_eq!(as_array[index], value);
index += 1;
}

for &value in &test_values.global_cumulative_sum.0.y.0 {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify is_complete.
assert_eq!(as_array[index], test_values.is_complete);
index += 1;

// Verify contains_execution_shard.
assert_eq!(as_array[index], test_values.contains_execution_shard);
index += 1;

// Verify exit_code.
assert_eq!(as_array[index], test_values.exit_code);
index += 1;

// Verify digest.
for &value in &test_values.digest {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify the final index of the array.
assert_eq!(index, RECURSIVE_PROOF_NUM_PV_ELTS);
}

#[test]
fn test_challenger_public_values_as_array() {
// Create a sample ChallengerPublicValues with arbitrary values.
let test_values = ChallengerPublicValues {
sponge_state: [1; PERMUTATION_WIDTH],
num_inputs: 2,
input_buffer: [3; PERMUTATION_WIDTH],
num_outputs: 4,
output_buffer: [5; PERMUTATION_WIDTH],
};

// Convert to array and verify the array length.
let as_array = test_values.as_array();
assert_eq!(as_array.len(), CHALLENGER_STATE_NUM_ELTS);

// Verify sponge_state.
let mut index = 0;
for &value in &test_values.sponge_state {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify num_inputs.
assert_eq!(as_array[index], test_values.num_inputs);
index += 1;

// Verify input_buffer.
for &value in &test_values.input_buffer {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify num_outputs.
assert_eq!(as_array[index], test_values.num_outputs);
index += 1;

// Verify output_buffer.
for &value in &test_values.output_buffer {
assert_eq!(as_array[index], value);
index += 1;
}

// Verify the final index of the array.
assert_eq!(index, CHALLENGER_STATE_NUM_ELTS);
}
}

0 comments on commit 97d2de5

Please sign in to comment.