Skip to content

Commit 8c12b84

Browse files
committed
refactor(gpu): match_value to backend with multiple streams
1 parent 36fb820 commit 8c12b84

File tree

7 files changed

+2311
-330
lines changed

7 files changed

+2311
-330
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,80 @@ void cuda_integer_ilog2_64(
735735

736736
void cleanup_cuda_integer_ilog2_64(CudaStreamsFFI streams,
737737
int8_t **mem_ptr_void);
738+
739+
uint64_t scratch_cuda_compute_equality_selectors_64(
740+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
741+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
742+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
743+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
744+
uint32_t num_possible_values, uint32_t num_blocks, uint32_t message_modulus,
745+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
746+
PBS_MS_REDUCTION_T noise_reduction_type);
747+
748+
void cuda_compute_equality_selectors_64(
749+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out_list,
750+
CudaRadixCiphertextFFI const *lwe_array_in, uint32_t num_blocks,
751+
const uint64_t *h_decomposed_cleartexts, int8_t *mem, void *const *bsks,
752+
void *const *ksks);
753+
754+
void cleanup_cuda_compute_equality_selectors_64(CudaStreamsFFI streams,
755+
int8_t **mem_ptr_void);
756+
757+
uint64_t scratch_cuda_create_possible_results_64(
758+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
759+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
760+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
761+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
762+
uint32_t num_possible_values, uint32_t num_blocks, uint32_t message_modulus,
763+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
764+
PBS_MS_REDUCTION_T noise_reduction_type);
765+
766+
void cuda_create_possible_results_64(
767+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out_list,
768+
CudaRadixCiphertextFFI const *lwe_array_in_list,
769+
uint32_t num_possible_values, const uint64_t *h_decomposed_cleartexts,
770+
uint32_t num_blocks, int8_t *mem, void *const *bsks, void *const *ksks);
771+
772+
void cleanup_cuda_create_possible_results_64(CudaStreamsFFI streams,
773+
int8_t **mem_ptr_void);
774+
775+
uint64_t scratch_cuda_aggregate_one_hot_vector_64(
776+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
777+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
778+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
779+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
780+
uint32_t num_blocks, uint32_t num_matches, uint32_t message_modulus,
781+
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
782+
PBS_MS_REDUCTION_T noise_reduction_type);
783+
784+
void cuda_aggregate_one_hot_vector_64(
785+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out,
786+
CudaRadixCiphertextFFI const *lwe_array_in_list,
787+
uint32_t num_input_ciphertexts, uint32_t num_blocks, int8_t *mem,
788+
void *const *bsks, void *const *ksks);
789+
790+
void cleanup_cuda_aggregate_one_hot_vector_64(CudaStreamsFFI streams,
791+
int8_t **mem_ptr_void);
792+
793+
uint64_t scratch_cuda_unchecked_match_value_64(
794+
CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension,
795+
uint32_t polynomial_size, uint32_t big_lwe_dimension,
796+
uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log,
797+
uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor,
798+
uint32_t num_matches, uint32_t num_input_blocks,
799+
uint32_t num_output_packed_blocks, uint32_t max_output_is_zero,
800+
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
801+
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
802+
803+
void cuda_unchecked_match_value_64(
804+
CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out_result,
805+
CudaRadixCiphertextFFI *lwe_array_out_boolean,
806+
CudaRadixCiphertextFFI const *lwe_array_in_ct,
807+
const uint64_t *h_match_inputs, const uint64_t *h_match_outputs,
808+
int8_t *mem, void *const *bsks, void *const *ksks);
809+
810+
void cleanup_cuda_unchecked_match_value_64(CudaStreamsFFI streams,
811+
int8_t **mem_ptr_void);
738812
} // extern C
739813

740814
#endif // CUDA_INTEGER_H

0 commit comments

Comments
 (0)