Skip to content

Commit 379c5ae

Browse files
committed
feat(gpu): add support for GPU-accelerated expand on the HL Api
- includes documentation about GPU's accelerated expand on the HL API - rework CudaKeySwitchingKey - Cloning the key is no longer necessary on the HL API
1 parent d197a2a commit 379c5ae

File tree

17 files changed

+837
-215
lines changed

17 files changed

+837
-215
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ check_typos: install_typos_checker
290290
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
291291
clippy_gpu: install_rs_check_toolchain
292292
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
293-
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
293+
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types,zk-pok \
294294
--all-targets \
295295
-p $(TFHE_SPEC) -- --no-deps -D warnings
296296

@@ -854,7 +854,7 @@ test_high_level_api: install_rs_build_toolchain
854854

855855
test_high_level_api_gpu: install_rs_build_toolchain install_cargo_nextest
856856
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) nextest run --cargo-profile $(CARGO_PROFILE) \
857-
--features=integer,internal-keycache,gpu -p $(TFHE_SPEC) \
857+
--features=integer,internal-keycache,gpu,zk-pok -p $(TFHE_SPEC) \
858858
-E "test(/high_level_api::.*gpu.*/)"
859859

860860
.PHONY: test_strings # Run the tests for strings ci
@@ -1012,7 +1012,7 @@ check_compile_tests: install_rs_build_toolchain
10121012
.PHONY: check_compile_tests_benches_gpu # Build tests in debug without running them
10131013
check_compile_tests_benches_gpu: install_rs_build_toolchain
10141014
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
1015-
--features=experimental,boolean,shortint,integer,internal-keycache,gpu \
1015+
--features=experimental,boolean,shortint,integer,internal-keycache,gpu,zk-pok \
10161016
-p $(TFHE_SPEC)
10171017
mkdir -p "$(TFHECUDA_BUILD)" && \
10181018
cd "$(TFHECUDA_BUILD)" && \

tfhe/benches/integer/zk_pke.rs

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,11 @@ fn cpu_pke_zk_verify(c: &mut Criterion, results_file: &Path) {
434434
#[cfg(all(feature = "gpu", feature = "zk-pok"))]
435435
mod cuda {
436436
use super::*;
437-
use crate::utilities::{cuda_local_keys, cuda_local_streams};
437+
use crate::utilities::cuda_local_streams;
438438
use criterion::BatchSize;
439439
use itertools::Itertools;
440440
use tfhe::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
441-
use tfhe::integer::gpu::key_switching_key::CudaKeySwitchingKey;
441+
use tfhe::integer::gpu::key_switching_key::{CudaKeySwitchingKey, CudaKeySwitchingKeyMaterial};
442442
use tfhe::integer::gpu::zk::CudaProvenCompactCiphertextList;
443443
use tfhe::integer::gpu::CudaServerKey;
444444
use tfhe::integer::CompressedServerKey;
@@ -467,14 +467,17 @@ mod cuda {
467467
let param_name = param_name.as_str();
468468
let cks = ClientKey::new(param_fhe);
469469
let compressed_server_key = CompressedServerKey::new_radix_compressed_server_key(&cks);
470+
let sk = compressed_server_key.decompress();
470471
let gpu_sks = CudaServerKey::decompress_from_cpu(&compressed_server_key, &streams);
472+
471473
let compact_private_key = CompactPrivateKey::new(param_pke);
472474
let pk = CompactPublicKey::new(&compact_private_key);
473-
let d_ksk = CudaKeySwitchingKey::new(
474-
(&compact_private_key, None),
475-
(&cks, &gpu_sks),
476-
param_ksk,
477-
&streams,
475+
let ksk = KeySwitchingKey::new((&compact_private_key, None), (&cks, &sk), param_ksk);
476+
let d_ksk_material =
477+
CudaKeySwitchingKeyMaterial::from_key_switching_key(&ksk, &streams);
478+
let d_ksk = CudaKeySwitchingKey::from_cuda_key_switching_key_material(
479+
&d_ksk_material,
480+
&gpu_sks,
478481
);
479482

480483
// We have a use case with 320 bits of metadata
@@ -625,7 +628,6 @@ mod cuda {
625628
});
626629
}
627630
BenchmarkType::Throughput => {
628-
let gpu_sks_vec = cuda_local_keys(&cks);
629631
let gpu_count = get_number_of_gpus() as usize;
630632

631633
// Execute the operation once to know its cost.
@@ -669,27 +671,18 @@ mod cuda {
669671
})
670672
.collect::<Vec<_>>();
671673

672-
let gpu_cts = cts.iter().map(|ct| {
673-
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
674-
&ct, &streams,
675-
)
676-
}).collect_vec();
677-
678674
let local_streams = cuda_local_streams(num_block, elements as usize);
679-
let d_ksk_vec = gpu_sks_vec
675+
let d_ksk_material_vec = local_streams
680676
.par_iter()
681-
.zip(local_streams.par_iter())
682-
.map(|(gpu_sks, local_stream)| {
683-
CudaKeySwitchingKey::new(
684-
(&compact_private_key, None),
685-
(&cks, &gpu_sks),
686-
param_ksk,
677+
.map(|local_stream| {
678+
CudaKeySwitchingKeyMaterial::from_key_switching_key(
679+
&ksk,
687680
local_stream,
688681
)
689682
})
690683
.collect::<Vec<_>>();
691684

692-
assert_eq!(d_ksk_vec.len(), gpu_count);
685+
assert_eq!(d_ksk_material_vec.len(), gpu_count);
693686

694687
bench_group.bench_function(&bench_id_verify, |b| {
695688
b.iter(|| {
@@ -705,7 +698,7 @@ mod cuda {
705698

706699
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
707700
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
708-
&ct, &local_streams[i],
701+
ct, &local_streams[i],
709702
)
710703
}).collect_vec();
711704

@@ -716,8 +709,11 @@ mod cuda {
716709
|(gpu_cts, local_streams)| {
717710
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
718711
(|(i, (gpu_ct, local_stream))| {
712+
let d_ksk =
713+
CudaKeySwitchingKey::from_cuda_key_switching_key_material(&d_ksk_material_vec[i % gpu_count], &gpu_sks);
714+
719715
gpu_ct
720-
.expand_without_verification(&d_ksk_vec[i % gpu_count], local_stream)
716+
.expand_without_verification(&d_ksk, local_stream)
721717
.unwrap();
722718
});
723719
}, BatchSize::SmallInput);
@@ -729,7 +725,7 @@ mod cuda {
729725

730726
let gpu_cts = cts.iter().enumerate().map(|(i, ct)| {
731727
CudaProvenCompactCiphertextList::from_proven_compact_ciphertext_list(
732-
&ct, &local_streams[i],
728+
ct, &local_streams[i],
733729
)
734730
}).collect_vec();
735731

@@ -738,8 +734,8 @@ mod cuda {
738734

739735
b.iter_batched(setup_encrypted_values,
740736
|(gpu_cts, local_streams)| {
741-
gpu_cts.par_iter().zip(local_streams.par_iter()).enumerate().for_each
742-
(|(i, (gpu_ct, local_stream))| {
737+
gpu_cts.par_iter().zip(local_streams.par_iter()).for_each
738+
(|(gpu_ct, local_stream)| {
743739
gpu_ct
744740
.verify_and_expand(
745741
&crs, &pk, &metadata, &d_ksk, local_stream,
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Zero-knowledge proofs
2+
3+
Zero-knowledge proofs (ZK) are a powerful tool to assert that the encryption of a message is correct, as discussed in [advanced features](../../fhe-computation/advanced-features/zk-pok.md).
4+
However, computation is not possible on the type of ciphertexts it produces (i.e. `ProvenCompactCiphertextList`). This document explains how to use the GPU to accelerate the
5+
preprocessing step needed to convert ciphertexts formatted for ZK to ciphertexts in the right format for computation purposes on GPU. This
6+
operation is called "expansion".
7+
8+
## Proven compact ciphertext list
9+
10+
A proven compact list of ciphertexts can be seen as a compacted collection of ciphertexts which encryption can be verified.
11+
This verification is currently only supported on the CPU, but the expansion can be accelerated using the GPU.
12+
This way, verification and expansion can be performed in parallel, efficiently using all the available computational resources.
13+
14+
## Supported types
15+
Encrypted messages can be integers (like FheUint64) or booleans. The GPU backend does not currently support encrypted strings.
16+
17+
{% hint style="info" %}
18+
You can enable this feature using the flag: `--features=zk-pok,gpu` when building **TFHE-rs**.
19+
{% endhint %}
20+
21+
22+
## Example
23+
24+
The following example shows how a client can encrypt and prove a ciphertext, and how a server can verify the proof, preprocess the ciphertext and run a computation on it on GPU:
25+
26+
```rust
27+
use rand::random;
28+
use tfhe::CompressedServerKey;
29+
use tfhe::prelude::*;
30+
use tfhe::set_server_key;
31+
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
32+
33+
pub fn main() -> Result<(), Box<dyn std::error::Error>> {
34+
let params = tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
35+
// Indicate which parameters to use for the Compact Public Key encryption
36+
let cpk_params = tfhe::shortint::parameters::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
37+
// And parameters allowing to keyswitch/cast to the computation parameters.
38+
let casting_params = tfhe::shortint::parameters::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
39+
// Enable the dedicated parameters on the config
40+
let config = tfhe::ConfigBuilder::with_custom_parameters(params)
41+
.use_dedicated_compact_public_key_parameters((cpk_params, casting_params)).build();
42+
43+
// The CRS should be generated in an offline phase then shared to all clients and the server
44+
let crs = CompactPkeCrs::from_config(config, 64).unwrap();
45+
46+
// Then use TFHE-rs as usual
47+
let client_key = tfhe::ClientKey::generate(config);
48+
let compressed_server_key = CompressedServerKey::new(&client_key);
49+
let gpu_server_key = compressed_server_key.decompress_to_gpu();
50+
51+
let public_key = tfhe::CompactPublicKey::try_new(&client_key).unwrap();
52+
// This can be left empty, but if provided allows to tie the proof to arbitrary data
53+
let metadata = [b'T', b'F', b'H', b'E', b'-', b'r', b's'];
54+
55+
let clear_a = random::<u64>();
56+
let clear_b = random::<u64>();
57+
58+
let proven_compact_list = tfhe::ProvenCompactCiphertextList::builder(&public_key)
59+
.push(clear_a)
60+
.push(clear_b)
61+
.build_with_proof_packed(&crs, &metadata, ZkComputeLoad::Verify)?;
62+
63+
// Server side
64+
let result = {
65+
set_server_key(gpu_server_key);
66+
67+
// Verify the ciphertexts
68+
let expander =
69+
proven_compact_list.verify_and_expand(&crs, &public_key, &metadata)?;
70+
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
71+
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
72+
73+
a + b
74+
};
75+
76+
// Back on the client side
77+
let a_plus_b: u64 = result.decrypt(&client_key);
78+
assert_eq!(a_plus_b, clear_a.wrapping_add(clear_b));
79+
80+
Ok(())
81+
}
82+
```

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::core_crypto::prelude::{
77
use tfhe_cuda_backend::cuda_bind::cuda_memcpy_async_gpu_to_gpu;
88

99
/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU.
10-
#[derive(Debug)]
10+
#[derive(Clone, Debug)]
1111
pub struct CudaLweCiphertextList<T: UnsignedInteger>(pub(crate) CudaLweList<T>);
1212

1313
#[allow(dead_code)]

tfhe/src/core_crypto/gpu/entities/lwe_compact_ciphertext_list.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ use crate::core_crypto::prelude::{
1010
pub struct CudaLweCompactCiphertextList<T: UnsignedInteger>(pub CudaLweList<T>);
1111

1212
impl<T: UnsignedInteger> CudaLweCompactCiphertextList<T> {
13+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
14+
Self(self.0.duplicate(streams))
15+
}
16+
1317
pub fn from_lwe_compact_ciphertext_list<C: Container<Element = T>>(
1418
h_ct: &LweCompactCiphertextList<C>,
1519
streams: &CudaStreams,

tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
1010
UnsignedInteger,
1111
};
1212

13+
#[derive(Clone)]
1314
#[allow(dead_code)]
1415
pub struct CudaLweKeyswitchKey<T: UnsignedInteger> {
1516
pub(crate) d_vec: CudaVec<T>,

tfhe/src/core_crypto/gpu/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ pub unsafe fn fourier_transform_backward_as_torus_f128_async<T: UnsignedInteger>
993993
);
994994
}
995995

996-
#[derive(Debug)]
996+
#[derive(Clone, Debug)]
997997
pub struct CudaLweList<T: UnsignedInteger> {
998998
// Pointer to GPU data
999999
pub d_vec: CudaVec<T>,
@@ -1005,6 +1005,17 @@ pub struct CudaLweList<T: UnsignedInteger> {
10051005
pub ciphertext_modulus: CiphertextModulus<T>,
10061006
}
10071007

1008+
impl<T: UnsignedInteger> CudaLweList<T> {
1009+
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
1010+
Self {
1011+
d_vec: self.d_vec.duplicate(streams),
1012+
lwe_ciphertext_count: self.lwe_ciphertext_count,
1013+
lwe_dimension: self.lwe_dimension,
1014+
ciphertext_modulus: self.ciphertext_modulus,
1015+
}
1016+
}
1017+
}
1018+
10081019
#[derive(Debug, Clone)]
10091020
pub struct CudaGlweList<T: UnsignedInteger> {
10101021
// Pointer to GPU data

tfhe/src/high_level_api/backward_compatibility/compact_list.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
33

44
use crate::{CompactCiphertextList, Tag};
55

6+
#[cfg(feature = "zk-pok")]
7+
use crate::ProvenCompactCiphertextList;
8+
69
#[derive(Version)]
710
pub struct CompactCiphertextListV0(crate::integer::ciphertext::CompactCiphertextList);
811

@@ -17,9 +20,6 @@ impl Upgrade<CompactCiphertextList> for CompactCiphertextListV0 {
1720
}
1821
}
1922

20-
#[cfg(feature = "zk-pok")]
21-
use crate::ProvenCompactCiphertextList;
22-
2323
#[derive(VersionsDispatch)]
2424
pub enum CompactCiphertextListVersions {
2525
V0(CompactCiphertextListV0),

0 commit comments

Comments
 (0)