Skip to content

Commit e3028e6

Browse files
committed
fixed 1 test
1 parent 3e9c66c commit e3028e6

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

sdk-libs/macros/src/hasher.rs

+51-8
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ pub(crate) fn hasher(input: ItemStruct) -> Result<TokenStream> {
8282
.any(|attr| attr.path().is_ident("flatten"))
8383
});
8484

85-
let mut flattned_fields_added = vec![quote! { Self::NUM_FIELDS as usize }];
85+
let mut flattned_fields_added = vec![quote! { 0usize }];
8686
let mut truncate_code = Vec::new();
8787

8888
// Process each field
@@ -206,21 +206,37 @@ pub(crate) fn hasher(input: ItemStruct) -> Result<TokenStream> {
206206
flattned_fields_added.push(quote! {
207207
+ #field_type::NUM_FIELDS as usize
208208
});
209+
210+
// For flattened fields, we need to use their individual byte arrays directly
211+
// rather than hashing the whole struct first
209212
code.push(quote! {
210213
{
211-
for (j, element) in <#field_type as ::light_hasher::to_byte_array::ToByteArray>::to_byte_arrays::<{#field_type::NUM_FIELDS}>(&self.#field_name)?.iter().enumerate() {
212-
field_array[#i + j + num_flattned_fields ] = *element;
213-
num_flattned_fields +=1;
214+
// Get individual byte arrays from the flattened field
215+
let flattened_arrays = <#field_type as ::light_hasher::to_byte_array::ToByteArray>::to_byte_arrays::<{#field_type::NUM_FIELDS}>(&self.#field_name)?;
216+
// Add each element individually to the field_array
217+
for element in flattened_arrays.iter() {
218+
field_array[num_flattned_fields] = *element;
219+
num_flattned_fields += 1;
214220
}
215221
}
216222
});
217223
} else {
224+
if flatten_field_exists {
225+
flattned_fields_added.push(quote! {
226+
+ 1
227+
});
228+
}
218229
to_byte_arrays_fields.push(quote! {
219230
arrays[#i ] = self.#field_name.to_byte_array()?;
220231
});
221232
if flatten_field_exists {
233+
// Store field index in the field_assignments for later non-flattened field processing
222234
field_assignments.push(quote! {
223-
field_array[#i + num_flattned_fields ] = self.#field_name.to_byte_array()?;
235+
#i
236+
});
237+
code.push(quote! {
238+
field_array[num_flattned_fields] = self.#field_name.to_byte_array()?;
239+
num_flattned_fields += 1;
224240
});
225241
} else {
226242
field_assignments.push(quote! {
@@ -242,8 +258,9 @@ pub(crate) fn hasher(input: ItemStruct) -> Result<TokenStream> {
242258
},
243259
);
244260
code.push(quote! {
245-
for element in field_array.iter() {
246-
slices[num_flattned_fields] = element.as_slice();
261+
// Set all slices properly for both flattened and non-flattened fields
262+
for i in 0..num_flattned_fields {
263+
slices[i] = field_array[i].as_slice();
247264
}
248265
});
249266
quote! {
@@ -310,9 +327,35 @@ pub(crate) fn hasher(input: ItemStruct) -> Result<TokenStream> {
310327
}
311328
};
312329

330+
// Calculate the total number of fields, accounting for flattened fields
331+
let total_field_count = if flatten_field_exists {
332+
// When there are flattened fields, we need to adjust the total field count
333+
let mut sum = quote! { 0 };
334+
335+
for field in fields.named.iter() {
336+
let flatten = field
337+
.attrs
338+
.iter()
339+
.any(|attr| attr.path().is_ident("flatten"));
340+
341+
if flatten {
342+
// Use the field type's NUM_FIELDS instead of counting as one field
343+
let field_type = &field.ty;
344+
sum = quote! { #sum + #field_type::NUM_FIELDS };
345+
} else {
346+
// Regular fields count as one
347+
sum = quote! { #sum + 1 };
348+
}
349+
}
350+
sum
351+
} else {
352+
// Without flattened fields, just use the regular field count
353+
quote! { #field_count }
354+
};
355+
313356
Ok(quote! {
314357
impl #impl_gen ::light_hasher::to_byte_array::ToByteArray for #struct_name #type_gen #where_clause {
315-
const NUM_FIELDS: usize = #field_count;
358+
const NUM_FIELDS: usize = #total_field_count;
316359

317360
fn to_byte_array(&self) -> ::std::result::Result<[u8; 32], ::light_hasher::HasherError> {
318361
#to_byte_array

sdk-libs/macros/tests/flatten.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ fn flatten() {
4343
array[31] = 4;
4444
array
4545
};
46-
let manual_hash =
47-
Poseidon::hashv(&[one.as_ref(), two.as_ref(), three.as_ref(), four.as_ref()]).unwrap();
46+
let hash = Poseidon::hashv(&[three.as_slice(), four.as_slice()]).unwrap();
47+
let manual_slices = [one.as_ref(), two.as_ref(), hash.as_ref()];
48+
println!("manual_slices {:?}", manual_slices);
49+
let manual_hash = Poseidon::hashv(&manual_slices).unwrap();
4850
assert_eq!(test.hash::<Poseidon>().unwrap(), manual_hash);
4951
}

0 commit comments

Comments
 (0)