Skip to content

Commit 9cdb147

Browse files
committed
feat(shortint): adds generic client key for atomic pattern support
1 parent 8278a93 commit 9cdb147

File tree

64 files changed

+1421
-822
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1421
-822
lines changed

tests/backward_compatibility/shortint.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,13 @@ pub fn test_shortint_clientkey(
138138

139139
let key: ClientKey = load_and_unversionize(dir, test, format)?;
140140

141-
if test_params != key.parameters {
141+
if test_params != key.parameters() {
142142
Err(test.failure(
143143
format!(
144144
"Invalid {} parameters:\n Expected :\n{:?}\nGot:\n{:?}",
145-
format, test_params, key.parameters
145+
format,
146+
test_params,
147+
key.parameters()
146148
),
147149
format,
148150
))

tfhe-benchmark/benches/shortint/bench.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ fn bench_server_key_unary_function<F>(
2727

2828
let mut rng = rand::thread_rng();
2929

30-
let modulus = cks.parameters.message_modulus().0;
30+
let modulus = cks.parameters().message_modulus().0;
3131

3232
let clear_text = rng.gen::<u64>() % modulus;
3333

@@ -70,7 +70,7 @@ fn bench_server_key_binary_function<F>(
7070

7171
let mut rng = rand::thread_rng();
7272

73-
let modulus = cks.parameters.message_modulus().0;
73+
let modulus = cks.parameters().message_modulus().0;
7474

7575
let clear_0 = rng.gen::<u64>() % modulus;
7676
let clear_1 = rng.gen::<u64>() % modulus;
@@ -115,7 +115,7 @@ fn bench_server_key_binary_scalar_function<F>(
115115

116116
let mut rng = rand::thread_rng();
117117

118-
let modulus = cks.parameters.message_modulus().0;
118+
let modulus = cks.parameters().message_modulus().0;
119119

120120
let clear_0 = rng.gen::<u64>() % modulus;
121121
let clear_1 = rng.gen::<u64>() % modulus;
@@ -159,7 +159,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
159159

160160
let mut rng = rand::thread_rng();
161161

162-
let modulus = cks.parameters.message_modulus().0;
162+
let modulus = cks.parameters().message_modulus().0;
163163
assert_ne!(modulus, 1);
164164

165165
let clear_0 = rng.gen::<u64>() % modulus;
@@ -200,7 +200,7 @@ fn carry_extract_bench(c: &mut Criterion) {
200200

201201
let mut rng = rand::thread_rng();
202202

203-
let modulus = cks.parameters.message_modulus().0;
203+
let modulus = cks.parameters().message_modulus().0;
204204

205205
let clear_0 = rng.gen::<u64>() % modulus;
206206

@@ -236,7 +236,7 @@ fn programmable_bootstrapping_bench(c: &mut Criterion) {
236236

237237
let mut rng = rand::thread_rng();
238238

239-
let modulus = cks.parameters.message_modulus().0;
239+
let modulus = cks.parameters().message_modulus().0;
240240

241241
let acc = sks.generate_lookup_table(|x| x);
242242

tfhe/docs/references/fine-grained-apis/quick_start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ fn main() {
114114
let msg1 = 1;
115115
let msg2 = 0;
116116

117-
let modulus = client_key.parameters.message_modulus().0;
117+
let modulus = client_key.parameters().message_modulus().0;
118118

119119
// We use the client key to encrypt two messages:
120120
let ct_1 = client_key.encrypt(msg1);

tfhe/docs/references/fine-grained-apis/shortint/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn main() {
8686
let msg1 = 1;
8787
let msg2 = 0;
8888

89-
let modulus = client_key.parameters.message_modulus().0;
89+
let modulus = client_key.parameters().message_modulus().0;
9090

9191
// We use the client key to encrypt two messages:
9292
let ct_1 = client_key.encrypt(msg1);

tfhe/docs/references/fine-grained-apis/shortint/operations.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ fn main() {
5959
let msg2 = 3;
6060
let scalar = 4;
6161

62-
let modulus = client_key.parameters.message_modulus().0;
62+
let modulus = client_key.parameters().message_modulus().0;
6363

6464
// We use the client key to encrypt two messages:
6565
let mut ct_1 = client_key.encrypt(msg1);
@@ -91,7 +91,7 @@ fn main() {
9191
let msg2 = 3;
9292
let scalar = 4;
9393

94-
let modulus = client_key.parameters.message_modulus().0;
94+
let modulus = client_key.parameters().message_modulus().0;
9595

9696
// We use the client key to encrypt two messages:
9797
let mut ct_1 = client_key.encrypt(msg1);
@@ -134,7 +134,7 @@ fn main() {
134134
let msg2 = 3;
135135
let scalar = 4;
136136

137-
let modulus = client_key.parameters.message_modulus().0;
137+
let modulus = client_key.parameters().message_modulus().0;
138138

139139
// We use the client key to encrypt two messages:
140140
let mut ct_1 = client_key.encrypt(msg1);
@@ -168,7 +168,7 @@ fn main() {
168168
let msg2 = 3;
169169
let scalar = 4;
170170

171-
let modulus = client_key.parameters.message_modulus().0;
171+
let modulus = client_key.parameters().message_modulus().0;
172172

173173
// We use the client key to encrypt two messages:
174174
let mut ct_1 = client_key.encrypt(msg1);
@@ -244,7 +244,7 @@ fn main() {
244244
let msg1 = 2;
245245
let msg2 = 1;
246246

247-
let modulus = client_key.parameters.message_modulus().0;
247+
let modulus = client_key.parameters().message_modulus().0;
248248

249249
// We use the private client key to encrypt two messages:
250250
let ct_1 = client_key.encrypt(msg1);
@@ -275,7 +275,7 @@ fn main() {
275275
let msg1 = 2;
276276
let msg2 = 1;
277277

278-
let modulus = client_key.parameters.message_modulus().0;
278+
let modulus = client_key.parameters().message_modulus().0;
279279

280280
// We use the private client key to encrypt two messages:
281281
let ct_1 = client_key.encrypt(msg1);
@@ -306,7 +306,7 @@ fn main() {
306306
let msg1 = 2;
307307
let msg2 = 1;
308308

309-
let modulus = client_key.parameters.message_modulus().0;
309+
let modulus = client_key.parameters().message_modulus().0;
310310

311311
// We use the private client key to encrypt two messages:
312312
let ct_1 = client_key.encrypt(msg1);
@@ -365,7 +365,7 @@ fn main() {
365365
let msg1 = 3;
366366
let msg2 = 2;
367367

368-
let modulus = client_key.parameters.message_modulus().0;
368+
let modulus = client_key.parameters().message_modulus().0;
369369

370370
// We use the private client key to encrypt two messages:
371371
let ct_1 = client_key.encrypt(msg1);

tfhe/src/high_level_api/keys/inner.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ impl IntegerClientKey {
110110
);
111111
let mut seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
112112
let cks = crate::shortint::engine::ShortintEngine::new_from_seeder(&mut seeder)
113-
.new_client_key(config.block_parameters.into());
113+
.new_client_key(config.block_parameters);
114114

115115
let key = crate::integer::ClientKey::from(cks);
116116

@@ -172,15 +172,15 @@ impl IntegerClientKey {
172172

173173
if let Some(dedicated_compact_private_key) = dedicated_compact_private_key.as_ref() {
174174
assert_eq!(
175-
shortint_cks.parameters.message_modulus(),
175+
shortint_cks.parameters().message_modulus(),
176176
dedicated_compact_private_key
177177
.0
178178
.key
179179
.parameters()
180180
.message_modulus,
181181
);
182182
assert_eq!(
183-
shortint_cks.parameters.carry_modulus(),
183+
shortint_cks.parameters().carry_modulus(),
184184
dedicated_compact_private_key
185185
.0
186186
.key

tfhe/src/integer/client_key/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ impl ClientKey {
143143
}
144144

145145
pub fn parameters(&self) -> crate::shortint::AtomicPatternParameters {
146-
self.key.parameters.ap_parameters().unwrap()
146+
self.key.parameters().ap_parameters().unwrap()
147147
}
148148

149149
#[cfg(test)]
@@ -333,7 +333,7 @@ impl ClientKey {
333333
return T::ZERO;
334334
}
335335

336-
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
336+
let bits_in_block = self.key.parameters().message_modulus().0.ilog2();
337337
let decrypted_block_iter = blocks.iter().map(|block| decrypt_block(&self.key, block));
338338
BlockRecomposer::recompose_unsigned(decrypted_block_iter, bits_in_block)
339339
}
@@ -417,7 +417,7 @@ impl ClientKey {
417417
return T::ZERO;
418418
}
419419

420-
let bits_in_block = self.key.parameters.message_modulus().0.ilog2();
420+
let bits_in_block = self.key.parameters().message_modulus().0.ilog2();
421421
let decrypted_block_iter = ctxt
422422
.blocks
423423
.iter()

tfhe/src/integer/encryption.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub(crate) trait KnowsMessageModulus {
77

88
impl KnowsMessageModulus for crate::shortint::ClientKey {
99
fn message_modulus(&self) -> MessageModulus {
10-
self.parameters.message_modulus()
10+
self.parameters().message_modulus()
1111
}
1212
}
1313

tfhe/src/integer/gpu/client_key/radix.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::integer::gpu::list_compression::server_keys::{
1010
};
1111
use crate::integer::gpu::server_key::CudaBootstrappingKey;
1212
use crate::integer::RadixClientKey;
13+
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
1314
use crate::shortint::engine::ShortintEngine;
1415
use crate::shortint::EncryptionKeyChoice;
1516

@@ -21,7 +22,11 @@ impl RadixClientKey {
2122
) -> (CudaCompressionKey, CudaDecompressionKey) {
2223
let private_compression_key = &private_compression_key.key;
2324

24-
let params = &private_compression_key.params;
25+
let compression_params = &private_compression_key.params;
26+
27+
let AtomicPatternClientKey::Standard(std_cks) = &self.as_ref().key.atomic_pattern else {
28+
panic!("Only the standard atomic pattern is supported on GPU")
29+
};
2530

2631
assert_eq!(
2732
self.parameters().encryption_key_choice(),
@@ -32,11 +37,11 @@ impl RadixClientKey {
3237
// Compression key
3338
let packing_key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
3439
allocate_and_generate_new_lwe_packing_keyswitch_key(
35-
&self.as_ref().key.large_lwe_secret_key(),
40+
&std_cks.large_lwe_secret_key(),
3641
&private_compression_key.post_packing_ks_key,
37-
params.packing_ks_base_log,
38-
params.packing_ks_level,
39-
params.packing_ks_key_noise_distribution,
42+
compression_params.packing_ks_base_log,
43+
compression_params.packing_ks_level,
44+
compression_params.packing_ks_key_noise_distribution,
4045
self.parameters().ciphertext_modulus(),
4146
&mut engine.encryption_generator,
4247
)
@@ -45,7 +50,7 @@ impl RadixClientKey {
4550
let glwe_compression_key = CompressionKey {
4651
key: crate::shortint::list_compression::CompressionKey {
4752
packing_key_switching_key,
48-
lwe_per_glwe: params.lwe_per_glwe,
53+
lwe_per_glwe: compression_params.lwe_per_glwe,
4954
storage_log_modulus: private_compression_key.params.storage_log_modulus,
5055
},
5156
};
@@ -60,9 +65,9 @@ impl RadixClientKey {
6065
self.parameters().polynomial_size(),
6166
private_compression_key.params.br_base_log,
6267
private_compression_key.params.br_level,
63-
params
68+
compression_params
6469
.packing_ks_glwe_dimension
65-
.to_equivalent_lwe_dimension(params.packing_ks_polynomial_size),
70+
.to_equivalent_lwe_dimension(compression_params.packing_ks_polynomial_size),
6671
self.parameters().ciphertext_modulus(),
6772
);
6873

@@ -71,7 +76,7 @@ impl RadixClientKey {
7176
&private_compression_key
7277
.post_packing_ks_key
7378
.as_lwe_secret_key(),
74-
&self.as_ref().key.glwe_secret_key,
79+
&std_cks.glwe_secret_key,
7580
&mut bsk,
7681
self.parameters().glwe_noise_distribution(),
7782
&mut engine.encryption_generator,
@@ -84,7 +89,7 @@ impl RadixClientKey {
8489

8590
let cuda_decompression_key = CudaDecompressionKey {
8691
blind_rotate_key,
87-
lwe_per_glwe: params.lwe_per_glwe,
92+
lwe_per_glwe: compression_params.lwe_per_glwe,
8893
glwe_dimension: self.parameters().glwe_dimension(),
8994
polynomial_size: self.parameters().polynomial_size(),
9095
message_modulus: self.parameters().message_modulus(),

tfhe/src/integer/gpu/key_switching_key/mod.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,23 @@ impl<'keys> CudaKeySwitchingKey<'keys> {
2626
{
2727
let input_secret_key: SecretEncryptionKeyView<'_> = input_key_pair.0.into();
2828

29+
let std_cks = output_key_pair
30+
.0
31+
.key
32+
.as_view()
33+
.try_into()
34+
.expect("Only the standard atomic pattern is supported on GPU");
35+
2936
// Creation of the key switching key
3037
let key_switching_key = ShortintEngine::with_thread_local_mut(|engine| {
31-
engine.new_key_switching_key(&input_secret_key.key, output_key_pair.0.as_ref(), params)
38+
engine.new_key_switching_key(&input_secret_key.key, std_cks, params)
3239
});
3340
let d_key_switching_key =
3441
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&key_switching_key, streams);
3542
let full_message_modulus_input =
3643
input_secret_key.key.carry_modulus.0 * input_secret_key.key.message_modulus.0;
37-
let full_message_modulus_output = output_key_pair.0.key.parameters.carry_modulus().0
38-
* output_key_pair.0.key.parameters.message_modulus().0;
44+
let full_message_modulus_output = output_key_pair.0.key.parameters().carry_modulus().0
45+
* output_key_pair.0.key.parameters().message_modulus().0;
3946
assert!(
4047
full_message_modulus_input.is_power_of_two()
4148
&& full_message_modulus_output.is_power_of_two(),

0 commit comments

Comments
 (0)