Added rgsw ops

This commit is contained in:
Jean-Philippe Bossuat
2025-05-10 11:26:01 +02:00
parent 9913040aa1
commit ee7b5744e4

View File

@@ -1,12 +1,14 @@
use base2k::{ use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps, ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
elem::Infos, elem::Infos,
elem_grlwe::GRLWECt,
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk},
keys::SecretKeyDft, keys::SecretKeyDft,
utils::derive_size, utils::derive_size,
@@ -69,20 +71,42 @@ impl RGSWCt<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(2, size) + module.bytes_of_vec_znx_dft(2, size)
} }
pub fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, rgsw_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, rgsw_size)
+ ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size))
| module.vec_znx_big_normalize_tmp_bytes())
}
pub fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rgsw_size: usize) -> usize {
Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size)
}
} }
impl<C> RGSWCt<C, FFT64> impl<C> RGSWCt<C, FFT64>
where where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>, MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{ {
pub fn get_row(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<C, FFT64>) pub fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<R, FFT64>)
where where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>, VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64>,
{ {
module.vmp_extract_row(res, self, row_i, col_j); module.vmp_extract_row(res, self, row_i, col_j);
} }
} }
impl<C> RGSWCt<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &RLWECtDft<R, FFT64>)
where
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
pub fn encrypt_rgsw_sk<C, P, S>( pub fn encrypt_rgsw_sk<C, P, S>(
module: &Module<FFT64>, module: &Module<FFT64>,
ct: &mut RGSWCt<C, FFT64>, ct: &mut RGSWCt<C, FFT64>,
@@ -168,4 +192,237 @@ impl<C> RGSWCt<C, FFT64> {
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
) )
} }
pub fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
let log_base2k: usize = self.log_base2k();
#[cfg(debug_assertions)]
{
assert_eq!(res.log_base2k(), log_base2k);
assert_eq!(a.log_base2k(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
}
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
{
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size());
module.vec_znx_dft(&mut a_dft, 0, a, 0);
module.vec_znx_dft(&mut a_dft, 1, a, 1);
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
}
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
}
pub fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnx<R>: VecZnxToMut + VecZnxToRef,
{
unsafe {
let res_ptr: *mut RLWECt<R> = res as *mut RLWECt<R>; // This is ok because [Self::mul_rlwe] only updates res at the end.
self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch);
}
}
pub fn mul_rlwe_dft<R, A>(
&self,
module: &Module<FFT64>,
res: &mut RLWECtDft<R, FFT64>,
a: &RLWECtDft<A, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
let log_base2k: usize = self.log_base2k();
#[cfg(debug_assertions)]
{
assert_eq!(res.log_base2k(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
}
let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size());
let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> {
data: a_data,
log_base2k: a.log_base2k(),
log_k: a.log_k(),
};
a.idft(module, &mut a_idft, scratch_1);
let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size());
let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> {
data: res_data,
log_base2k: res.log_base2k(),
log_k: res.log_k(),
};
self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2);
module.vec_znx_dft(res, 0, &res_idft, 0);
module.vec_znx_dft(res, 1, &res_idft, 1);
}
pub fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
{
let log_base2k: usize = self.log_base2k();
#[cfg(debug_assertions)]
{
assert_eq!(res.log_base2k(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
}
let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size());
let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> {
data: res_data,
log_base2k: res.log_base2k(),
log_k: res.log_k(),
};
res.idft(module, &mut res_idft, scratch_1);
self.mul_rlwe_inplace(module, &mut res_idft, scratch_1);
module.vec_znx_dft(res, 0, &res_idft, 0);
module.vec_znx_dft(res, 1, &res_idft, 1);
}
pub fn mul_grlwe<R, A>(
&self,
module: &Module<FFT64>,
res: &mut GRLWECt<R, FFT64>,
a: &GRLWECt<A, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size());
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
data: tmp_row_data,
log_base2k: a.log_base2k(),
log_k: a.log_k(),
};
let min_rows: usize = res.rows().min(a.rows());
(0..min_rows).for_each(|row_i| {
a.get_row(module, row_i, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, &tmp_row);
});
tmp_row.data.zero();
(min_rows..res.rows()).for_each(|row_i| {
res.set_row(module, row_i, &tmp_row);
})
}
pub fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size());
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
data: tmp_row_data,
log_base2k: res.log_base2k(),
log_k: res.log_k(),
};
(0..res.rows()).for_each(|row_i| {
res.get_row(module, row_i, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, &tmp_row);
});
}
pub fn mul_rgsw<R, A>(&self, module: &Module<FFT64>, res: &mut RGSWCt<R, FFT64>, a: &RGSWCt<A, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size());
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
data: tmp_row_data,
log_base2k: a.log_base2k(),
log_k: a.log_k(),
};
let min_rows: usize = res.rows().min(a.rows());
(0..min_rows).for_each(|row_i| {
a.get_row(module, row_i, 0, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, 0, &tmp_row);
});
(0..min_rows).for_each(|row_i| {
a.get_row(module, row_i, 1, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, 1, &tmp_row);
});
tmp_row.data.zero();
(min_rows..res.rows()).for_each(|row_i| {
res.set_row(module, row_i, 0, &tmp_row);
res.set_row(module, row_i, 1, &tmp_row);
})
}
pub fn mul_rgsw_inplace<R>(&self, module: &Module<FFT64>, res: &mut RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size());
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
data: tmp_row_data,
log_base2k: res.log_base2k(),
log_k: res.log_k(),
};
(0..res.rows()).for_each(|row_i| {
res.get_row(module, row_i, 0, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, 0, &tmp_row);
});
(0..res.rows()).for_each(|row_i| {
res.get_row(module, row_i, 1, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, 1, &tmp_row);
});
}
} }