Skip to content

Commit 66e3c02

Browse files
committed
feat(gpu): add memory tracking functions for comparisons
1 parent 408e81c commit 66e3c02

File tree

11 files changed

+919
-12
lines changed

11 files changed

+919
-12
lines changed

tfhe/src/high_level_api/integers/signed/ops.rs

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::high_level_api::integers::{FheIntId, FheUintId};
55
use crate::high_level_api::keys::InternalServerKey;
66
#[cfg(feature = "gpu")]
77
use crate::high_level_api::traits::{
8-
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SizeOnGpu,
9-
SubSizeOnGpu,
8+
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
9+
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
1010
};
1111
use crate::high_level_api::traits::{
1212
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -2253,3 +2253,113 @@ where
22532253
})
22542254
}
22552255
}
2256+
2257+
#[cfg(feature = "gpu")]
2258+
impl<Id> FheOrdSizeOnGpu<&Self> for FheInt<Id>
2259+
where
2260+
Id: FheIntId,
2261+
{
2262+
fn get_gt_size_on_gpu(&self, rhs: &Self) -> u64 {
2263+
global_state::with_internal_keys(|key| {
2264+
if let InternalServerKey::Cuda(cuda_key) = key {
2265+
with_thread_local_cuda_streams(|streams| {
2266+
cuda_key.key.key.get_gt_size_on_gpu(
2267+
&*self.ciphertext.on_gpu(streams),
2268+
&rhs.ciphertext.on_gpu(streams),
2269+
streams,
2270+
)
2271+
})
2272+
} else {
2273+
0
2274+
}
2275+
})
2276+
}
2277+
fn get_ge_size_on_gpu(&self, rhs: &Self) -> u64 {
2278+
global_state::with_internal_keys(|key| {
2279+
if let InternalServerKey::Cuda(cuda_key) = key {
2280+
with_thread_local_cuda_streams(|streams| {
2281+
cuda_key.key.key.get_ge_size_on_gpu(
2282+
&*self.ciphertext.on_gpu(streams),
2283+
&rhs.ciphertext.on_gpu(streams),
2284+
streams,
2285+
)
2286+
})
2287+
} else {
2288+
0
2289+
}
2290+
})
2291+
}
2292+
fn get_lt_size_on_gpu(&self, rhs: &Self) -> u64 {
2293+
global_state::with_internal_keys(|key| {
2294+
if let InternalServerKey::Cuda(cuda_key) = key {
2295+
with_thread_local_cuda_streams(|streams| {
2296+
cuda_key.key.key.get_lt_size_on_gpu(
2297+
&*self.ciphertext.on_gpu(streams),
2298+
&rhs.ciphertext.on_gpu(streams),
2299+
streams,
2300+
)
2301+
})
2302+
} else {
2303+
0
2304+
}
2305+
})
2306+
}
2307+
fn get_le_size_on_gpu(&self, rhs: &Self) -> u64 {
2308+
global_state::with_internal_keys(|key| {
2309+
if let InternalServerKey::Cuda(cuda_key) = key {
2310+
with_thread_local_cuda_streams(|streams| {
2311+
cuda_key.key.key.get_le_size_on_gpu(
2312+
&*self.ciphertext.on_gpu(streams),
2313+
&rhs.ciphertext.on_gpu(streams),
2314+
streams,
2315+
)
2316+
})
2317+
} else {
2318+
0
2319+
}
2320+
})
2321+
}
2322+
}
2323+
#[cfg(feature = "gpu")]
2324+
impl<Id> FheMinSizeOnGpu<&Self> for FheInt<Id>
2325+
where
2326+
Id: FheIntId,
2327+
{
2328+
fn get_min_size_on_gpu(&self, rhs: &Self) -> u64 {
2329+
global_state::with_internal_keys(|key| {
2330+
if let InternalServerKey::Cuda(cuda_key) = key {
2331+
with_thread_local_cuda_streams(|streams| {
2332+
cuda_key.key.key.get_min_size_on_gpu(
2333+
&*self.ciphertext.on_gpu(streams),
2334+
&rhs.ciphertext.on_gpu(streams),
2335+
streams,
2336+
)
2337+
})
2338+
} else {
2339+
0
2340+
}
2341+
})
2342+
}
2343+
}
2344+
2345+
#[cfg(feature = "gpu")]
2346+
impl<Id> FheMaxSizeOnGpu<&Self> for FheInt<Id>
2347+
where
2348+
Id: FheIntId,
2349+
{
2350+
fn get_max_size_on_gpu(&self, rhs: &Self) -> u64 {
2351+
global_state::with_internal_keys(|key| {
2352+
if let InternalServerKey::Cuda(cuda_key) = key {
2353+
with_thread_local_cuda_streams(|streams| {
2354+
cuda_key.key.key.get_max_size_on_gpu(
2355+
&*self.ciphertext.on_gpu(streams),
2356+
&rhs.ciphertext.on_gpu(streams),
2357+
streams,
2358+
)
2359+
})
2360+
} else {
2361+
0
2362+
}
2363+
})
2364+
}
2365+
}

tfhe/src/high_level_api/integers/signed/scalar_ops.rs

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ use crate::high_level_api::integers::FheIntId;
99
use crate::high_level_api::keys::InternalServerKey;
1010
#[cfg(feature = "gpu")]
1111
use crate::high_level_api::traits::{
12-
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, SubSizeOnGpu,
12+
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, FheMaxSizeOnGpu,
13+
FheMinSizeOnGpu, FheOrdSizeOnGpu, SubSizeOnGpu,
1314
};
1415
use crate::high_level_api::traits::{
1516
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -406,6 +407,112 @@ where
406407
}
407408
}
408409

410+
#[cfg(feature = "gpu")]
411+
impl<Id, Clear> FheOrdSizeOnGpu<Clear> for FheInt<Id>
412+
where
413+
Id: FheIntId,
414+
Clear: DecomposableInto<u64>,
415+
{
416+
fn get_gt_size_on_gpu(&self, _rhs: Clear) -> u64 {
417+
global_state::with_internal_keys(|key| {
418+
if let InternalServerKey::Cuda(cuda_key) = key {
419+
with_thread_local_cuda_streams(|streams| {
420+
cuda_key
421+
.key
422+
.key
423+
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
424+
})
425+
} else {
426+
0
427+
}
428+
})
429+
}
430+
fn get_ge_size_on_gpu(&self, _rhs: Clear) -> u64 {
431+
global_state::with_internal_keys(|key| {
432+
if let InternalServerKey::Cuda(cuda_key) = key {
433+
with_thread_local_cuda_streams(|streams| {
434+
cuda_key
435+
.key
436+
.key
437+
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
438+
})
439+
} else {
440+
0
441+
}
442+
})
443+
}
444+
fn get_lt_size_on_gpu(&self, _rhs: Clear) -> u64 {
445+
global_state::with_internal_keys(|key| {
446+
if let InternalServerKey::Cuda(cuda_key) = key {
447+
with_thread_local_cuda_streams(|streams| {
448+
cuda_key
449+
.key
450+
.key
451+
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
452+
})
453+
} else {
454+
0
455+
}
456+
})
457+
}
458+
fn get_le_size_on_gpu(&self, _rhs: Clear) -> u64 {
459+
global_state::with_internal_keys(|key| {
460+
if let InternalServerKey::Cuda(cuda_key) = key {
461+
with_thread_local_cuda_streams(|streams| {
462+
cuda_key
463+
.key
464+
.key
465+
.get_scalar_le_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
466+
})
467+
} else {
468+
0
469+
}
470+
})
471+
}
472+
}
473+
474+
#[cfg(feature = "gpu")]
475+
impl<Id, Clear> FheMinSizeOnGpu<Clear> for FheInt<Id>
476+
where
477+
Id: FheIntId,
478+
Clear: DecomposableInto<u64>,
479+
{
480+
fn get_min_size_on_gpu(&self, _rhs: Clear) -> u64 {
481+
global_state::with_internal_keys(|key| {
482+
if let InternalServerKey::Cuda(cuda_key) = key {
483+
with_thread_local_cuda_streams(|streams| {
484+
cuda_key
485+
.key
486+
.key
487+
.get_scalar_min_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
488+
})
489+
} else {
490+
0
491+
}
492+
})
493+
}
494+
}
495+
#[cfg(feature = "gpu")]
496+
impl<Id, Clear> FheMaxSizeOnGpu<Clear> for FheInt<Id>
497+
where
498+
Id: FheIntId,
499+
Clear: DecomposableInto<u64>,
500+
{
501+
fn get_max_size_on_gpu(&self, _rhs: Clear) -> u64 {
502+
global_state::with_internal_keys(|key| {
503+
if let InternalServerKey::Cuda(cuda_key) = key {
504+
with_thread_local_cuda_streams(|streams| {
505+
cuda_key
506+
.key
507+
.key
508+
.get_scalar_max_size_on_gpu(&*self.ciphertext.on_gpu(streams), streams)
509+
})
510+
} else {
511+
0
512+
}
513+
})
514+
}
515+
}
409516
// DivRem is a bit special as it returns a tuple of quotient and remainder
410517
macro_rules! generic_integer_impl_scalar_div_rem {
411518
(

tfhe/src/high_level_api/integers/signed/tests/gpu.rs

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use crate::high_level_api::integers::signed::tests::{
22
test_case_ilog2, test_case_leading_trailing_zeros_ones,
33
};
44
use crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu;
5-
use crate::high_level_api::traits::AddSizeOnGpu;
65
use crate::prelude::{
7-
check_valid_cuda_malloc, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
8-
FheTryEncrypt, SubSizeOnGpu,
6+
check_valid_cuda_malloc, AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu,
7+
BitXorSizeOnGpu, FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt,
8+
SubSizeOnGpu,
99
};
1010
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS;
1111
use crate::{FheInt32, GpuIndex};
@@ -162,3 +162,77 @@ fn test_gpu_get_bitops_size_on_gpu() {
162162
GpuIndex::new(0)
163163
));
164164
}
165+
#[test]
166+
fn test_gpu_get_comparisons_size_on_gpu() {
167+
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
168+
let mut rng = rand::thread_rng();
169+
let clear_a = rng.gen_range(1..=i32::MAX);
170+
let clear_b = rng.gen_range(1..=i32::MAX);
171+
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
172+
let mut b = FheInt32::try_encrypt(clear_b, &cks).unwrap();
173+
a.move_to_current_device();
174+
b.move_to_current_device();
175+
let a = &a;
176+
let b = &b;
177+
178+
let gt_tmp_buffer_size = a.get_gt_size_on_gpu(b);
179+
let scalar_gt_tmp_buffer_size = a.get_gt_size_on_gpu(clear_b);
180+
assert!(check_valid_cuda_malloc(
181+
gt_tmp_buffer_size,
182+
GpuIndex::new(0)
183+
));
184+
assert!(check_valid_cuda_malloc(
185+
scalar_gt_tmp_buffer_size,
186+
GpuIndex::new(0)
187+
));
188+
let ge_tmp_buffer_size = a.get_ge_size_on_gpu(b);
189+
let scalar_ge_tmp_buffer_size = a.get_ge_size_on_gpu(clear_b);
190+
assert!(check_valid_cuda_malloc(
191+
ge_tmp_buffer_size,
192+
GpuIndex::new(0)
193+
));
194+
assert!(check_valid_cuda_malloc(
195+
scalar_ge_tmp_buffer_size,
196+
GpuIndex::new(0)
197+
));
198+
let lt_tmp_buffer_size = a.get_lt_size_on_gpu(b);
199+
let scalar_lt_tmp_buffer_size = a.get_lt_size_on_gpu(clear_b);
200+
assert!(check_valid_cuda_malloc(
201+
lt_tmp_buffer_size,
202+
GpuIndex::new(0)
203+
));
204+
assert!(check_valid_cuda_malloc(
205+
scalar_lt_tmp_buffer_size,
206+
GpuIndex::new(0)
207+
));
208+
let le_tmp_buffer_size = a.get_le_size_on_gpu(b);
209+
let scalar_le_tmp_buffer_size = a.get_le_size_on_gpu(clear_b);
210+
assert!(check_valid_cuda_malloc(
211+
le_tmp_buffer_size,
212+
GpuIndex::new(0)
213+
));
214+
assert!(check_valid_cuda_malloc(
215+
scalar_le_tmp_buffer_size,
216+
GpuIndex::new(0)
217+
));
218+
let max_tmp_buffer_size = a.get_max_size_on_gpu(b);
219+
let scalar_max_tmp_buffer_size = a.get_max_size_on_gpu(clear_b);
220+
assert!(check_valid_cuda_malloc(
221+
max_tmp_buffer_size,
222+
GpuIndex::new(0)
223+
));
224+
assert!(check_valid_cuda_malloc(
225+
scalar_max_tmp_buffer_size,
226+
GpuIndex::new(0)
227+
));
228+
let min_tmp_buffer_size = a.get_min_size_on_gpu(b);
229+
let scalar_min_tmp_buffer_size = a.get_min_size_on_gpu(clear_b);
230+
assert!(check_valid_cuda_malloc(
231+
min_tmp_buffer_size,
232+
GpuIndex::new(0)
233+
));
234+
assert!(check_valid_cuda_malloc(
235+
scalar_min_tmp_buffer_size,
236+
GpuIndex::new(0)
237+
));
238+
}

0 commit comments

Comments
 (0)