Skip to content

Commit 3b28051

Browse files
chore(gpu): bench KS latency batches
1 parent f9268b8 commit 3b28051

File tree

22 files changed

+1071
-148
lines changed

22 files changed

+1071
-148
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ build_debug_integer_short_run_gpu: install_rs_check_toolchain install_cargo_next
773773
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile debug_lto_off \
774774
--features=integer,gpu-debug-fake-multi-gpu -p tfhe -- integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random --list
775775
@echo "To debug fake-multi-gpu short run tests run:"
776-
@echo "TFHE_RS_TEST_LONG_TESTS_MINIMAL=TRUE <executable> integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random_op_sequence_param_gpu_multi_bit_group_4_message_2_carry_2_ks_pbs_tuniform_2m128 --nocapture"
776+
@echo "TFHE_RS_LONGRUN_TESTS_SEED=<SEED_FROM_CI> TFHE_RS_TEST_LONG_TESTS_MINIMAL=TRUE <executable> integer::gpu::server_key::radix::tests_long_run::test_random_op_sequence::test_gpu_short_random_op_sequence_param_gpu_multi_bit_group_4_message_2_carry_2_ks_pbs_tuniform_2m128 --nocapture"
777777
@echo "Where <executable> = the one printed in the () in the 'Running unittests src/lib.rs ()' line above"
778778

779779
.PHONY: test_integer_compression
@@ -806,7 +806,7 @@ test_unsigned_integer_gpu_ci: install_rs_check_toolchain install_cargo_nextest
806806
NIGHTLY_TESTS="$(NIGHTLY_TESTS)" \
807807
./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_CHECK_TOOLCHAIN) \
808808
--cargo-profile "$(CARGO_PROFILE)" --backend "gpu" \
809-
--unsigned-only --tfhe-package "tfhe"
809+
--unsigned-only --tfhe-package "tfhe" -- --nocapture
810810

811811
.PHONY: test_signed_integer_gpu_ci # Run the tests for signed integer ci on gpu backend
812812
test_signed_integer_gpu_ci: install_rs_check_toolchain install_cargo_nextest

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include <stdio.h>
1515

16+
#include "crypto/keyswitch.cuh"
17+
1618
class NoiseLevel {
1719
public:
1820
// Constants equivalent to the Rust code
@@ -336,7 +338,11 @@ struct int_radix_lut_custom_input_output {
336338
std::vector<InputTorus *> lwe_after_ks_vec;
337339
std::vector<OutputTorus *> lwe_after_pbs_vec;
338340
std::vector<InputTorus *> lwe_trivial_indexes_vec;
341+
std::vector<ks_mem<InputTorus> *>
342+
ks_tmp_buf_vec; // buffers on each GPU to store keyswitch temporary data
343+
339344
std::vector<InputTorus *> lwe_aligned_vec;
345+
uint64_t numSamplesKsTmp = 0;
340346

341347
bool gpu_memory_allocated;
342348

@@ -439,6 +445,26 @@ struct int_radix_lut_custom_input_output {
439445
multi_gpu_copy_array_async(active_streams, lwe_trivial_indexes_vec,
440446
lwe_trivial_indexes, num_radix_blocks,
441447
allocate_gpu_memory);
448+
449+
auto inputs_on_gpu = std::max(
450+
THRESHOLD_MULTI_GPU,
451+
get_num_inputs_on_gpu(num_radix_blocks, 0, active_streams.count()));
452+
453+
this->numSamplesKsTmp = inputs_on_gpu;
454+
if (inputs_on_gpu >= 144) {
455+
for (auto i = 0; i < active_streams.count(); ++i) {
456+
ks_mem<InputTorus> *ks_buffer;
457+
uint64_t sub_size_tracker = scratch_cuda_keyswitch<InputTorus>(
458+
active_streams.stream(i), active_streams.gpu_index(i), &ks_buffer,
459+
params.small_lwe_dimension, params.big_lwe_dimension, num_blocks,
460+
allocate_gpu_memory);
461+
462+
if (i == 0) {
463+
size_tracker += sub_size_tracker;
464+
}
465+
ks_tmp_buf_vec.push_back(ks_buffer);
466+
}
467+
}
442468
}
443469

444470
void setup_mem_reuse(uint32_t num_radix_blocks,
@@ -459,6 +485,8 @@ struct int_radix_lut_custom_input_output {
459485
lwe_after_pbs_vec = base_lut_object->lwe_after_pbs_vec;
460486
lwe_trivial_indexes_vec = base_lut_object->lwe_trivial_indexes_vec;
461487

488+
ks_tmp_buf_vec = base_lut_object->ks_tmp_buf_vec;
489+
462490
mem_reuse = true;
463491
}
464492

@@ -861,6 +889,13 @@ struct int_radix_lut_custom_input_output {
861889
}
862890
lwe_aligned_vec.clear();
863891
}
892+
893+
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
894+
cleanup_cuda_keyswitch(active_streams.stream(i),
895+
active_streams.gpu_index(i), ks_tmp_buf_vec[i],
896+
gpu_memory_allocated);
897+
}
898+
ks_tmp_buf_vec.clear();
864899
}
865900
free(h_lut_indexes);
866901
free(degrees);

backends/tfhe-cuda-backend/cuda/include/integer/rerand_utilities.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ template <typename Torus> struct int_rerand_mem {
1515

1616
bool gpu_memory_allocated;
1717

18+
std::vector<ks_mem<Torus> *>
19+
ks_tmp_buf_vec; // buffers on each GPU to store keyswitch temporary data
20+
1821
expand_job<Torus> *d_expand_jobs;
1922
expand_job<Torus> *h_expand_jobs;
2023

@@ -56,6 +59,21 @@ template <typename Torus> struct int_rerand_mem {
5659

5760
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
5861

62+
for (auto i = 0; i < streams.count(); ++i) {
63+
ks_mem<Torus> *ks_buffer;
64+
uint64_t sub_size_tracker = scratch_cuda_keyswitch<Torus>(
65+
streams.stream(i), streams.gpu_index(i), &ks_buffer,
66+
params.small_lwe_dimension, params.big_lwe_dimension, num_lwes,
67+
allocate_gpu_memory);
68+
69+
if (i == 0) {
70+
size_tracker += sub_size_tracker;
71+
}
72+
ks_tmp_buf_vec.push_back(ks_buffer);
73+
}
74+
75+
streams.synchronize();
76+
5977
free(h_lwe_trivial_indexes);
6078
}
6179

@@ -72,6 +90,13 @@ template <typename Torus> struct int_rerand_mem {
7290
cuda_drop_with_size_tracking_async(d_expand_jobs, streams.stream(0),
7391
streams.gpu_index(0),
7492
gpu_memory_allocated);
93+
94+
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
95+
cleanup_cuda_keyswitch(streams.stream(i), streams.gpu_index(i),
96+
ks_tmp_buf_vec[i], gpu_memory_allocated);
97+
}
98+
ks_tmp_buf_vec.clear();
99+
75100
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
76101
free(h_expand_jobs);
77102
}

backends/tfhe-cuda-backend/cuda/include/keyswitch/keyswitch.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
1717
void const *lwe_output_indexes, void const *lwe_array_in,
1818
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
1919
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
20-
uint32_t num_samples);
20+
uint32_t num_samples, const void *ks_tmp_buffer, bool uses_trivial_indexes);
2121

2222
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
2323
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
2424
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
2525
uint32_t num_lwes, bool allocate_gpu_memory);
2626

27+
uint64_t scratch_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
28+
void **ks_tmp_memory,
29+
uint32_t lwe_dimension_in,
30+
uint32_t lwe_dimension_out,
31+
uint32_t num_lwes, bool allocate_gpu_memory);
32+
33+
void cleanup_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
34+
void **ks_tmp_memory, bool allocate_gpu_memory);
35+
2736
void cuda_packing_keyswitch_lwe_list_to_glwe_64(
2837
void *stream, uint32_t gpu_index, void *glwe_array_out,
2938
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
99
void *stream, uint32_t gpu_index, void *lwe_array_out,
1010
void *lwe_output_indexes, void *lwe_array_in, void *lwe_input_indexes,
1111
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
12-
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
13-
host_keyswitch_lwe_ciphertext_vector<uint32_t>(
12+
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
13+
void *ksk_tmp_buffer, bool uses_trivial_indices) {
14+
host_gemm_keyswitch_lwe_ciphertext_vector<uint32_t>(
1415
static_cast<cudaStream_t>(stream), gpu_index,
1516
static_cast<uint32_t *>(lwe_array_out),
1617
static_cast<uint32_t *>(lwe_output_indexes),
1718
static_cast<uint32_t *>(lwe_array_in),
1819
static_cast<uint32_t *>(lwe_input_indexes), static_cast<uint32_t *>(ksk),
19-
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
20+
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples,
21+
static_cast<uint32_t *>(ksk_tmp_buffer), uses_trivial_indices);
2022
}
2123

2224
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts.
@@ -40,15 +42,19 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
4042
void const *lwe_output_indexes, void const *lwe_array_in,
4143
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
4244
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
43-
uint32_t num_samples) {
44-
host_keyswitch_lwe_ciphertext_vector<uint64_t>(
45+
uint32_t num_samples, const void *ks_tmp_buffer,
46+
bool uses_trivial_indices) {
47+
48+
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t>(
4549
static_cast<cudaStream_t>(stream), gpu_index,
4650
static_cast<uint64_t *>(lwe_array_out),
4751
static_cast<const uint64_t *>(lwe_output_indexes),
4852
static_cast<const uint64_t *>(lwe_array_in),
4953
static_cast<const uint64_t *>(lwe_input_indexes),
5054
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
51-
base_log, level_count, num_samples);
55+
base_log, level_count, num_samples,
56+
static_cast<const ks_mem<uint64_t> *>(ks_tmp_buffer)->d_buffer,
57+
uses_trivial_indices);
5258
}
5359

5460
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
@@ -60,6 +66,26 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
6066
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
6167
}
6268

69+
uint64_t scratch_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
70+
void **ks_tmp_buffer,
71+
uint32_t lwe_dimension_in,
72+
uint32_t lwe_dimension_out,
73+
uint32_t num_lwes,
74+
bool allocate_gpu_memory) {
75+
return scratch_cuda_keyswitch<uint64_t>(
76+
static_cast<cudaStream_t>(stream), gpu_index,
77+
(ks_mem<uint64_t> **)ks_tmp_buffer, lwe_dimension_in, lwe_dimension_out,
78+
num_lwes, allocate_gpu_memory);
79+
}
80+
81+
void cleanup_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
82+
void **ks_tmp_buffer, bool allocate_gpu_memory) {
83+
cleanup_cuda_keyswitch<uint64_t>(static_cast<cudaStream_t>(stream), gpu_index,
84+
(ks_mem<uint64_t> *)*ks_tmp_buffer,
85+
allocate_gpu_memory);
86+
*ks_tmp_buffer = nullptr;
87+
}
88+
6389
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
6490
* ciphertexts.
6591
*/

0 commit comments

Comments
 (0)