mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
refactor of key-switching & external product
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef,
|
||||
VecZnxOps, ZnxInfos, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -13,7 +12,6 @@ use crate::{
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier,
|
||||
utils::derive_size,
|
||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
||||
};
|
||||
|
||||
pub struct GGLWECiphertext<C, B: Backend> {
|
||||
@@ -212,81 +210,3 @@ where
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl VecGLWEProductScratchSpace for GGLWECiphertext<Vec<u8>, FFT64> {
|
||||
fn prod_with_glwe_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
grlwe_size: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size)
|
||||
+ (module.vec_znx_big_normalize_tmp_bytes()
|
||||
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size)
|
||||
+ module.bytes_of_vec_znx_dft(rank_in, a_size)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecGLWEProduct for GGLWECiphertext<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
fn prod_with_glwe<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GLWECiphertext<R>,
|
||||
a: &GLWECiphertext<A>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.rank(), self.rank_in());
|
||||
assert_eq!(res.rank(), self.rank_out());
|
||||
assert_eq!(res.basek(), basek);
|
||||
assert_eq!(a.basek(), basek);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GGLWECiphertext::prod_with_glwe_scratch_space(
|
||||
module,
|
||||
res.size(),
|
||||
a.size(),
|
||||
self.size(),
|
||||
self.rank_in(),
|
||||
self.rank_out()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = self.rank_in();
|
||||
let cols_out: usize = self.rank_out() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2);
|
||||
}
|
||||
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user