|
3 | 3 | //! Definition of random oracle input structure and |
4 | 4 | //! methods for serializing into bytes and field elements |
5 | 5 |
|
| 6 | +use core::fmt::Error; |
| 7 | + |
6 | 8 | use super::Hashable; |
7 | 9 | use alloc::{vec, vec::Vec}; |
8 | | -use ark_ff::PrimeField; |
| 10 | +use ark_ff::{BigInteger, PrimeField}; |
9 | 11 | use bitvec::{prelude::*, view::AsBits}; |
10 | 12 | use mina_curves::pasta::{Fp, Fq}; |
11 | 13 | use o1_utils::FieldHelpers; |
12 | 14 |
|
| 15 | +const SER_HEADER_SIZE: usize = 8; // total number of bytes for the header of the serialized ROInput |
| 16 | +const SINGLE_HEADER_SIZE: usize = 4; // number of bytes for each part of the header of the serialized ROInput |
| 17 | + |
13 | 18 | /// Random oracle input structure |
14 | 19 | /// |
15 | 20 | /// The random oracle input encapsulates the serialization format and methods using during hashing. |
@@ -172,6 +177,66 @@ impl ROInput { |
172 | 177 |
|
173 | 178 | fields |
174 | 179 | } |
| 180 | + |
| 181 | + /// Serialize the ROInput into bytes |
| 182 | + pub fn serialize(&self) -> Vec<u8> { |
| 183 | + // 4-byte LE field count, 4-byte LE bit count, then payload |
| 184 | + let fields_len = self.fields.len() as u32; |
| 185 | + let bits_len = self.bits.len() as u32; |
| 186 | + |
| 187 | + let mut bytes = Vec::with_capacity(SER_HEADER_SIZE + self.to_bytes().len()); |
| 188 | + bytes.extend_from_slice(&fields_len.to_le_bytes()); |
| 189 | + bytes.extend_from_slice(&bits_len.to_le_bytes()); |
| 190 | + bytes.extend_from_slice(&self.to_bytes()); |
| 191 | + bytes |
| 192 | + } |
| 193 | + |
| 194 | + /// Deserialize a `ROInput` from bytes |
| 195 | + pub fn deserialize(input: &[u8]) -> Result<Self, Error> { |
| 196 | + if input.len() < SER_HEADER_SIZE { |
| 197 | + return Err(Error); |
| 198 | + } |
| 199 | + |
| 200 | + // read back our two u32 little-endian lengths |
| 201 | + let fields_len = |
| 202 | + u32::from_le_bytes(input[0..SINGLE_HEADER_SIZE].try_into().unwrap()) as usize; |
| 203 | + let bits_len = u32::from_le_bytes( |
| 204 | + input[SINGLE_HEADER_SIZE..SER_HEADER_SIZE] |
| 205 | + .try_into() |
| 206 | + .unwrap(), |
| 207 | + ) as usize; |
| 208 | + |
| 209 | + // the rest is payload |
| 210 | + let bits = input[SER_HEADER_SIZE..].view_bits::<Lsb0>(); |
| 211 | + |
| 212 | + // Check that the number of bytes is consistent with the expected lengths |
| 213 | + let expected_len_bits = fields_len * Fp::MODULUS_BIT_SIZE as usize + bits_len; |
| 214 | + // Round up to nearest multiple of 8 |
| 215 | + let expected_len = (expected_len_bits + 7) / 8 + SER_HEADER_SIZE; |
| 216 | + if input.len() != expected_len { |
| 217 | + return Err(Error); |
| 218 | + } |
| 219 | + |
| 220 | + // allocate space for exactly `fields_len` elements |
| 221 | + let mut fields = Vec::with_capacity(fields_len); |
| 222 | + |
| 223 | + for chunk in bits.chunks(Fp::MODULUS_BIT_SIZE as usize).take(fields_len) { |
| 224 | + let bools: Vec<bool> = chunk.iter().by_vals().collect(); |
| 225 | + // conver little-endian bits to a big integer representation |
| 226 | + let repr = <Fp as PrimeField>::BigInt::from_bits_le(&bools); |
| 227 | + // convert to field element (reduces mod p) |
| 228 | + let elt = Fp::from_bigint(repr).ok_or(Error)?; |
| 229 | + fields.push(elt); |
| 230 | + } |
| 231 | + |
| 232 | + let remainder = &bits[fields_len * Fp::MODULUS_BIT_SIZE as usize..]; |
| 233 | + // Delete the final bits according to the bits length |
| 234 | + let bits = remainder.iter().take(bits_len).collect::<BitVec<u8>>(); |
| 235 | + |
| 236 | + let roi = ROInput { fields, bits }; |
| 237 | + |
| 238 | + Ok(roi) |
| 239 | + } |
175 | 240 | } |
176 | 241 |
|
177 | 242 | #[cfg(test)] |
@@ -909,4 +974,198 @@ mod tests { |
909 | 974 | }; |
910 | 975 | assert_ne!(b1.to_roinput(), b2.to_roinput()); |
911 | 976 | } |
| 977 | + |
| 978 | + #[test] |
| 979 | + fn serialize_empty() { |
| 980 | + let roi = ROInput::new(); |
| 981 | + |
| 982 | + let serialized = roi.serialize(); |
| 983 | + |
| 984 | + assert_eq!( |
| 985 | + serialized, |
| 986 | + vec![0; SER_HEADER_SIZE], |
| 987 | + "Serialized empty ROInput should be zero bytes" |
| 988 | + ); |
| 989 | + |
| 990 | + let deserialized_roi = |
| 991 | + ROInput::deserialize(&serialized).expect("Failed to deserialize ROInput"); |
| 992 | + assert_eq!( |
| 993 | + roi, deserialized_roi, |
| 994 | + "Serialized and deserialized ROInput do not match" |
| 995 | + ); |
| 996 | + } |
| 997 | + |
| 998 | + #[test] |
| 999 | + fn serialize_single_field() { |
| 1000 | + let roi = ROInput::new().append_field( |
| 1001 | + Fp::from_hex("41203c6bbac14b357301e1f386d80f52123fd00f02197491b690bddfa742ca22") |
| 1002 | + .expect("failed to create field"), |
| 1003 | + ); |
| 1004 | + |
| 1005 | + let serialized = roi.serialize(); |
| 1006 | + let expected_length = SER_HEADER_SIZE + 32; // 32 bytes for the field |
| 1007 | + assert_eq!( |
| 1008 | + serialized.len(), |
| 1009 | + expected_length, |
| 1010 | + "Serialized ROInput length mismatch" |
| 1011 | + ); |
| 1012 | + assert_eq!( |
| 1013 | + serialized, |
| 1014 | + [ |
| 1015 | + 0x01, 0x00, 0x00, 0x00, // Field count |
| 1016 | + 0x00, 0x00, 0x00, 0x00, // Bit count |
| 1017 | + 0x41, 0x20, 0x3c, 0x6b, 0xba, 0xc1, 0x4b, 0x35, 0x73, 0x01, 0xe1, 0xf3, 0x86, 0xd8, |
| 1018 | + 0x0f, 0x52, 0x12, 0x3f, 0xd0, 0x0f, 0x02, 0x19, 0x74, 0x91, 0xb6, 0x90, 0xbd, 0xdf, |
| 1019 | + 0xa7, 0x42, 0xca, 0x22 |
| 1020 | + ] |
| 1021 | + .to_vec(), |
| 1022 | + "Serialized ROInput does not match expected output" |
| 1023 | + ); |
| 1024 | + |
| 1025 | + assert_eq!( |
| 1026 | + roi, |
| 1027 | + ROInput::deserialize(&serialized).expect("Failed to deserialize ROInput"), |
| 1028 | + "Serialized and deserialized ROInput do not match" |
| 1029 | + ) |
| 1030 | + } |
| 1031 | + |
| 1032 | + #[test] |
| 1033 | + fn serialize_single_bool() { |
| 1034 | + let roi = ROInput::new().append_bool(true); |
| 1035 | + |
| 1036 | + let serialized = roi.serialize(); |
| 1037 | + let expected_length = SER_HEADER_SIZE + 1; // 1 byte for the boolean |
| 1038 | + assert_eq!( |
| 1039 | + serialized.len(), |
| 1040 | + expected_length, |
| 1041 | + "Serialized ROInput length mismatch" |
| 1042 | + ); |
| 1043 | + assert_eq!( |
| 1044 | + serialized, |
| 1045 | + [ |
| 1046 | + 0x00, 0x00, 0x00, 0x00, |
| 1047 | + 0x01, 0x00, 0x00, 0x00, |
| 1048 | + 0x01 // Boolean value |
| 1049 | + ] |
| 1050 | + .to_vec(), |
| 1051 | + "Serialized ROInput does not match expected output" |
| 1052 | + ); |
| 1053 | + |
| 1054 | + assert_eq!( |
| 1055 | + roi, |
| 1056 | + ROInput::deserialize(&serialized).expect("Failed to deserialize ROInput"), |
| 1057 | + "Serialized and deserialized ROInput do not match" |
| 1058 | + ); |
| 1059 | + } |
| 1060 | + |
| 1061 | + #[test] |
| 1062 | + fn serialize_multiple_bools_length() { |
| 1063 | + for i in 0..1024 { |
| 1064 | + let roi = ROInput::new().append_bool(i % 2 == 0); |
| 1065 | + let serialized = roi.serialize(); |
| 1066 | + |
| 1067 | + // Deserialize and check if it matches |
| 1068 | + let deserialized_roi = |
| 1069 | + ROInput::deserialize(&serialized).expect("Failed to deserialize ROInput"); |
| 1070 | + assert_eq!( |
| 1071 | + roi, deserialized_roi, |
| 1072 | + "Serialized and deserialized ROInput do not match for i={}", |
| 1073 | + i |
| 1074 | + ); |
| 1075 | + } |
| 1076 | + } |
| 1077 | + |
| 1078 | + #[test] |
| 1079 | + fn deserialize_invalid() { |
| 1080 | + let invalid_data = vec![0x01, 0x00, 0x00, 0x00]; // Invalid header, missing fields and bits |
| 1081 | + |
| 1082 | + let result = ROInput::deserialize(&invalid_data); |
| 1083 | + assert!( |
| 1084 | + result.is_err(), |
| 1085 | + "Deserialization should fail for invalid data" |
| 1086 | + ); |
| 1087 | + } |
| 1088 | + |
| 1089 | + #[test] |
| 1090 | + fn deserialize_invalid_inconsistent_bitlen() { |
| 1091 | + let invalid_data = vec![ |
| 1092 | + 0x01, 0x00, 0x00, // Field count |
| 1093 | + 0x01, 0x00, 0x00, 0x00, // Bit count |
| 1094 | + 0x01, // Boolean value |
| 1095 | + // Missing bits for the boolean |
| 1096 | + ]; |
| 1097 | + |
| 1098 | + let result = ROInput::deserialize(&invalid_data); |
| 1099 | + assert!( |
| 1100 | + result.is_err(), |
| 1101 | + "Deserialization should fail for inconsistent bit length" |
| 1102 | + ); |
| 1103 | + } |
| 1104 | + |
| 1105 | + #[test] |
| 1106 | + fn deserialize_invalid_message() { |
| 1107 | + let msg = b"Test message for Mina compatibility".to_vec(); |
| 1108 | + let result = ROInput::deserialize(&msg); |
| 1109 | + assert!( |
| 1110 | + result.is_err(), |
| 1111 | + "Deserialization should fail for invalid message format" |
| 1112 | + ); |
| 1113 | + } |
| 1114 | + |
| 1115 | + #[test] |
| 1116 | + fn deserialize_invalid_fieldheader() { |
| 1117 | + let invalid_data = vec![ |
| 1118 | + 0x01, 0x00, 0x00, 0x00, // Field count |
| 1119 | + 0x01, 0x00, 0x00, 0x00, // Bit count |
| 1120 | + // Incorrect number of bytes for field header |
| 1121 | + 0x01, 0x02, 0x03, 0x04, 0x01, // Boolean value |
| 1122 | + ]; |
| 1123 | + |
| 1124 | + let result = ROInput::deserialize(&invalid_data); |
| 1125 | + assert!( |
| 1126 | + result.is_err(), |
| 1127 | + "Deserialization should fail for overflow in field header" |
| 1128 | + ); |
| 1129 | + } |
| 1130 | + |
| 1131 | + #[test] |
| 1132 | + fn serialize_tx() { |
| 1133 | + let tx_roi = ROInput::new() |
| 1134 | + .append_field( |
| 1135 | + Fp::from_hex("41203c6bbac14b357301e1f386d80f52123fd00f02197491b690bddfa742ca22") |
| 1136 | + .expect("failed to create field"), |
| 1137 | + ) |
| 1138 | + .append_field( |
| 1139 | + Fp::from_hex("992cdaf29ffe15b2bcea5d00e498ed4fffd117c197f0f98586e405f72ef88e00") |
| 1140 | + .expect("failed to create field"), |
| 1141 | + ) // source |
| 1142 | + .append_field( |
| 1143 | + Fp::from_hex("3fba4fa71bce0dfdf709d827463036d6291458dfef772ff65e87bd6d1b1e062a") |
| 1144 | + .expect("failed to create field"), |
| 1145 | + ) // receiver |
| 1146 | + .append_u64(1000000) // fee |
| 1147 | + .append_u64(1) // fee token |
| 1148 | + .append_bool(true) // fee payer pk odd |
| 1149 | + .append_u32(0) // nonce |
| 1150 | + .append_u32(u32::MAX) // valid_until |
| 1151 | + .append_bytes(&[0; 34]) // memo |
| 1152 | + .append_bool(false) // tags[0] |
| 1153 | + .append_bool(false) // tags[1] |
| 1154 | + .append_bool(false) // tags[2] |
| 1155 | + .append_bool(true) // sender pk odd |
| 1156 | + .append_bool(false) // receiver pk odd |
| 1157 | + .append_u64(1) // token_id |
| 1158 | + .append_u64(10000000000) // amount |
| 1159 | + .append_bool(false); // token_locked |
| 1160 | + |
| 1161 | + let tx_bytes = tx_roi.serialize(); |
| 1162 | + |
| 1163 | + let deserialized_roi = |
| 1164 | + ROInput::deserialize(&tx_bytes).expect("Failed to deserialize ROInput"); |
| 1165 | + |
| 1166 | + assert_eq!( |
| 1167 | + tx_roi, deserialized_roi, |
| 1168 | + "Serialized and deserialized ROInput do not match" |
| 1169 | + ); |
| 1170 | + } |
912 | 1171 | } |
0 commit comments