abstracted products for all cross types

This commit is contained in:
Jean-Philippe Bossuat
2025-05-11 18:33:47 +02:00
parent 54fab8e4f3
commit 73098af73a
9 changed files with 1219 additions and 946 deletions

View File

@@ -1,10 +1,11 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut,
VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
};
use crate::{
grlwe::GRLWECt,
rgsw::RGSWCt,
rlwe::{RLWECt, RLWECtDft},
utils::derive_size,
};
@@ -65,6 +66,36 @@ pub trait SetRow<B: Backend> {
VecZnxDft<A, B>: VecZnxDftToRef<B>;
}
pub trait ProdByScratchSpace {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize;
}
pub trait ProdBy<D> {
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
}
pub trait FromProdByScratchSpace {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize;
}
pub trait FromProdBy<D, L> {
fn from_prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, lhs: &L, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
fn from_prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, lhs: &L, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>;
}
pub(crate) trait MatZnxDftProducts<D, C>: Infos
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
@@ -75,6 +106,31 @@ where
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef;
fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize;
fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_scratch_space(module, res_size, res_size, mat_size)
}
fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
(Self::mul_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, a_size)
+ module.bytes_of_vec_znx(2, res_size)
}
fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
(Self::mul_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, res_size)
}
fn mul_mat_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
fn mul_mat_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, mat_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size)
}
fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
@@ -132,7 +188,6 @@ where
fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
{
let log_base2k: usize = self.log_base2k();
@@ -160,11 +215,10 @@ where
module.vec_znx_dft(res, 1, &res_idft, 1);
}
fn mul_grlwe<R, A>(&self, module: &Module<FFT64>, res: &mut GRLWECt<R, FFT64>, a: &GRLWECt<A, FFT64>, scratch: &mut Scratch)
fn mul_mat_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut R, a: &A, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
A: GetRow<FFT64> + Infos,
R: SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size());
@@ -176,22 +230,25 @@ where
let min_rows: usize = res.rows().min(a.rows());
(0..min_rows).for_each(|row_i| {
a.get_row(module, row_i, &mut tmp_row);
(0..res.rows()).for_each(|row_i| {
(0..self.cols()).for_each(|col_j| {
a.get_row(module, row_i, col_j, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, &tmp_row);
res.set_row(module, row_i, col_j, &tmp_row);
});
});
tmp_row.data.zero();
(min_rows..res.rows()).for_each(|row_i| {
res.set_row(module, row_i, &tmp_row);
})
(0..self.cols()).for_each(|col_j| {
res.set_row(module, row_i, col_j, &tmp_row);
});
});
}
fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
fn mul_mat_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size());
@@ -202,12 +259,12 @@ where
log_k: res.log_k(),
};
(0..self.cols()).for_each(|col_j| {
(0..res.rows()).for_each(|row_i| {
(0..self.cols()).for_each(|col_j| {
res.get_row(module, row_i, col_j, &mut tmp_row);
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
res.set_row(module, row_i, col_j, &tmp_row);
});
})
});
}
}

View File

@@ -7,8 +7,9 @@ use base2k::{
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, MatZnxDftProducts, SetRow},
elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow},
keys::SecretKeyDft,
rgsw::RGSWCt,
rlwe::{RLWECt, RLWECtDft, RLWEPt},
utils::derive_size,
};
@@ -41,18 +42,6 @@ where
}
}
impl<C> GRLWECt<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, a: &RLWECtDft<R, FFT64>)
where
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, 0, a);
}
}
impl<T, B: Backend> Infos for GRLWECt<T, B> {
type Inner = MatZnxDft<T, B>;
@@ -94,36 +83,6 @@ impl GRLWECt<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(2, size)
}
pub fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, grlwe_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size)
+ module.bytes_of_vec_znx_dft(1, a_size)))
}
pub fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_scratch_space(module, res_size, res_size, grlwe_size)
}
pub fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
(Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, a_size)
+ module.bytes_of_vec_znx(2, res_size)
}
pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, grlwe_size: usize) -> usize {
(Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, res_size)
}
pub fn mul_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
pub fn mul_grlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
}
pub fn encrypt_grlwe_sk<C, P, S>(
@@ -209,67 +168,6 @@ impl<C> GRLWECt<C, FFT64> {
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
pub fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch);
}
pub fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnx<R>: VecZnxToMut + VecZnxToRef,
{
MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch);
}
pub fn mul_rlwe_dft<R, A>(
&self,
module: &Module<FFT64>,
res: &mut RLWECtDft<R, FFT64>,
a: &RLWECtDft<A, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch);
}
pub fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
{
MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch);
}
pub fn mul_grlwe<R, A>(
&self,
module: &Module<FFT64>,
res: &mut GRLWECt<R, FFT64>,
a: &GRLWECt<A, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch);
}
pub fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch);
}
}
impl<C> GetRow<FFT64> for GRLWECt<C, FFT64>
@@ -308,6 +206,13 @@ impl<C> MatZnxDftProducts<GRLWECt<C, FFT64>, C> for GRLWECt<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, grlwe_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size)
+ module.bytes_of_vec_znx_dft(1, a_size)))
}
fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
@@ -341,3 +246,80 @@ where
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
}
}
impl ProdByScratchSpace for GRLWECt<Vec<u8>, FFT64> {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_inplace_scratch_space(
module, lhs, rhs,
)
}
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_inplace_scratch_space(
module, lhs, rhs,
)
}
}
impl FromProdByScratchSpace for GRLWECt<Vec<u8>, FFT64> {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_scratch_space(
module, res_size, lhs, rhs,
)
}
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_scratch_space(
module, res_size, lhs, rhs,
)
}
}
impl<MUT> ProdBy<GRLWECt<MUT, FFT64>> for GRLWECt<MUT, FFT64>
where
GRLWECt<MUT, FFT64>: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe_inplace(module, self, scratch);
}
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe_inplace(module, self, scratch);
}
}
impl<MUT, REF> FromProdBy<GRLWECt<MUT, FFT64>, GRLWECt<REF, FFT64>> for GRLWECt<MUT, FFT64>
where
GRLWECt<MUT, FFT64>: GetRow<FFT64> + SetRow<FFT64> + Infos,
GRLWECt<REF, FFT64>: GetRow<FFT64> + Infos,
{
fn from_prod_by_grlwe<R>(
&mut self,
module: &Module<FFT64>,
lhs: &GRLWECt<REF, FFT64>,
rhs: &GRLWECt<R, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe(module, self, lhs, scratch);
}
fn from_prod_by_rgsw<R>(
&mut self,
module: &Module<FFT64>,
lhs: &GRLWECt<REF, FFT64>,
rhs: &RGSWCt<R, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe(module, self, lhs, scratch);
}
}

View File

@@ -7,7 +7,7 @@ use base2k::{
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, MatZnxDftProducts, SetRow},
elem::{FromProdBy, FromProdByScratchSpace, GetRow, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace, SetRow},
grlwe::GRLWECt,
keys::SecretKeyDft,
rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk},
@@ -71,43 +71,6 @@ impl RGSWCt<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(2, size)
}
pub fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, rgsw_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, rgsw_size)
+ ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size))
| module.vec_znx_big_normalize_tmp_bytes())
}
pub fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rgsw_size: usize) -> usize {
Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size)
}
pub fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
(Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, a_size)
+ module.bytes_of_vec_znx(2, res_size)
}
pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, grlwe_size: usize) -> usize {
(Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
+ module.bytes_of_vec_znx(2, res_size)
}
pub fn mul_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
pub fn mul_grlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
pub fn mul_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
pub fn mul_rgsw_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
}
}
pub fn encrypt_rgsw_sk<C, P, S>(
@@ -195,67 +158,6 @@ impl<C> RGSWCt<C, FFT64> {
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
pub fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch);
}
pub fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnx<R>: VecZnxToMut + VecZnxToRef,
{
MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch);
}
pub fn mul_rlwe_dft<R, A>(
&self,
module: &Module<FFT64>,
res: &mut RLWECtDft<R, FFT64>,
a: &RLWECtDft<A, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch);
}
pub fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
{
MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch);
}
pub fn mul_grlwe<R, A>(
&self,
module: &Module<FFT64>,
res: &mut GRLWECt<R, FFT64>,
a: &GRLWECt<A, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch);
}
pub fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch);
}
}
impl<C> GetRow<FFT64> for RGSWCt<C, FFT64>
@@ -286,6 +188,12 @@ impl<C> MatZnxDftProducts<RGSWCt<C, FFT64>, C> for RGSWCt<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, rgsw_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, rgsw_size)
+ ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size))
| module.vec_znx_big_normalize_tmp_bytes())
}
fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
@@ -318,3 +226,80 @@ where
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
}
}
impl ProdByScratchSpace for RGSWCt<Vec<u8>, FFT64> {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_inplace_scratch_space(
module, lhs, rhs,
)
}
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_inplace_scratch_space(
module, lhs, rhs,
)
}
}
impl FromProdByScratchSpace for RGSWCt<Vec<u8>, FFT64> {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_scratch_space(
module, res_size, lhs, rhs,
)
}
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_mat_rlwe_scratch_space(
module, res_size, lhs, rhs,
)
}
}
impl<MUT> ProdBy<RGSWCt<MUT, FFT64>> for RGSWCt<MUT, FFT64>
where
RGSWCt<MUT, FFT64>: GetRow<FFT64> + SetRow<FFT64> + Infos,
{
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe_inplace(module, self, scratch);
}
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe_inplace(module, self, scratch);
}
}
impl<MUT, REF> FromProdBy<RGSWCt<MUT, FFT64>, RGSWCt<REF, FFT64>> for RGSWCt<MUT, FFT64>
where
RGSWCt<MUT, FFT64>: GetRow<FFT64> + SetRow<FFT64> + Infos,
RGSWCt<REF, FFT64>: GetRow<FFT64> + Infos,
{
fn from_prod_by_grlwe<R>(
&mut self,
module: &Module<FFT64>,
lhs: &RGSWCt<REF, FFT64>,
rhs: &GRLWECt<R, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe(module, self, lhs, scratch);
}
fn from_prod_by_rgsw<R>(
&mut self,
module: &Module<FFT64>,
lhs: &RGSWCt<REF, FFT64>,
rhs: &RGSWCt<R, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_mat_rlwe(module, self, lhs, scratch);
}
}

View File

@@ -6,9 +6,10 @@ use base2k::{
use sampling::source::Source;
use crate::{
elem::Infos,
elem::{FromProdBy, FromProdByScratchSpace, Infos, MatZnxDftProducts, ProdBy, ProdByScratchSpace},
grlwe::GRLWECt,
keys::{PublicKey, SecretDistribution, SecretKeyDft},
rgsw::RGSWCt,
utils::derive_size,
};
@@ -83,134 +84,70 @@ where
}
}
pub struct RLWEPt<C> {
pub data: VecZnx<C>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<T> Infos for RLWEPt<T> {
type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner {
&self.data
impl ProdByScratchSpace for RLWECt<Vec<u8>> {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_inplace_scratch_space(
module, lhs, rhs,
)
}
fn log_base2k(&self) -> usize {
self.log_base2k
}
fn log_k(&self) -> usize {
self.log_k
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_inplace_scratch_space(
module, lhs, rhs,
)
}
}
impl<C> VecZnxToMut for RLWEPt<C>
impl FromProdByScratchSpace for RLWECt<Vec<u8>> {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_scratch_space(
module, res_size, lhs, rhs,
)
}
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_scratch_space(
module, res_size, lhs, rhs,
)
}
}
impl<MUT> ProdBy<RLWECt<MUT>> for RLWECt<MUT>
where
VecZnx<C>: VecZnxToMut,
VecZnx<MUT>: VecZnxToMut + VecZnxToRef,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for RLWEPt<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl RLWEPt<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
pub struct RLWECtDft<C, B: Backend> {
pub data: VecZnxDft<C, B>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
impl<T, B: Backend> Infos for RLWECtDft<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn log_base2k(&self) -> usize {
self.log_base2k
}
fn log_k(&self) -> usize {
self.log_k
}
}
impl<C, B: Backend> VecZnxDftToMut<B> for RLWECtDft<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for RLWECtDft<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl<C> RLWECtDft<C, FFT64>
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
{
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub(crate) fn idft<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
VecZnx<R>: VecZnxToMut,
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.cols(), 2);
assert_eq!(res.cols(), 2);
assert_eq!(self.log_base2k(), res.log_base2k())
rhs.mul_rlwe_inplace(module, self, scratch);
}
let min_size: usize = self.size().min(res.size());
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_rlwe_inplace(module, self, scratch);
}
}
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size);
impl<MUT, REF> FromProdBy<RLWECt<MUT>, RLWECt<REF>> for RLWECt<MUT>
where
VecZnx<MUT>: VecZnxToMut + VecZnxToRef,
VecZnx<REF>: VecZnxToRef,
{
fn from_prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, lhs: &RLWECt<REF>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_rlwe(module, self, lhs, scratch);
}
module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1);
module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1);
module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1);
module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1);
fn from_prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, lhs: &RLWECt<REF>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_rlwe(module, self, lhs, scratch);
}
}
@@ -390,6 +327,204 @@ impl<C> RLWECt<C> {
}
}
pub(crate) fn encrypt_rlwe_pk<C, P, S>(
module: &Module<FFT64>,
ct: &mut RLWECt<C>,
pt: Option<&RLWEPt<P>>,
pk: &PublicKey<S, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(ct.log_base2k(), pk.log_base2k());
assert_eq!(ct.n(), module.n());
assert_eq!(pk.n(), module.n());
if let Some(pt) = pt {
assert_eq!(pt.log_base2k(), pk.log_base2k());
assert_eq!(pt.n(), module.n());
}
}
let log_base2k: usize = pk.log_base2k();
let size_pk: usize = pk.size();
// Generates u according to the underlying secret distribution.
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
{
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
match pk.dist {
SecretDistribution::NONE => panic!(
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
),
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
SecretDistribution::ZERO => {}
}
module.svp_prepare(&mut u_dft, 0, &u, 0);
}
let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
// ct[0] = pk[0] * u + m + e0
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0);
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
if let Some(pt) = pt {
module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0);
}
module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3);
// ct[1] = pk[1] * u + e1
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1);
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3);
}
pub struct RLWEPt<C> {
pub data: VecZnx<C>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<T> Infos for RLWEPt<T> {
type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn log_base2k(&self) -> usize {
self.log_base2k
}
fn log_k(&self) -> usize {
self.log_k
}
}
impl<C> VecZnxToMut for RLWEPt<C>
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for RLWEPt<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl RLWEPt<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
pub struct RLWECtDft<C, B: Backend> {
pub data: VecZnxDft<C, B>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
impl<T, B: Backend> Infos for RLWECtDft<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn log_base2k(&self) -> usize {
self.log_base2k
}
fn log_k(&self) -> usize {
self.log_k
}
}
impl<C, B: Backend> VecZnxDftToMut<B> for RLWECtDft<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for RLWECtDft<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl<C> RLWECtDft<C, FFT64>
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
{
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub(crate) fn idft<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
where
VecZnx<R>: VecZnxToMut,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.cols(), 2);
assert_eq!(res.cols(), 2);
assert_eq!(self.log_base2k(), res.log_base2k())
}
let min_size: usize = self.size().min(res.size());
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size);
module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1);
module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1);
module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1);
module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1);
}
}
pub(crate) fn encrypt_zero_rlwe_dft_sk<C, S>(
module: &Module<FFT64>,
ct: &mut RLWECtDft<C, FFT64>,
@@ -528,79 +663,81 @@ impl<C> RLWECtDft<C, FFT64> {
{
decrypt_rlwe_dft(module, pt, self, sk_dft, scratch);
}
}
pub fn mul_grlwe_assign<A>(&mut self, module: &Module<FFT64>, a: &GRLWECt<A, FFT64>, scratch: &mut Scratch)
where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64>,
{
a.mul_rlwe_dft_inplace(module, self, scratch);
impl ProdByScratchSpace for RLWECtDft<Vec<u8>, FFT64> {
fn prod_by_grlwe_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_dft_inplace_scratch_space(
module, lhs, rhs,
)
}
fn prod_by_rgsw_scratch_space(module: &Module<FFT64>, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_dft_inplace_scratch_space(
module, lhs, rhs,
)
}
}
pub(crate) fn encrypt_rlwe_pk<C, P, S>(
module: &Module<FFT64>,
ct: &mut RLWECt<C>,
pt: Option<&RLWEPt<P>>,
pk: &PublicKey<S, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
impl FromProdByScratchSpace for RLWECtDft<Vec<u8>, FFT64> {
fn from_prod_by_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GRLWECt<Vec<u8>, FFT64> as MatZnxDftProducts<GRLWECt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_dft_scratch_space(
module, res_size, lhs, rhs,
)
}
fn from_prod_by_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<RGSWCt<Vec<u8>, FFT64> as MatZnxDftProducts<RGSWCt<Vec<u8>, FFT64>, Vec<u8>>>::mul_rlwe_dft_scratch_space(
module, res_size, lhs, rhs,
)
}
}
impl<MUT> ProdBy<RLWECtDft<MUT, FFT64>> for RLWECtDft<MUT, FFT64>
where
VecZnxDft<MUT, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
fn prod_by_grlwe<R>(&mut self, module: &Module<FFT64>, rhs: &GRLWECt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
assert_eq!(ct.log_base2k(), pk.log_base2k());
assert_eq!(ct.n(), module.n());
assert_eq!(pk.n(), module.n());
if let Some(pt) = pt {
assert_eq!(pt.log_base2k(), pk.log_base2k());
assert_eq!(pt.n(), module.n());
}
rhs.mul_rlwe_dft_inplace(module, self, scratch);
}
let log_base2k: usize = pk.log_base2k();
let size_pk: usize = pk.size();
// Generates u according to the underlying secret distribution.
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
fn prod_by_rgsw<R>(&mut self, module: &Module<FFT64>, rhs: &RGSWCt<R, FFT64>, scratch: &mut Scratch)
where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
match pk.dist {
SecretDistribution::NONE => panic!(
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
),
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
SecretDistribution::ZERO => {}
rhs.mul_rlwe_dft_inplace(module, self, scratch);
}
}
impl<MUT, REF> FromProdBy<RLWECtDft<MUT, FFT64>, RLWECtDft<REF, FFT64>> for RLWECtDft<MUT, FFT64>
where
VecZnxDft<MUT, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
VecZnxDft<REF, FFT64>: VecZnxDftToRef<FFT64>,
{
fn from_prod_by_grlwe<R>(
&mut self,
module: &Module<FFT64>,
lhs: &RLWECtDft<REF, FFT64>,
rhs: &GRLWECt<R, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_rlwe_dft(module, self, lhs, scratch);
}
fn from_prod_by_rgsw<R>(
&mut self,
module: &Module<FFT64>,
lhs: &RLWECtDft<REF, FFT64>,
rhs: &RGSWCt<R, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<R, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.mul_rlwe_dft(module, self, lhs, scratch);
}
module.svp_prepare(&mut u_dft, 0, &u, 0);
}
let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
// ct[0] = pk[0] * u + m + e0
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0);
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
if let Some(pt) = pt {
module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0);
}
module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3);
// ct[1] = pk[1] * u + e1
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1);
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3);
}

View File

@@ -1,14 +1,14 @@
#[cfg(test)]
mod tests {
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps};
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps};
use sampling::source::Source;
use crate::{
elem::Infos,
elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace},
grlwe::GRLWECt,
keys::{SecretKey, SecretKeyDft},
rlwe::{RLWECt, RLWECtDft, RLWEPt},
rlwe::{RLWECtDft, RLWEPt},
test_fft64::grlwe::noise_grlwe_rlwe_product,
};
@@ -67,413 +67,7 @@ mod tests {
}
#[test]
fn mul_rlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| GRLWECt::mul_rlwe_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_grlwe.size(),
),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_grlwe.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn mul_rlwe_inplace() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe: usize = 45;
let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_grlwe.mul_rlwe_inplace(&module, &mut ct_rlwe, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn mul_rlwe_dft() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_in_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
let mut ct_rlwe_out_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| GRLWECt::mul_rlwe_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_grlwe.size(),
),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft);
ct_grlwe.mul_rlwe_dft(
&module,
&mut ct_rlwe_out_dft,
&ct_rlwe_in_dft,
scratch.borrow(),
);
ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn mul_rlwe_dft_inplace() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe: usize = 45;
let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe);
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.dft(&module, &mut ct_rlwe_dft);
ct_grlwe.mul_rlwe_dft_inplace(&module, &mut ct_rlwe_dft, scratch.borrow());
ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn mul_grlwe() {
fn from_prod_by_grlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
@@ -493,7 +87,7 @@ mod tests {
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size())
| RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size())
| GRLWECt::mul_grlwe_scratch_space(
| GRLWECt::from_prod_by_grlwe_scratch_space(
&module,
ct_grlwe_s0s2.size(),
ct_grlwe_s0s1.size(),
@@ -544,12 +138,7 @@ mod tests {
);
// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0)
ct_grlwe_s1s2.mul_grlwe(
&module,
&mut ct_grlwe_s0s2,
&ct_grlwe_s0s1,
scratch.borrow(),
);
ct_grlwe_s0s2.from_prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow());
let mut ct_rlwe_dft_s0s2: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe);
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_grlwe);
@@ -584,7 +173,7 @@ mod tests {
}
#[test]
fn mul_grlwe_inplace() {
fn prod_by_grlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
@@ -603,12 +192,7 @@ mod tests {
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size())
| RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size())
| GRLWECt::mul_grlwe_scratch_space(
&module,
ct_grlwe_s0s1.size(),
ct_grlwe_s0s1.size(),
ct_grlwe_s1s2.size(),
),
| GRLWECt::prod_by_grlwe_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
@@ -654,7 +238,7 @@ mod tests {
);
// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0)
ct_grlwe_s1s2.mul_grlwe_inplace(&module, &mut ct_grlwe_s0s1, scratch.borrow());
ct_grlwe_s0s1.prod_by_grlwe(&module, &ct_grlwe_s1s2, scratch.borrow());
let ct_grlwe_s0s2: GRLWECt<Vec<u8>, FFT64> = ct_grlwe_s0s1;

View File

@@ -1,3 +1,4 @@
mod grlwe;
mod rgsw;
mod rlwe;
mod rlwe_dft;

View File

@@ -2,7 +2,7 @@
mod tests {
use base2k::{
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero,
};
use sampling::source::Source;
@@ -86,120 +86,6 @@ mod tests {
module.free();
}
#[test]
fn mul_rlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_rgsw: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
// pt_want
// .data
// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| RGSWCt::mul_rlwe_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_rgsw.size(),
),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
Some(&pt_want),
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rgsw.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_rgsw_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
}
#[allow(dead_code)]

View File

@@ -1,13 +1,19 @@
#[cfg(test)]
mod tests {
use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero};
mod tests_rlwe {
use base2k::{
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
ZnxViewMut, ZnxZero,
};
use itertools::izip;
use sampling::source::Source;
use crate::{
elem::Infos,
elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace},
grlwe::GRLWECt,
keys::{PublicKey, SecretKey, SecretKeyDft},
rgsw::RGSWCt,
rlwe::{RLWECt, RLWECtDft, RLWEPt},
test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_rlwe_product},
};
#[test]
@@ -193,4 +199,423 @@ mod tests {
module.free();
}
#[test]
fn from_prod_by_grlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| RLWECt::from_prod_by_grlwe_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_grlwe.size(),
),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_out.from_prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn prod_grlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe: usize = 45;
let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| RLWECt::prod_by_grlwe_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn from_prod_by_rgsw() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_rgsw: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| RLWECt::from_prod_by_rgsw_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_rgsw.size(),
),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
Some(&pt_want),
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_out.from_prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_rgsw_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn prod_by_rgsw() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_rgsw: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| RLWECt::prod_by_rgsw_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
Some(&pt_want),
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.prod_by_rgsw(&module, &ct_rgsw, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_rgsw_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
}

View File

@@ -0,0 +1,216 @@
#[cfg(test)]
mod tests {
use crate::{
elem::{FromProdBy, FromProdByScratchSpace, Infos, ProdBy, ProdByScratchSpace},
grlwe::GRLWECt,
keys::{SecretKey, SecretKeyDft},
rlwe::{RLWECt, RLWECtDft, RLWEPt},
test_fft64::grlwe::noise_grlwe_rlwe_product,
};
use base2k::{FFT64, FillUniform, Module, ScratchOwned, Stats, VecZnxOps};
use sampling::source::Source;
#[test]
fn from_prod_by_grlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe_in: usize = 45;
let log_k_rlwe_out: usize = 60;
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_in_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in);
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
let mut ct_rlwe_out_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| RLWECtDft::from_prod_by_grlwe_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_grlwe.size(),
),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft);
ct_rlwe_out_dft.from_prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow());
ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe_in,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
#[test]
fn prod_by_grlwe() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 12;
let log_k_grlwe: usize = 60;
let log_k_rlwe: usize = 45;
let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe);
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| RLWECtDft::prod_by_grlwe_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0.data,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
Some(&pt_want),
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_rlwe.dft(&module, &mut ct_rlwe_dft);
ct_rlwe_dft.prod_by_grlwe(&module, &ct_grlwe, scratch.borrow());
ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
let noise_want: f64 = noise_grlwe_rlwe_product(
module.n() as f64,
log_base2k,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
log_k_rlwe,
log_k_grlwe,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
module.free();
}
}