Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ cargo run --example function_evaluation
cargo run --example polynomial_evaluation
cargo run --example simple_integers
cargo run --example simple_real_numbers
cargo run --example dcrt_poly
```

# Contributing
Expand Down
12 changes: 12 additions & 0 deletions examples/dcrt_poly.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use num_bigint::BigUint;
use num_traits::Num;
use openfhe::ffi::{self, GetMatrixElement};
use openfhe::parse_coefficients_bytes;

fn main() {
let val = String::from("123456789099999");
Expand Down Expand Up @@ -56,6 +57,17 @@ fn main() {
let coeffs_poly_add = poly_add.GetCoefficients();
println!("coeffs_poly_add: {:?}", coeffs_poly_add);

let coeffs_poly_bytes = poly.GetCoefficientsBytes();
println!("coeffs_poly_bytes: {:?}", coeffs_poly_bytes);

// decode coeff_poly_bytes
let parsed_coefficients = parse_coefficients_bytes(&coeffs_poly_bytes);
println!("decoded mod: {:?}", parsed_coefficients.modulus);
println!(
"decoded coeffs: {:?}",
parsed_coefficients.coefficients
);

let poly_modulus = poly.GetModulus();
assert_eq!(poly_modulus, modulus);

Expand Down
28 changes: 28 additions & 0 deletions src/DCRTPoly.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,34 @@ rust::Vec<rust::String> DCRTPoly::GetCoefficients() const
return result;
}

rust::Vec<rust::u8> DCRTPoly::GetCoefficientsBytes() const
{
auto tempPoly = m_poly;
tempPoly.SetFormat(Format::COEFFICIENT);

lbcrypto::DCRTPoly::PolyLargeType polyLarge = tempPoly.CRTInterpolate();

const lbcrypto::BigVector &coeffs = polyLarge.GetValues();

// Serialize the coefficients to a binary format
std::stringstream ss;
lbcrypto::Serial::Serialize(coeffs, ss, lbcrypto::SerType::BINARY);

// Get the binary data as a string
std::string serializedData = ss.str();

// Convert to a rust::Vec<rust::u8>
rust::Vec<rust::u8> result;
result.reserve(serializedData.size());

// Copy each byte from the serialized data to the result vector
for (size_t i = 0; i < serializedData.size(); ++i) {
result.push_back(static_cast<rust::u8>(static_cast<unsigned char>(serializedData[i])));
}

return result;
}

std::unique_ptr<DCRTPoly> DCRTPoly::Negate() const
{
return std::make_unique<DCRTPoly>(-m_poly);
Expand Down
2 changes: 2 additions & 0 deletions src/DCRTPoly.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "openfhe/core/lattice/hal/lat-backend.h"
#include "rust/cxx.h"
#include "openfhe/core/math/matrix.h"
#include "openfhe/core/utils/serial.h"

namespace openfhe
{
Expand All @@ -23,6 +24,7 @@ class DCRTPoly final
[[nodiscard]] rust::String GetString() const;
[[nodiscard]] bool IsEqual(const DCRTPoly& other) const noexcept;
[[nodiscard]] rust::Vec<rust::String> GetCoefficients() const;
[[nodiscard]] rust::Vec<rust::u8> GetCoefficientsBytes() const;
[[nodiscard]] rust::String GetModulus() const;
[[nodiscard]] std::unique_ptr<DCRTPoly> Negate() const;
[[nodiscard]] std::unique_ptr<Matrix> Decompose() const;
Expand Down
4 changes: 3 additions & 1 deletion src/Trapdoor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ std::unique_ptr<Matrix> DCRTSquareMatTrapdoorGaussSamp(usint n, usint k, const M
{
lbcrypto::DCRTPoly::DggType dgg(sigma);

size_t d = U.GetCols(); // U is a square matrix

double c = (base + 1) * sigma;
double s = lbcrypto::SPECTRAL_BOUND(n, k, base);
double s = lbcrypto::SPECTRAL_BOUND_D(n, k, base, d);
lbcrypto::DCRTPoly::DggType dggLargeSigma(sqrt(s * s - c * c));

auto result = lbcrypto::RLWETrapdoorUtility<lbcrypto::DCRTPoly>::GaussSampSquareMat(
Expand Down
73 changes: 73 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use cxx::{CxxVector, let_cxx_string};
pub use cxx;
use num_bigint::BigUint;

#[cxx::bridge(namespace = "openfhe")]
pub mod ffi
Expand Down Expand Up @@ -729,6 +730,7 @@ pub mod ffi
fn GetString(self: &DCRTPoly) -> String;
fn IsEqual(self: &DCRTPoly, other: &DCRTPoly) -> bool;
fn GetCoefficients(self: &DCRTPoly) -> Vec<String>;
fn GetCoefficientsBytes(self: &DCRTPoly) -> Vec<u8>;
fn GetModulus(self: &DCRTPoly) -> String;
fn Negate(self: &DCRTPoly) -> UniquePtr<DCRTPoly>;
fn Decompose(self: &DCRTPoly) -> UniquePtr<Matrix>;
Expand Down Expand Up @@ -1233,6 +1235,77 @@ impl PartialEq for DCRTPoly {
}
}

pub struct ParsedCoefficients {
pub coefficients: Vec<BigUint>,
pub modulus: BigUint,
}

/// Parses raw bytes from the serialized format into a vector of BigUint values
/// Returns a vector containing all coefficients followed by the modulus as the last element
pub fn parse_coefficients_bytes(bytes: &[u8]) -> ParsedCoefficients {
if bytes.len() < 14 {
return ParsedCoefficients {
coefficients: Vec::new(),
modulus: BigUint::from(0u32),
};
}

// Number of coefficients
let coeff_count = u64::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8],
bytes[9], bytes[10], bytes[11], bytes[12]]) as usize;

let mut coefficients = Vec::with_capacity(coeff_count);
let mut offset = 17; // Start after the header

// Parse coefficients
for _ in 0..coeff_count {
if offset + 8 > bytes.len() { break; }

// Number of chunks for this coefficient
let chunk_count = u64::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7]]) as usize;
offset += 8;

// Read and combine chunks
let mut value = BigUint::from(0u32);
for i in 0..chunk_count {
if offset + 8 > bytes.len() { break; }

let chunk = u64::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7]]);
// Add chunk with proper shifting (little-endian format)
value += BigUint::from(chunk) << (i * 64);
offset += 8;
}

coefficients.push(value);
offset += 4; // Skip the m value
}

// Parse modulus
let mut modulus = BigUint::from(0u32);
if offset + 8 <= bytes.len() {
let mod_chunk_count = u64::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7]]) as usize;
offset += 8;

for i in 0..mod_chunk_count {
if offset + 8 > bytes.len() { break; }

let chunk = u64::from_le_bytes([bytes[offset], bytes[offset+1], bytes[offset+2], bytes[offset+3],
bytes[offset+4], bytes[offset+5], bytes[offset+6], bytes[offset+7]]);
// Add chunk with proper shifting (little-endian format)
modulus += BigUint::from(chunk) << (i * 64);
offset += 8;
}
}

ParsedCoefficients {
coefficients,
modulus,
}
}

#[cfg(test)]
mod tests
{
Expand Down