Some traits updates + added missing tests for products on RGSWCt

This commit is contained in:
Jean-Philippe Bossuat
2025-05-12 14:40:17 +02:00
parent e38ca404f9
commit d8a7d6cdaf
9 changed files with 2295 additions and 1914 deletions

View File

@@ -66,92 +66,88 @@ pub trait SetRow<B: Backend> {
VecZnxDft<A, B>: VecZnxDftToRef<B>;
}
pub trait ProdByScratchSpace {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
pub trait ProdInplaceScratchSpace {
fn prod_by_grlwe_inplace_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
fn prod_by_rgsw_inplace_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
}
pub trait ProdBy<D> {
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
}
pub trait FromProdByScratchSpace {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
}
pub trait FromProdBy<D, L> {
fn from_prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, lhs: &L, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
fn from_prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, lhs: &L, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
}
pub(crate) trait MatZnxDftProducts<D, C>: Infos
pub trait ProdInplace<MUT, REF>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<REF, FFT64>: MatZnxDftToRef<FFT64>,
{
fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef;
fn prod_by_grlwe_inplace(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<REF, FFT64>, scratch: &mut Scratch);
fn prod_by_rgsw_inplace(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<REF, FFT64>, scratch: &mut Scratch);
}
fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize;
pub trait ProdScratchSpace {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
}
fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size)
pub trait Product<MUT, REF>
where
MatZnxDft<REF, FFT64>: MatZnxDftToRef<FFT64>,
{
type Lhs;
fn prod_by_grlwe(&mut self, module: &Module<FFT64>, lhs: &Self::Lhs, rhs: &GRLWECt<REF, FFT64>, scratch: &mut Scratch);
fn prod_by_rgsw(&mut self, module: &Module<FFT64>, lhs: &Self::Lhs, rhs: &RGSWCt<REF, FFT64>, scratch: &mut Scratch);
}
pub(crate) trait MatRLWEProductScratchSpace {
fn prod_with_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize;
fn prod_with_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::prod_with_rlwe_scratch_space(module, res_size, res_size, mat_size)
}
fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
(Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes())
fn prod_with_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
(Self::prod_with_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, a_size)
+ module.bytes_of_vec_znx(2, res_size)
}
fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
(Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes())
fn prod_with_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
(Self::prod_with_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, res_size)
}
fn mul_mat_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size)
fn prod_with_mat_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
Self::prod_with_rlwe_dft_scratch_space(module, res_size, a_size, mat_size)
+ module.bytes_of_vec_znx_dft(2, a_size)
+ module.bytes_of_vec_znx_dft(2, res_size)
}
fn mul_mat_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size)
fn prod_with_mat_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::prod_with_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size)
}
}
fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
pub(crate) trait MatRLWEProduct: Infos {
fn prod_with_rlwe<MUT, REF>(&self, module: &Module<FFT64>, res: &mut RLWECt<MUT>, a: &RLWECt<REF>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnx<R>: VecZnxToMut + VecZnxToRef,
VecZnx<MUT>: VecZnxToMut,
VecZnx<REF>: VecZnxToRef;
fn prod_with_rlwe_inplace<MUT>(&self, module: &Module<FFT64>, res: &mut RLWECt<MUT>, scratch: &mut Scratch)
where
VecZnx<MUT>: VecZnxToMut + VecZnxToRef,
{
unsafe {
let res_ptr: *mut RLWECt<R> = res as *mut RLWECt<R>; // This is ok because [Self::mul_rlwe] only updates res at the end.
self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch);
let res_ptr: *mut RLWECt<MUT> = res as *mut RLWECt<MUT>; // This is ok because [Self::mul_rlwe] only updates res at the end.
self.prod_with_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch);
}
}
fn mul_rlwe_dft<R, A>(
fn prod_with_rlwe_dft<MUT, REF>(
&self,
module: &Module<FFT64>,
res: &mut RLWECtDft<R, FFT64>,
a: &RLWECtDft<A, FFT64>,
res: &mut RLWECtDft<MUT, FFT64>,
a: &RLWECtDft<REF, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<MUT, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<REF, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
let log_base2k: usize = self.log_base2k();
@@ -180,15 +176,15 @@ where
log_k: res.log_k(),
};
self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2);
self.prod_with_rlwe(module, &mut res_idft, &a_idft, scratch_2);
module.vec_znx_dft(res, 0, &res_idft, 0);
module.vec_znx_dft(res, 1, &res_idft, 1);
}
fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
fn prod_with_rlwe_dft_inplace<MUT>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<MUT, FFT64>, scratch: &mut Scratch)
where
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
VecZnxDft<MUT, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
{
let log_base2k: usize = self.log_base2k();
@@ -209,47 +205,55 @@ where
res.idft(module, &mut res_idft, scratch_1);
self.mul_rlwe_inplace(module, &mut res_idft, scratch_1);
self.prod_with_rlwe_inplace(module, &mut res_idft, scratch_1);
module.vec_znx_dft(res, 0, &res_idft, 0);
module.vec_znx_dft(res, 1, &res_idft, 1);
}
fn mul_mat_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch)
fn prod_with_mat_rlwe<RES, LHS>(&self, module: &Module<FFT64>, res: &mut RES, a: &LHS, scratch: &mut Scratch)
where
A: GetRow<FFT64> + Infos,
R: SetRow<FFT64> + Infos,
LHS: GetRow<FFT64> + Infos,
RES: SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size());
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
let mut tmp_a_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
data: tmp_row_data,
log_base2k: a.log_base2k(),
log_k: a.log_k(),
};
let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size());
let mut tmp_res_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
data: tmp_res_data,
log_base2k: res.log_base2k(),
log_k: res.log_k(),
};
let min_rows: usize = res.rows().min(a.rows());
(0..res.rows()).for_each(|row_i| {
(0..res.cols()).for_each(|col_j| {
a.get_row(module, row_i, col_j, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, col_j, &tmp_row);
a.get_row(module, row_i, col_j, &mut tmp_a_row);
self.prod_with_rlwe_dft(module, &mut tmp_res_row, &tmp_a_row, scratch2);
res.set_row(module, row_i, col_j, &tmp_res_row);
});
});
tmp_row.data.zero();
tmp_res_row.data.zero();
(min_rows..res.rows()).for_each(|row_i| {
(0..self.cols()).for_each(|col_j| {
res.set_row(module, row_i, col_j, &tmp_row);
res.set_row(module, row_i, col_j, &tmp_res_row);
});
});
}
fn mul_mat_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
fn prod_with_mat_rlwe_inplace<RES>(&self, module: &Module<FFT64>, res: &mut RES, scratch: &mut Scratch)
where
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
RES: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size());
@@ -262,7 +266,7 @@ where
(0..res.rows()).for_each(|row_i| {
(0..res.cols()).for_each(|col_j| {
res.get_row(module, row_i, col_j, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
self.prod_with_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, col_j, &tmp_row);
});
});