Skip to content

Commit 1242657

Browse files
guillermo-oyarzunagnesLeroy
authored andcommitted
fix(gpu): add upper bound to lwe_chunk_size calculation
1 parent 6f105cd commit 1242657

7 files changed

+79
-60
lines changed

backends/tfhe-cuda-backend/cuda/include/pbs/pbs_multibit_utilities.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,21 @@ uint64_t get_buffer_size_full_sm_tbc_multibit_programmable_bootstrap(
9797
uint32_t polynomial_size);
9898

9999
template <typename Torus, class params>
100-
uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
101-
uint32_t polynomial_size,
102-
uint64_t full_sm_keybundle);
100+
uint64_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
101+
uint32_t polynomial_size, uint32_t glwe_dimension,
102+
uint32_t level_count, uint64_t full_sm_keybundle);
103103
template <typename Torus, class params>
104-
uint32_t get_lwe_chunk_size_128(uint32_t gpu_index, uint32_t max_num_pbs,
104+
uint64_t get_lwe_chunk_size_128(uint32_t gpu_index, uint32_t max_num_pbs,
105105
uint32_t polynomial_size,
106+
uint32_t glwe_dimension, uint32_t level_count,
106107
uint64_t full_sm_keybundle);
107108
template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
108109
int8_t *d_mem_keybundle = NULL;
109110
int8_t *d_mem_acc_step_one = NULL;
110111
int8_t *d_mem_acc_step_two = NULL;
111112
int8_t *d_mem_acc_cg = NULL;
112113
int8_t *d_mem_acc_tbc = NULL;
113-
uint32_t lwe_chunk_size;
114+
uint64_t lwe_chunk_size;
114115
double2 *keybundle_fft;
115116
Torus *global_accumulator;
116117
double2 *global_join_buffer;
@@ -120,7 +121,7 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
120121

121122
pbs_buffer(cudaStream_t stream, uint32_t gpu_index, uint32_t glwe_dimension,
122123
uint32_t polynomial_size, uint32_t level_count,
123-
uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size,
124+
uint32_t input_lwe_ciphertext_count, uint64_t lwe_chunk_size,
124125
PBS_VARIANT pbs_variant, bool allocate_gpu_memory,
125126
uint64_t &size_tracker) {
126127
gpu_memory_allocated = allocate_gpu_memory;
@@ -295,7 +296,7 @@ struct pbs_buffer_128<InputTorus, PBS_TYPE::MULTI_BIT> {
295296
int8_t *d_mem_acc_step_two = NULL;
296297
int8_t *d_mem_acc_cg = NULL;
297298
int8_t *d_mem_acc_tbc = NULL;
298-
uint32_t lwe_chunk_size;
299+
uint64_t lwe_chunk_size;
299300
double *keybundle_fft;
300301
__uint128_t *global_accumulator;
301302
double *global_join_buffer;
@@ -306,7 +307,7 @@ struct pbs_buffer_128<InputTorus, PBS_TYPE::MULTI_BIT> {
306307
pbs_buffer_128(cudaStream_t stream, uint32_t gpu_index,
307308
uint32_t glwe_dimension, uint32_t polynomial_size,
308309
uint32_t level_count, uint32_t input_lwe_ciphertext_count,
309-
uint32_t lwe_chunk_size, PBS_VARIANT pbs_variant,
310+
uint64_t lwe_chunk_size, PBS_VARIANT pbs_variant,
310311
bool allocate_gpu_memory, uint64_t &size_tracker) {
311312
gpu_memory_allocated = allocate_gpu_memory;
312313
cuda_set_device(gpu_index);

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_cg_multibit.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
3030
Torus *global_accumulator, uint32_t lwe_dimension,
3131
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
3232
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
33-
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
33+
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
3434
int8_t *device_mem, uint64_t device_memory_size_per_block,
3535
uint32_t num_many_lut, uint32_t lut_stride) {
3636

@@ -193,7 +193,7 @@ template <typename Torus>
193193
uint64_t get_buffer_size_cg_multibit_programmable_bootstrap(
194194
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
195195
uint32_t level_count, uint32_t input_lwe_ciphertext_count,
196-
uint32_t grouping_factor, uint32_t lwe_chunk_size) {
196+
uint32_t grouping_factor, uint64_t lwe_chunk_size) {
197197

198198
uint64_t buffer_size = 0;
199199
buffer_size += input_lwe_ciphertext_count * lwe_chunk_size * level_count *
@@ -280,9 +280,9 @@ __host__ uint64_t scratch_cg_multi_bit_programmable_bootstrap(
280280
check_cuda_error(cudaGetLastError());
281281
}
282282

283-
auto lwe_chunk_size =
284-
get_lwe_chunk_size<Torus, params>(gpu_index, input_lwe_ciphertext_count,
285-
polynomial_size, full_sm_keybundle);
283+
auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
284+
gpu_index, input_lwe_ciphertext_count, polynomial_size, glwe_dimension,
285+
level_count, full_sm_keybundle);
286286
uint64_t size_tracker = 0;
287287
*buffer = new pbs_buffer<Torus, MULTI_BIT>(
288288
stream, gpu_index, glwe_dimension, polynomial_size, level_count,
@@ -317,12 +317,12 @@ __host__ void execute_cg_external_product_loop(
317317
auto lwe_chunk_size = buffer->lwe_chunk_size;
318318
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);
319319

320-
uint32_t keybundle_size_per_input =
320+
uint64_t keybundle_size_per_input =
321321
lwe_chunk_size * level_count * (glwe_dimension + 1) *
322322
(glwe_dimension + 1) * (polynomial_size / 2);
323323

324-
uint32_t chunk_size =
325-
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
324+
uint64_t chunk_size = std::min(
325+
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
326326

327327
auto d_mem = buffer->d_mem_acc_cg;
328328
auto keybundle_fft = buffer->keybundle_fft;

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,9 @@ void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream,
456456
* benchmarking on an RTX 4090 GPU, balancing performance and resource use.
457457
*/
458458
template <typename Torus, class params>
459-
uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
460-
uint32_t polynomial_size,
461-
uint64_t full_sm_keybundle) {
459+
uint64_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
460+
uint32_t polynomial_size, uint32_t glwe_dimension,
461+
uint32_t level_count, uint64_t full_sm_keybundle) {
462462

463463
int max_blocks_per_sm;
464464
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);
@@ -479,6 +479,22 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
479479
check_cuda_error(cudaDeviceGetAttribute(
480480
&num_sms, cudaDevAttrMultiProcessorCount, gpu_index));
481481

482+
size_t total_mem, free_mem;
483+
check_cuda_error(cudaMemGetInfo(&free_mem, &total_mem));
484+
// Estimate the size of one chunk
485+
uint64_t size_one_chunk = max_num_pbs * polynomial_size *
486+
(glwe_dimension + 1) * (glwe_dimension + 1) *
487+
level_count * sizeof(Torus);
488+
489+
// We calculate the maximum number of chunks that can fit in the 50% of free
490+
// memory. We don't want the pbs temp array uses more than 50% of the free
491+
// memory if 1 chunk doesn't fit in the 50% of free memory we panic
492+
uint32_t max_num_chunks =
493+
static_cast<uint32_t>(free_mem / (2 * size_one_chunk));
494+
PANIC_IF_FALSE(
495+
max_num_chunks > 0,
496+
"Cuda error (multi-bit PBS): Not enough GPU memory to allocate PBS "
497+
"temporary arrays.");
482498
int x = num_sms * max_blocks_per_sm;
483499
int count = 0;
484500

@@ -500,7 +516,7 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
500516
// applied only to few number of samples(8) because it can have a negative
501517
// effect of over saturation.
502518
if (max_num_pbs <= 8) {
503-
return num_sms / 2;
519+
return (max_num_chunks > num_sms / 2) ? num_sms / 2 : max_num_chunks;
504520
}
505521
#endif
506522

@@ -514,8 +530,7 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
514530
}
515531
}
516532
}
517-
518-
return divisor;
533+
return (max_num_chunks > divisor) ? divisor : max_num_chunks;
519534
}
520535

521536
template uint64_t scratch_cuda_multi_bit_programmable_bootstrap<uint64_t>(

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cuh

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
4545
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
4646
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
4747
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
48-
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
49-
uint32_t keybundle_size_per_input, int8_t *device_mem,
48+
uint32_t level_count, uint32_t lwe_offset, uint64_t lwe_chunk_size,
49+
uint64_t keybundle_size_per_input, int8_t *device_mem,
5050
uint64_t device_memory_size_per_block) {
5151

5252
extern __shared__ int8_t sharedmem[];
@@ -164,8 +164,8 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
164164
const Torus *__restrict__ lwe_array_in,
165165
const Torus *__restrict__ lwe_input_indexes, double2 *keybundle_array,
166166
const Torus *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
167-
uint32_t lwe_offset, uint32_t lwe_chunk_size,
168-
uint32_t keybundle_size_per_input) {
167+
uint32_t lwe_offset, uint64_t lwe_chunk_size,
168+
uint64_t keybundle_size_per_input) {
169169

170170
constexpr uint32_t polynomial_size = 2048;
171171
constexpr uint32_t grouping_factor = 4;
@@ -387,7 +387,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
387387
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
388388
const double2 *__restrict__ keybundle_array, Torus *global_accumulator,
389389
double2 *join_buffer, uint32_t glwe_dimension, uint32_t polynomial_size,
390-
uint32_t level_count, uint32_t iteration, uint32_t lwe_chunk_size,
390+
uint32_t level_count, uint32_t iteration, uint64_t lwe_chunk_size,
391391
int8_t *device_mem, uint64_t device_memory_size_per_block,
392392
uint32_t num_many_lut, uint32_t lut_stride) {
393393
// We use shared memory for the polynomials that are used often during the
@@ -658,9 +658,9 @@ __host__ uint64_t scratch_multi_bit_programmable_bootstrap(
658658
check_cuda_error(cudaGetLastError());
659659
}
660660

661-
auto lwe_chunk_size =
662-
get_lwe_chunk_size<Torus, params>(gpu_index, input_lwe_ciphertext_count,
663-
polynomial_size, full_sm_keybundle);
661+
auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
662+
gpu_index, input_lwe_ciphertext_count, polynomial_size, glwe_dimension,
663+
level_count, full_sm_keybundle);
664664
uint64_t size_tracker = 0;
665665
*buffer = new pbs_buffer<Torus, MULTI_BIT>(
666666
stream, gpu_index, glwe_dimension, polynomial_size, level_count,
@@ -679,10 +679,10 @@ __host__ void execute_compute_keybundle(
679679
cuda_set_device(gpu_index);
680680

681681
auto lwe_chunk_size = buffer->lwe_chunk_size;
682-
uint32_t chunk_size =
683-
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
682+
uint64_t chunk_size = std::min(
683+
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
684684

685-
uint32_t keybundle_size_per_input =
685+
uint64_t keybundle_size_per_input =
686686
lwe_chunk_size * level_count * (glwe_dimension + 1) *
687687
(glwe_dimension + 1) * (polynomial_size / 2);
688688

@@ -859,8 +859,9 @@ __host__ void host_multi_bit_programmable_bootstrap(
859859
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
860860
grouping_factor, level_count, lwe_offset);
861861
// Accumulate
862-
uint32_t chunk_size = std::min(
863-
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
862+
uint32_t chunk_size =
863+
std::min((uint32_t)lwe_chunk_size,
864+
(lwe_dimension / grouping_factor) - lwe_offset);
864865
for (uint32_t j = 0; j < chunk_size; j++) {
865866
bool is_first_iter = (j + lwe_offset) == 0;
866867
bool is_last_iter =

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit_128.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,9 @@ void cleanup_cuda_multi_bit_programmable_bootstrap_128(void *stream,
307307
* benchmarking on an RTX 4090 GPU, balancing performance and resource use.
308308
*/
309309
template <typename Torus, class params>
310-
uint32_t get_lwe_chunk_size_128(uint32_t gpu_index, uint32_t max_num_pbs,
310+
uint64_t get_lwe_chunk_size_128(uint32_t gpu_index, uint32_t max_num_pbs,
311311
uint32_t polynomial_size,
312+
uint32_t glwe_dimension, uint32_t level_count,
312313
uint64_t full_sm_keybundle) {
313314

314315
int max_blocks_per_sm;

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit_128.cuh

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_128(
2323
const InputTorus *__restrict__ lwe_input_indexes, double *keybundle_array,
2424
const __uint128_t *__restrict__ bootstrapping_key, uint32_t lwe_dimension,
2525
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
26-
uint32_t level_count, uint32_t lwe_offset, uint32_t lwe_chunk_size,
27-
uint32_t keybundle_size_per_input, int8_t *device_mem,
26+
uint32_t level_count, uint32_t lwe_offset, uint64_t lwe_chunk_size,
27+
uint64_t keybundle_size_per_input, int8_t *device_mem,
2828
uint64_t device_memory_size_per_block) {
2929

3030
extern __shared__ int8_t sharedmem[];
@@ -237,7 +237,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
237237
const double *__restrict__ keybundle_array,
238238
__uint128_t *global_accumulator, double *global_accumulator_fft,
239239
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
240-
uint32_t iteration, uint32_t lwe_chunk_size, int8_t *device_mem,
240+
uint32_t iteration, uint64_t lwe_chunk_size, int8_t *device_mem,
241241
uint64_t device_memory_size_per_block, uint32_t num_many_lut,
242242
uint32_t lut_stride) {
243243
// We use shared memory for the polynomials that are used often during the
@@ -372,7 +372,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
372372
__uint128_t *global_accumulator, uint32_t lwe_dimension,
373373
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
374374
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
375-
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
375+
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
376376
int8_t *device_mem, uint64_t device_memory_size_per_block,
377377
uint32_t num_many_lut, uint32_t lut_stride) {
378378

@@ -546,10 +546,10 @@ __host__ void execute_compute_keybundle_128(
546546
cuda_set_device(gpu_index);
547547

548548
auto lwe_chunk_size = buffer->lwe_chunk_size;
549-
uint32_t chunk_size =
550-
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
549+
uint64_t chunk_size = std::min(
550+
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
551551

552-
uint32_t keybundle_size_per_input =
552+
uint64_t keybundle_size_per_input =
553553
lwe_chunk_size * level_count * (glwe_dimension + 1) *
554554
(glwe_dimension + 1) * (polynomial_size / 2) * 4;
555555

@@ -703,8 +703,9 @@ __host__ void host_multi_bit_programmable_bootstrap_128(
703703
buffer, num_samples, lwe_dimension, glwe_dimension, polynomial_size,
704704
grouping_factor, level_count, lwe_offset);
705705
// Accumulate
706-
uint32_t chunk_size = std::min(
707-
lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
706+
uint64_t chunk_size =
707+
std::min((uint32_t)lwe_chunk_size,
708+
(lwe_dimension / grouping_factor) - lwe_offset);
708709
for (uint32_t j = 0; j < chunk_size; j++) {
709710
bool is_first_iter = (j + lwe_offset) == 0;
710711
bool is_last_iter =
@@ -761,12 +762,12 @@ __host__ void execute_cg_external_product_loop_128(
761762
auto lwe_chunk_size = buffer->lwe_chunk_size;
762763
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);
763764

764-
uint32_t keybundle_size_per_input =
765+
uint64_t keybundle_size_per_input =
765766
lwe_chunk_size * level_count * (glwe_dimension + 1) *
766767
(glwe_dimension + 1) * (polynomial_size / 2) * 4;
767768

768-
uint32_t chunk_size =
769-
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
769+
uint64_t chunk_size = std::min(
770+
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
770771

771772
auto d_mem = buffer->d_mem_acc_cg;
772773
auto keybundle_fft = buffer->keybundle_fft;
@@ -994,8 +995,8 @@ __host__ uint64_t scratch_multi_bit_programmable_bootstrap_128(
994995
}
995996

996997
auto lwe_chunk_size = get_lwe_chunk_size_128<InputTorus, params>(
997-
gpu_index, input_lwe_ciphertext_count, polynomial_size,
998-
full_sm_keybundle);
998+
gpu_index, input_lwe_ciphertext_count, polynomial_size, glwe_dimension,
999+
level_count, full_sm_keybundle);
9991000
uint64_t size_tracker = 0;
10001001
*buffer = new pbs_buffer_128<InputTorus, MULTI_BIT>(
10011002
stream, gpu_index, glwe_dimension, polynomial_size, level_count,
@@ -1079,8 +1080,8 @@ __host__ uint64_t scratch_cg_multi_bit_programmable_bootstrap_128(
10791080
}
10801081

10811082
auto lwe_chunk_size = get_lwe_chunk_size_128<InputTorus, params>(
1082-
gpu_index, input_lwe_ciphertext_count, polynomial_size,
1083-
full_sm_keybundle);
1083+
gpu_index, input_lwe_ciphertext_count, polynomial_size, glwe_dimension,
1084+
level_count, full_sm_keybundle);
10841085
uint64_t size_tracker = 0;
10851086
*buffer = new pbs_buffer_128<InputTorus, MULTI_BIT>(
10861087
stream, gpu_index, glwe_dimension, polynomial_size, level_count,

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_tbc_multibit.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
3030
Torus *global_accumulator, uint32_t lwe_dimension,
3131
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
3232
uint32_t level_count, uint32_t grouping_factor, uint32_t lwe_offset,
33-
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
33+
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
3434
int8_t *device_mem, uint64_t device_memory_size_per_block,
3535
bool support_dsm, uint32_t num_many_lut, uint32_t lut_stride) {
3636

@@ -207,7 +207,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
207207
const Torus *__restrict__ lwe_input_indexes,
208208
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
209209
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset,
210-
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
210+
uint64_t lwe_chunk_size, uint64_t keybundle_size_per_input,
211211
uint32_t num_many_lut, uint32_t lut_stride) {
212212

213213
constexpr uint32_t level_count = 1;
@@ -502,9 +502,9 @@ __host__ uint64_t scratch_tbc_multi_bit_programmable_bootstrap(
502502
check_cuda_error(cudaGetLastError());
503503
}
504504

505-
auto lwe_chunk_size =
506-
get_lwe_chunk_size<Torus, params>(gpu_index, input_lwe_ciphertext_count,
507-
polynomial_size, full_sm_keybundle);
505+
auto lwe_chunk_size = get_lwe_chunk_size<Torus, params>(
506+
gpu_index, input_lwe_ciphertext_count, polynomial_size, glwe_dimension,
507+
level_count, full_sm_keybundle);
508508
uint64_t size_tracker = 0;
509509
*buffer = new pbs_buffer<uint64_t, MULTI_BIT>(
510510
stream, gpu_index, glwe_dimension, polynomial_size, level_count,
@@ -544,12 +544,12 @@ __host__ void execute_tbc_external_product_loop(
544544
get_buffer_size_sm_dsm_plus_tbc_multibit_programmable_bootstrap<Torus>(
545545
polynomial_size);
546546

547-
uint32_t keybundle_size_per_input =
547+
uint64_t keybundle_size_per_input =
548548
lwe_chunk_size * level_count * (glwe_dimension + 1) *
549549
(glwe_dimension + 1) * (polynomial_size / 2);
550550

551-
uint32_t chunk_size =
552-
std::min(lwe_chunk_size, (lwe_dimension / grouping_factor) - lwe_offset);
551+
uint64_t chunk_size = std::min(
552+
lwe_chunk_size, (uint64_t)(lwe_dimension / grouping_factor) - lwe_offset);
553553

554554
auto d_mem = buffer->d_mem_acc_tbc;
555555
auto keybundle_fft = buffer->keybundle_fft;

0 commit comments

Comments
 (0)