abstracted products for all cross types

This commit is contained in:
Jean-Philippe Bossuat
2025-05-11 18:33:47 +02:00
parent 54fab8e4f3
commit 73098af73a
9 changed files with 1219 additions and 946 deletions

View File

@@ -1,10 +1,11 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut,
VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
};
use crate::{
grlwe::GRLWECt,
rgsw::RGSWCt,
rlwe::{RLWECt, RLWECtDft},
utils::derive_size,
};
@@ -65,6 +66,36 @@ pub trait SetRow<B: Backend> {
VecZnxDft<A, B>: VecZnxDftToRef<B>;
}
pub trait ProdByScratchSpace {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
}
pub trait ProdBy<D> {
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
}
pub trait FromProdByScratchSpace {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
}
pub trait FromProdBy<D, L> {
fn from_prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, lhs: &L, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
fn from_prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, lhs: &L, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
}
pub(crate) trait MatZnxDftProducts<D, C>: Infos
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
@@ -75,6 +106,31 @@ where
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef;
fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize;
fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size)
}
fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
(Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, a_size)
+ module.bytes_of_vec_znx(2, res_size)
}
fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
(Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, res_size)
}
fn mul_mat_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
fn mul_mat_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size)
}
fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
@@ -132,7 +188,6 @@ where
fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
{
let log_base2k: usize = self.log_base2k();
@@ -160,11 +215,10 @@ where
module.vec_znx_dft(res, 1, &res_idft, 1);
}
fn mul_grlwe<R, A>(&self, module: &Module<FFT64>, res: &mut GRLWECt<R, FFT64>, a: &GRLWECt<A, FFT64>, scratch: &mut Scratch)
fn mul_mat_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
A: GetRow<FFT64> + Infos,
R: SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size());
@@ -176,22 +230,25 @@ where
let min_rows: usize = res.rows().min(a.rows());
(0..min_rows).for_each(|row_i| {
a.get_row(module, row_i, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, &tmp_row);
(0..res.rows()).for_each(|row_i| {
(0..self.cols()).for_each(|col_j| {
a.get_row(module, row_i, col_j, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, col_j, &tmp_row);
});
});
tmp_row.data.zero();
(min_rows..res.rows()).for_each(|row_i| {
res.set_row(module, row_i, &tmp_row);
})
(0..self.cols()).for_each(|col_j| {
res.set_row(module, row_i, col_j, &tmp_row);
});
});
}
fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
fn mul_mat_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size());
@@ -202,12 +259,12 @@ where
log_k: res.log_k(),
};
(0..self.cols()).for_each(|col_j| {
(0..res.rows()).for_each(|row_i| {
(0..res.rows()).for_each(|row_i| {
(0..self.cols()).for_each(|col_j| {
res.get_row(module, row_i, col_j, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, col_j, &tmp_row);
});
})
});
}
}