Skip to content

Commit 2bc2c86

Browse files
committed
feat(gpu): add memory tracking functions for shift/rotate
1 parent 66e3c02 commit 2bc2c86

File tree

13 files changed

+1896
-8
lines changed

13 files changed

+1896
-8
lines changed

tfhe/src/high_level_api/integers/signed/ops.rs

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use crate::high_level_api::keys::InternalServerKey;
66
#[cfg(feature = "gpu")]
77
use crate::high_level_api::traits::{
88
AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu,
9-
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
9+
FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu,
10+
ShlSizeOnGpu, ShrSizeOnGpu, SizeOnGpu, SubSizeOnGpu,
1011
};
1112
use crate::high_level_api::traits::{
1213
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -2363,3 +2364,92 @@ where
23632364
})
23642365
}
23652366
}
2367+
2368+
#[cfg(feature = "gpu")]
2369+
impl<Id, Id2> ShlSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
2370+
where
2371+
Id: FheIntId,
2372+
Id2: FheUintId,
2373+
{
2374+
fn get_left_shift_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
2375+
global_state::with_internal_keys(|key| {
2376+
if let InternalServerKey::Cuda(cuda_key) = key {
2377+
with_thread_local_cuda_streams(|streams| {
2378+
cuda_key.key.key.get_left_shift_size_on_gpu(
2379+
&*self.ciphertext.on_gpu(streams),
2380+
&rhs.ciphertext.on_gpu(streams),
2381+
streams,
2382+
)
2383+
})
2384+
} else {
2385+
0
2386+
}
2387+
})
2388+
}
2389+
}
2390+
#[cfg(feature = "gpu")]
2391+
impl<Id, Id2> ShrSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
2392+
where
2393+
Id: FheIntId,
2394+
Id2: FheUintId,
2395+
{
2396+
fn get_right_shift_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
2397+
global_state::with_internal_keys(|key| {
2398+
if let InternalServerKey::Cuda(cuda_key) = key {
2399+
with_thread_local_cuda_streams(|streams| {
2400+
cuda_key.key.key.get_right_shift_size_on_gpu(
2401+
&*self.ciphertext.on_gpu(streams),
2402+
&rhs.ciphertext.on_gpu(streams),
2403+
streams,
2404+
)
2405+
})
2406+
} else {
2407+
0
2408+
}
2409+
})
2410+
}
2411+
}
2412+
#[cfg(feature = "gpu")]
2413+
impl<Id, Id2> RotateLeftSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
2414+
where
2415+
Id: FheIntId,
2416+
Id2: FheUintId,
2417+
{
2418+
fn get_rotate_left_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
2419+
global_state::with_internal_keys(|key| {
2420+
if let InternalServerKey::Cuda(cuda_key) = key {
2421+
with_thread_local_cuda_streams(|streams| {
2422+
cuda_key.key.key.get_rotate_left_size_on_gpu(
2423+
&*self.ciphertext.on_gpu(streams),
2424+
&rhs.ciphertext.on_gpu(streams),
2425+
streams,
2426+
)
2427+
})
2428+
} else {
2429+
0
2430+
}
2431+
})
2432+
}
2433+
}
2434+
#[cfg(feature = "gpu")]
2435+
impl<Id, Id2> RotateRightSizeOnGpu<&FheUint<Id2>> for FheInt<Id>
2436+
where
2437+
Id: FheIntId,
2438+
Id2: FheUintId,
2439+
{
2440+
fn get_rotate_right_size_on_gpu(&self, rhs: &FheUint<Id2>) -> u64 {
2441+
global_state::with_internal_keys(|key| {
2442+
if let InternalServerKey::Cuda(cuda_key) = key {
2443+
with_thread_local_cuda_streams(|streams| {
2444+
cuda_key.key.key.get_rotate_right_size_on_gpu(
2445+
&*self.ciphertext.on_gpu(streams),
2446+
&rhs.ciphertext.on_gpu(streams),
2447+
streams,
2448+
)
2449+
})
2450+
} else {
2451+
0
2452+
}
2453+
})
2454+
}
2455+
}

tfhe/src/high_level_api/integers/signed/scalar_ops.rs

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use crate::high_level_api::keys::InternalServerKey;
1010
#[cfg(feature = "gpu")]
1111
use crate::high_level_api::traits::{
1212
AddSizeOnGpu, BitAndSizeOnGpu, BitOrSizeOnGpu, BitXorSizeOnGpu, FheMaxSizeOnGpu,
13-
FheMinSizeOnGpu, FheOrdSizeOnGpu, SubSizeOnGpu,
13+
FheMinSizeOnGpu, FheOrdSizeOnGpu, RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu,
14+
ShrSizeOnGpu, SubSizeOnGpu,
1415
};
1516
use crate::high_level_api::traits::{
1617
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
@@ -628,6 +629,30 @@ macro_rules! define_scalar_rotate_shifts {
628629
)*
629630
);
630631

632+
#[cfg(feature = "gpu")]
633+
generic_integer_impl_get_scalar_operation_size_on_gpu!(
634+
rust_trait: ShlSizeOnGpu(get_left_shift_size_on_gpu),
635+
implem: {
636+
|lhs: &FheInt<_>, _rhs| {
637+
global_state::with_internal_keys(|key|
638+
if let InternalServerKey::Cuda(cuda_key) = key {
639+
with_thread_local_cuda_streams(|streams| {
640+
cuda_key.key.key.get_scalar_left_shift_size_on_gpu(
641+
&*lhs.ciphertext.on_gpu(streams),
642+
streams,
643+
)
644+
})
645+
} else {
646+
0
647+
})
648+
}
649+
},
650+
fhe_and_scalar_type:
651+
$(
652+
($concrete_type, $($scalar_type,)*),
653+
)*
654+
);
655+
631656
generic_integer_impl_scalar_operation!(
632657
rust_trait: Shr(shr),
633658
implem: {
@@ -661,6 +686,30 @@ macro_rules! define_scalar_rotate_shifts {
661686
)*
662687
);
663688

689+
#[cfg(feature = "gpu")]
690+
generic_integer_impl_get_scalar_operation_size_on_gpu!(
691+
rust_trait: ShrSizeOnGpu(get_right_shift_size_on_gpu),
692+
implem: {
693+
|lhs: &FheInt<_>, _rhs| {
694+
global_state::with_internal_keys(|key|
695+
if let InternalServerKey::Cuda(cuda_key) = key {
696+
with_thread_local_cuda_streams(|streams| {
697+
cuda_key.key.key.get_scalar_right_shift_size_on_gpu(
698+
&*lhs.ciphertext.on_gpu(streams),
699+
streams,
700+
)
701+
})
702+
} else {
703+
0
704+
})
705+
}
706+
},
707+
fhe_and_scalar_type:
708+
$(
709+
($concrete_type, $($scalar_type,)*),
710+
)*
711+
);
712+
664713
generic_integer_impl_scalar_operation!(
665714
rust_trait: RotateLeft(rotate_left),
666715
implem: {
@@ -694,6 +743,30 @@ macro_rules! define_scalar_rotate_shifts {
694743
)*
695744
);
696745

746+
#[cfg(feature = "gpu")]
747+
generic_integer_impl_get_scalar_operation_size_on_gpu!(
748+
rust_trait: RotateLeftSizeOnGpu(get_rotate_left_size_on_gpu),
749+
implem: {
750+
|lhs: &FheInt<_>, _rhs| {
751+
global_state::with_internal_keys(|key|
752+
if let InternalServerKey::Cuda(cuda_key) = key {
753+
with_thread_local_cuda_streams(|streams| {
754+
cuda_key.key.key.get_scalar_rotate_left_size_on_gpu(
755+
&*lhs.ciphertext.on_gpu(streams),
756+
streams,
757+
)
758+
})
759+
} else {
760+
0
761+
})
762+
}
763+
},
764+
fhe_and_scalar_type:
765+
$(
766+
($concrete_type, $($scalar_type,)*),
767+
)*
768+
);
769+
697770
generic_integer_impl_scalar_operation!(
698771
rust_trait: RotateRight(rotate_right),
699772
implem: {
@@ -727,6 +800,30 @@ macro_rules! define_scalar_rotate_shifts {
727800
)*
728801
);
729802

803+
#[cfg(feature = "gpu")]
804+
generic_integer_impl_get_scalar_operation_size_on_gpu!(
805+
rust_trait: RotateRightSizeOnGpu(get_rotate_right_size_on_gpu),
806+
implem: {
807+
|lhs: &FheInt<_>, _rhs| {
808+
global_state::with_internal_keys(|key|
809+
if let InternalServerKey::Cuda(cuda_key) = key {
810+
with_thread_local_cuda_streams(|streams| {
811+
cuda_key.key.key.get_scalar_rotate_right_size_on_gpu(
812+
&*lhs.ciphertext.on_gpu(streams),
813+
streams,
814+
)
815+
})
816+
} else {
817+
0
818+
})
819+
}
820+
},
821+
fhe_and_scalar_type:
822+
$(
823+
($concrete_type, $($scalar_type,)*),
824+
)*
825+
);
826+
730827
generic_integer_impl_scalar_operation_assign!(
731828
rust_trait: ShlAssign(shl_assign),
732829
implem: {

tfhe/src/high_level_api/integers/signed/tests/gpu.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ use crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu;
55
use crate::prelude::{
66
check_valid_cuda_malloc, AddSizeOnGpu, BitAndSizeOnGpu, BitNotSizeOnGpu, BitOrSizeOnGpu,
77
BitXorSizeOnGpu, FheMaxSizeOnGpu, FheMinSizeOnGpu, FheOrdSizeOnGpu, FheTryEncrypt,
8-
SubSizeOnGpu,
8+
RotateLeftSizeOnGpu, RotateRightSizeOnGpu, ShlSizeOnGpu, ShrSizeOnGpu, SubSizeOnGpu,
99
};
1010
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS;
11-
use crate::{FheInt32, GpuIndex};
11+
use crate::{FheInt32, FheUint32, GpuIndex};
1212
use rand::Rng;
1313

1414
#[test]
@@ -236,3 +236,58 @@ fn test_gpu_get_comparisons_size_on_gpu() {
236236
GpuIndex::new(0)
237237
));
238238
}
239+
240+
#[test]
241+
fn test_gpu_get_shift_rotate_size_on_gpu() {
242+
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
243+
let mut rng = rand::thread_rng();
244+
let clear_a = rng.gen_range(1..=i32::MAX);
245+
let clear_b = rng.gen_range(1..=u32::MAX);
246+
let mut a = FheInt32::try_encrypt(clear_a, &cks).unwrap();
247+
let mut b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
248+
a.move_to_current_device();
249+
b.move_to_current_device();
250+
let a = &a;
251+
let b = &b;
252+
253+
let left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(b);
254+
let scalar_left_shift_tmp_buffer_size = a.get_left_shift_size_on_gpu(clear_b);
255+
assert!(check_valid_cuda_malloc(
256+
left_shift_tmp_buffer_size,
257+
GpuIndex::new(0)
258+
));
259+
assert!(check_valid_cuda_malloc(
260+
scalar_left_shift_tmp_buffer_size,
261+
GpuIndex::new(0)
262+
));
263+
let right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(b);
264+
let scalar_right_shift_tmp_buffer_size = a.get_right_shift_size_on_gpu(clear_b);
265+
assert!(check_valid_cuda_malloc(
266+
right_shift_tmp_buffer_size,
267+
GpuIndex::new(0)
268+
));
269+
assert!(check_valid_cuda_malloc(
270+
scalar_right_shift_tmp_buffer_size,
271+
GpuIndex::new(0)
272+
));
273+
let rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(b);
274+
let scalar_rotate_left_tmp_buffer_size = a.get_rotate_left_size_on_gpu(clear_b);
275+
assert!(check_valid_cuda_malloc(
276+
rotate_left_tmp_buffer_size,
277+
GpuIndex::new(0)
278+
));
279+
assert!(check_valid_cuda_malloc(
280+
scalar_rotate_left_tmp_buffer_size,
281+
GpuIndex::new(0)
282+
));
283+
let rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(b);
284+
let scalar_rotate_right_tmp_buffer_size = a.get_rotate_right_size_on_gpu(clear_b);
285+
assert!(check_valid_cuda_malloc(
286+
rotate_right_tmp_buffer_size,
287+
GpuIndex::new(0)
288+
));
289+
assert!(check_valid_cuda_malloc(
290+
scalar_rotate_right_tmp_buffer_size,
291+
GpuIndex::new(0)
292+
));
293+
}

0 commit comments

Comments
 (0)