Skip to content

Conversation

@andrei-stoian-zama
Copy link
Contributor

@andrei-stoian-zama andrei-stoian-zama commented Oct 16, 2025

  • Uses GEMM-based KS in integer ops when large batches of LWEs need to be keyswitched
  • Implements GEMM KS with non-trivial indexes
  • Improves KS GPU bench and KS GPU test

This change is Reviewable

@cla-bot cla-bot bot added the cla-signed label Oct 16, 2025
@andrei-stoian-zama andrei-stoian-zama force-pushed the as/gemm_ks branch 4 times, most recently from 6dc060b to 60145aa Compare November 10, 2025 10:32
@andrei-stoian-zama andrei-stoian-zama changed the title feat(gpu): use gemm ks for trivial indexes feat(gpu): use gemm ks in HL ops Nov 13, 2025
@andrei-stoian-zama andrei-stoian-zama force-pushed the as/gemm_ks branch 3 times, most recently from 7ff1c96 to 3fa7b3f Compare November 17, 2025 20:05
Copy link
Member

@IceTDrinker IceTDrinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick comments on the core non cuda part (I did not read that one), thanks!

@IceTDrinker reviewed 7 of 22 files at r1, all commit messages.
Reviewable status: 7 of 22 files reviewed, 13 unresolved discussions (waiting on @agnesLeroy and @soonum)


tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 79 at r1 (raw file):

    let mut ks_tmp_buffer: *mut ffi::c_void = std::ptr::null_mut();

    let num_lwes_to_ks = min(

it's possible to only partially keyswitch an input ?


tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 84 at r1 (raw file):

    );

    assert_eq!(input_indexes.len, output_indexes.len);

error message could be welcome


tfhe-benchmark/benches/core_crypto/ks_bench.rs line 458 at r1 (raw file):

                            let input_ks_list = LweCiphertextList::from_container(
                                input_ct_list.into_container(),
                                big_lwe_sk.lwe_dimension().to_lwe_size(),

same nits as tests will apply here about key dimensions and ciphertext counts and places where things get collected/transformed into vecs


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 118 at r1 (raw file):

        msg = msg.wrapping_sub(Scalar::ONE);
        for test_idx in 0..NB_TESTS {
            let num_blocks = test_idx * test_idx * 3 + 1;

are those magic numbers ? or could those be randomly chosen ?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 139 at r1 (raw file):

                &mut rsc.encryption_random_generator,
            );
            let input_ks_list = LweCiphertextList::from_container(

why is this required ? it looks like it's just recreating the input_ct_list ?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 149 at r1 (raw file):

            let output_ct_list = LweCiphertextList::new(
                Scalar::ZERO,
                lwe_sk.lwe_dimension().to_lwe_size(),

nit: prefer using the compute key (here ksk) dimension that will be used, tends to help with local reasoning


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 150 at r1 (raw file):

                Scalar::ZERO,
                lwe_sk.lwe_dimension().to_lwe_size(),
                LweCiphertextCount(num_blocks),

nit: use the input.lwe_ciphertext_count(), same for local reasoning


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 151 at r1 (raw file):

                lwe_sk.lwe_dimension().to_lwe_size(),
                LweCiphertextCount(num_blocks),
                ciphertext_modulus,

nit: again use the output modulus of the compute key (ksk)


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 173 at r1 (raw file):

            };
            let lwe_indexes_usize = (0..num_blocks).collect_vec();
            let mut lwe_indexes = lwe_indexes_usize.iter().collect_vec();

let mut lwe_indexes = lwe_indexes_usize.clone();

?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 182 at r1 (raw file):

            }

            if num_blocks_to_ks < num_blocks {

no need to do the check you can always do

lwe_indexes = lwe_indexes[..num_blocks_to_ks];

I believe since the num to ks should always be <= num_blocks ?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 184 at r1 (raw file):

            if num_blocks_to_ks < num_blocks {
                lwe_indexes = lwe_indexes[0..num_blocks_to_ks].to_vec();
                lwe_indexes_out = lwe_indexes_out[0..num_blocks_to_ks].to_vec();

I don't think you need the to_vec

can take a slice like

lwe_indexes = &lwe_indexes[..num_blocks_to_ks];

the whole thing above can be put in the iter below will give an example


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 187 at r1 (raw file):

            }

            let h_lwe_indexes: Vec<Scalar> = lwe_indexes

lwe_indexes.iter().take(num_blocks_to_ks).map(...)


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 216 at r1 (raw file):

            for i in 0..num_blocks_to_ks {
                ref_vec[*lwe_indexes_out[i]] =
                    round_decode(*plaintext_list.get(*lwe_indexes[i]).0, delta); // % msg_modulus;

comment can be removed I think ?

@zama-bot zama-bot removed the approved label Nov 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants