refactor of key-switching & external product

This commit is contained in:
Jean-Philippe Bossuat
2025-05-15 18:24:56 +02:00
parent 723a41acd0
commit ccd7450c5f
15 changed files with 1593 additions and 1740 deletions

View File

@@ -1,8 +1,7 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef,
VecZnxOps, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
@@ -13,7 +12,6 @@ use crate::{
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GGLWECiphertext<C, B: Backend> {
@@ -212,81 +210,3 @@ where
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
impl VecGLWEProductScratchSpace for GGLWECiphertext<Vec<u8>, FFT64> {
fn prod_with_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
grlwe_size: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size)
+ module.bytes_of_vec_znx_dft(rank_in, a_size)))
}
}
impl<C> VecGLWEProduct for GGLWECiphertext<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
fn prod_with_glwe<R, A>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertext<R>,
a: &GLWECiphertext<A>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(a.rank(), self.rank_in());
assert_eq!(res.rank(), self.rank_out());
assert_eq!(res.basek(), basek);
assert_eq!(a.basek(), basek);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
assert!(
scratch.available()
>= GGLWECiphertext::prod_with_glwe_scratch_space(
module,
res.size(),
a.size(),
self.size(),
self.rank_in(),
self.rank_out()
)
);
}
let cols_in: usize = self.rank_in();
let cols_out: usize = self.rank_out() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise
{
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1);
});
}
}